29
29
from copy import deepcopy
30
30
from typing import Any , Callable , Generator , Sequence , TypeVar , cast
31
31
32
+ from google .cloud .dataflow_v1beta3 import GetJobRequest , Job , JobState , JobsV1Beta3AsyncClient , JobView
32
33
from googleapiclient .discovery import build
33
34
34
35
from airflow .exceptions import AirflowException
35
36
from airflow .providers .apache .beam .hooks .beam import BeamHook , BeamRunnerType , beam_options_to_args
36
- from airflow .providers .google .common .hooks .base_google import GoogleBaseHook
37
+ from airflow .providers .google .common .hooks .base_google import (
38
+ PROVIDE_PROJECT_ID ,
39
+ GoogleBaseAsyncHook ,
40
+ GoogleBaseHook ,
41
+ )
37
42
from airflow .utils .log .logging_mixin import LoggingMixin
38
43
from airflow .utils .timeout import timeout
39
44
@@ -645,36 +650,10 @@ def start_template_dataflow(
645
650
"""
646
651
name = self .build_dataflow_job_name (job_name , append_job_name )
647
652
648
- environment = environment or {}
649
- # available keys for runtime environment are listed here:
650
- # https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
651
- environment_keys = [
652
- "numWorkers" ,
653
- "maxWorkers" ,
654
- "zone" ,
655
- "serviceAccountEmail" ,
656
- "tempLocation" ,
657
- "bypassTempDirValidation" ,
658
- "machineType" ,
659
- "additionalExperiments" ,
660
- "network" ,
661
- "subnetwork" ,
662
- "additionalUserLabels" ,
663
- "kmsKeyName" ,
664
- "ipConfiguration" ,
665
- "workerRegion" ,
666
- "workerZone" ,
667
- ]
668
-
669
- for key in variables :
670
- if key in environment_keys :
671
- if key in environment :
672
- self .log .warning (
673
- "'%s' parameter in 'variables' will override of "
674
- "the same one passed in 'environment'!" ,
675
- key ,
676
- )
677
- environment .update ({key : variables [key ]})
653
+ environment = self ._update_environment (
654
+ variables = variables ,
655
+ environment = environment ,
656
+ )
678
657
679
658
service = self .get_conn ()
680
659
@@ -723,6 +702,40 @@ def start_template_dataflow(
723
702
jobs_controller .wait_for_done ()
724
703
return response ["job" ]
725
704
705
+ def _update_environment (self , variables : dict , environment : dict | None = None ) -> dict :
706
+ environment = environment or {}
707
+ # available keys for runtime environment are listed here:
708
+ # https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
709
+ environment_keys = {
710
+ "numWorkers" ,
711
+ "maxWorkers" ,
712
+ "zone" ,
713
+ "serviceAccountEmail" ,
714
+ "tempLocation" ,
715
+ "bypassTempDirValidation" ,
716
+ "machineType" ,
717
+ "additionalExperiments" ,
718
+ "network" ,
719
+ "subnetwork" ,
720
+ "additionalUserLabels" ,
721
+ "kmsKeyName" ,
722
+ "ipConfiguration" ,
723
+ "workerRegion" ,
724
+ "workerZone" ,
725
+ }
726
+
727
+ def _check_one (key , val ):
728
+ if key in environment :
729
+ self .log .warning (
730
+ "%r parameter in 'variables' will override the same one passed in 'environment'!" ,
731
+ key ,
732
+ )
733
+ return key , val
734
+
735
+ environment .update (_check_one (key , val ) for key , val in variables .items () if key in environment_keys )
736
+
737
+ return environment
738
+
726
739
@GoogleBaseHook .fallback_to_default_project_id
727
740
def start_flex_template (
728
741
self ,
@@ -731,9 +744,9 @@ def start_flex_template(
731
744
project_id : str ,
732
745
on_new_job_id_callback : Callable [[str ], None ] | None = None ,
733
746
on_new_job_callback : Callable [[dict ], None ] | None = None ,
734
- ):
747
+ ) -> dict :
735
748
"""
736
- Starts flex templates with the Dataflow pipeline.
749
+ Starts flex templates with the Dataflow pipeline.
737
750
738
751
:param body: The request body. See:
739
752
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
@@ -1041,7 +1054,7 @@ def start_sql_job(
1041
1054
def get_job (
1042
1055
self ,
1043
1056
job_id : str ,
1044
- project_id : str ,
1057
+ project_id : str = PROVIDE_PROJECT_ID ,
1045
1058
location : str = DEFAULT_DATAFLOW_LOCATION ,
1046
1059
) -> dict :
1047
1060
"""
@@ -1169,3 +1182,88 @@ def wait_for_done(
1169
1182
wait_until_finished = self .wait_until_finished ,
1170
1183
)
1171
1184
job_controller .wait_for_done ()
1185
+
1186
+
1187
+ class AsyncDataflowHook (GoogleBaseAsyncHook ):
1188
+ """Async hook class for dataflow service."""
1189
+
1190
+ sync_hook_class = DataflowHook
1191
+
1192
+ async def initialize_client (self , client_class ):
1193
+ """
1194
+ Initialize object of the given class.
1195
+
1196
+ Method is used to initialize asynchronous client. Because of the big amount of the classes which are
1197
+ used for Dataflow service it was decided to initialize them the same way with credentials which are
1198
+ received from the method of the GoogleBaseHook class.
1199
+ :param client_class: Class of the Google cloud SDK
1200
+ """
1201
+ credentials = (await self .get_sync_hook ()).get_credentials ()
1202
+ return client_class (
1203
+ credentials = credentials ,
1204
+ )
1205
+
1206
+ async def get_project_id (self ) -> str :
1207
+ project_id = (await self .get_sync_hook ()).project_id
1208
+ return project_id
1209
+
1210
+ async def get_job (
1211
+ self ,
1212
+ job_id : str ,
1213
+ project_id : str = PROVIDE_PROJECT_ID ,
1214
+ job_view : int = JobView .JOB_VIEW_SUMMARY ,
1215
+ location : str = DEFAULT_DATAFLOW_LOCATION ,
1216
+ ) -> Job :
1217
+ """
1218
+ Gets the job with the specified Job ID.
1219
+
1220
+ :param job_id: Job ID to get.
1221
+ :param project_id: the Google Cloud project ID in which to start a job.
1222
+ If set to None or missing, the default project_id from the Google Cloud connection is used.
1223
+ :param job_view: Optional. JobView object which determines representation of the returned data
1224
+ :param location: Optional. The location of the Dataflow job (for example europe-west1). See:
1225
+ https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
1226
+ """
1227
+ project_id = project_id or (await self .get_project_id ())
1228
+ client = await self .initialize_client (JobsV1Beta3AsyncClient )
1229
+
1230
+ request = GetJobRequest (
1231
+ dict (
1232
+ project_id = project_id ,
1233
+ job_id = job_id ,
1234
+ view = job_view ,
1235
+ location = location ,
1236
+ )
1237
+ )
1238
+
1239
+ job = await client .get_job (
1240
+ request = request ,
1241
+ )
1242
+
1243
+ return job
1244
+
1245
+ async def get_job_status (
1246
+ self ,
1247
+ job_id : str ,
1248
+ project_id : str = PROVIDE_PROJECT_ID ,
1249
+ job_view : int = JobView .JOB_VIEW_SUMMARY ,
1250
+ location : str = DEFAULT_DATAFLOW_LOCATION ,
1251
+ ) -> JobState :
1252
+ """
1253
+ Gets the job status with the specified Job ID.
1254
+
1255
+ :param job_id: Job ID to get.
1256
+ :param project_id: the Google Cloud project ID in which to start a job.
1257
+ If set to None or missing, the default project_id from the Google Cloud connection is used.
1258
+ :param job_view: Optional. JobView object which determines representation of the returned data
1259
+ :param location: Optional. The location of the Dataflow job (for example europe-west1). See:
1260
+ https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
1261
+ """
1262
+ job = await self .get_job (
1263
+ project_id = project_id ,
1264
+ job_id = job_id ,
1265
+ job_view = job_view ,
1266
+ location = location ,
1267
+ )
1268
+ state = job .current_state
1269
+ return state
0 commit comments