Skip to content

Commit 094d6bf

Browse files
MrGeorgeOwlHeorhi Parkhomenka
andauthored
Add deferrable mode to dataflow operators (#27776)
* Add deferrable mode to DataflowTemplatedJobStartOperator and DataflowStartFlexTemplateOperator operators * Change project_id param to be optional, add fixes for tests and docs build * Add comment about upper-bound for google-cloud-dataflow-client lib and change warning message --------- Co-authored-by: Heorhi Parkhomenka <[email protected]>
1 parent 5e470c1 commit 094d6bf

File tree

12 files changed

+769
-92
lines changed

12 files changed

+769
-92
lines changed

airflow/providers/google/cloud/example_dags/example_dataflow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
GCS_OUTPUT = os.environ.get("GCP_DATAFLOW_GCS_OUTPUT", "gs://INVALID BUCKET NAME/output")
5353
GCS_JAR = os.environ.get("GCP_DATAFLOW_JAR", "gs://INVALID BUCKET NAME/word-count-beam-bundled-0.1.jar")
5454
GCS_PYTHON = os.environ.get("GCP_DATAFLOW_PYTHON", "gs://INVALID BUCKET NAME/wordcount_debugging.py")
55+
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
5556

5657
GCS_JAR_PARTS = urlsplit(GCS_JAR)
5758
GCS_JAR_BUCKET_NAME = GCS_JAR_PARTS.netloc
@@ -257,6 +258,7 @@ def check_autoscaling_event(autoscaling_events: list[dict]) -> bool:
257258
# [START howto_operator_start_template_job]
258259
start_template_job = DataflowTemplatedJobStartOperator(
259260
task_id="start-template-job",
261+
project_id=PROJECT_ID,
260262
template="gs://dataflow-templates/latest/Word_Count",
261263
parameters={"inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT},
262264
location="europe-west3",
@@ -279,6 +281,7 @@ def check_autoscaling_event(autoscaling_events: list[dict]) -> bool:
279281
# [END howto_operator_stop_dataflow_job]
280282
start_template_job = DataflowTemplatedJobStartOperator(
281283
task_id="start-template-job",
284+
project_id=PROJECT_ID,
282285
template="gs://dataflow-templates/latest/Word_Count",
283286
parameters={"inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT},
284287
location="europe-west3",

airflow/providers/google/cloud/example_dags/example_dataflow_flex_template.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
# [START howto_operator_start_template_job]
5353
start_flex_template = DataflowStartFlexTemplateOperator(
5454
task_id="start_flex_template_streaming_beam_sql",
55+
project_id=GCP_PROJECT_ID,
5556
body={
5657
"launchParameter": {
5758
"containerSpecGcsPath": GCS_FLEX_TEMPLATE_TEMPLATE_PATH,

airflow/providers/google/cloud/hooks/dataflow.py

Lines changed: 132 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,16 @@
2929
from copy import deepcopy
3030
from typing import Any, Callable, Generator, Sequence, TypeVar, cast
3131

32+
from google.cloud.dataflow_v1beta3 import GetJobRequest, Job, JobState, JobsV1Beta3AsyncClient, JobView
3233
from googleapiclient.discovery import build
3334

3435
from airflow.exceptions import AirflowException
3536
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args
36-
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
37+
from airflow.providers.google.common.hooks.base_google import (
38+
PROVIDE_PROJECT_ID,
39+
GoogleBaseAsyncHook,
40+
GoogleBaseHook,
41+
)
3742
from airflow.utils.log.logging_mixin import LoggingMixin
3843
from airflow.utils.timeout import timeout
3944

@@ -645,36 +650,10 @@ def start_template_dataflow(
645650
"""
646651
name = self.build_dataflow_job_name(job_name, append_job_name)
647652

648-
environment = environment or {}
649-
# available keys for runtime environment are listed here:
650-
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
651-
environment_keys = [
652-
"numWorkers",
653-
"maxWorkers",
654-
"zone",
655-
"serviceAccountEmail",
656-
"tempLocation",
657-
"bypassTempDirValidation",
658-
"machineType",
659-
"additionalExperiments",
660-
"network",
661-
"subnetwork",
662-
"additionalUserLabels",
663-
"kmsKeyName",
664-
"ipConfiguration",
665-
"workerRegion",
666-
"workerZone",
667-
]
668-
669-
for key in variables:
670-
if key in environment_keys:
671-
if key in environment:
672-
self.log.warning(
673-
"'%s' parameter in 'variables' will override of "
674-
"the same one passed in 'environment'!",
675-
key,
676-
)
677-
environment.update({key: variables[key]})
653+
environment = self._update_environment(
654+
variables=variables,
655+
environment=environment,
656+
)
678657

679658
service = self.get_conn()
680659

@@ -723,6 +702,40 @@ def start_template_dataflow(
723702
jobs_controller.wait_for_done()
724703
return response["job"]
725704

705+
def _update_environment(self, variables: dict, environment: dict | None = None) -> dict:
706+
environment = environment or {}
707+
# available keys for runtime environment are listed here:
708+
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
709+
environment_keys = {
710+
"numWorkers",
711+
"maxWorkers",
712+
"zone",
713+
"serviceAccountEmail",
714+
"tempLocation",
715+
"bypassTempDirValidation",
716+
"machineType",
717+
"additionalExperiments",
718+
"network",
719+
"subnetwork",
720+
"additionalUserLabels",
721+
"kmsKeyName",
722+
"ipConfiguration",
723+
"workerRegion",
724+
"workerZone",
725+
}
726+
727+
def _check_one(key, val):
728+
if key in environment:
729+
self.log.warning(
730+
"%r parameter in 'variables' will override the same one passed in 'environment'!",
731+
key,
732+
)
733+
return key, val
734+
735+
environment.update(_check_one(key, val) for key, val in variables.items() if key in environment_keys)
736+
737+
return environment
738+
726739
@GoogleBaseHook.fallback_to_default_project_id
727740
def start_flex_template(
728741
self,
@@ -731,9 +744,9 @@ def start_flex_template(
731744
project_id: str,
732745
on_new_job_id_callback: Callable[[str], None] | None = None,
733746
on_new_job_callback: Callable[[dict], None] | None = None,
734-
):
747+
) -> dict:
735748
"""
736-
Starts flex templates with the Dataflow pipeline.
749+
Starts flex templates with the Dataflow pipeline.
737750
738751
:param body: The request body. See:
739752
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
@@ -1041,7 +1054,7 @@ def start_sql_job(
10411054
def get_job(
10421055
self,
10431056
job_id: str,
1044-
project_id: str,
1057+
project_id: str = PROVIDE_PROJECT_ID,
10451058
location: str = DEFAULT_DATAFLOW_LOCATION,
10461059
) -> dict:
10471060
"""
@@ -1169,3 +1182,88 @@ def wait_for_done(
11691182
wait_until_finished=self.wait_until_finished,
11701183
)
11711184
job_controller.wait_for_done()
1185+
1186+
1187+
class AsyncDataflowHook(GoogleBaseAsyncHook):
1188+
"""Async hook class for dataflow service."""
1189+
1190+
sync_hook_class = DataflowHook
1191+
1192+
async def initialize_client(self, client_class):
1193+
"""
1194+
Initialize object of the given class.
1195+
1196+
Method is used to initialize asynchronous client. Because of the big amount of the classes which are
1197+
used for Dataflow service it was decided to initialize them the same way with credentials which are
1198+
received from the method of the GoogleBaseHook class.
1199+
:param client_class: Class of the Google cloud SDK
1200+
"""
1201+
credentials = (await self.get_sync_hook()).get_credentials()
1202+
return client_class(
1203+
credentials=credentials,
1204+
)
1205+
1206+
async def get_project_id(self) -> str:
1207+
project_id = (await self.get_sync_hook()).project_id
1208+
return project_id
1209+
1210+
async def get_job(
1211+
self,
1212+
job_id: str,
1213+
project_id: str = PROVIDE_PROJECT_ID,
1214+
job_view: int = JobView.JOB_VIEW_SUMMARY,
1215+
location: str = DEFAULT_DATAFLOW_LOCATION,
1216+
) -> Job:
1217+
"""
1218+
Gets the job with the specified Job ID.
1219+
1220+
:param job_id: Job ID to get.
1221+
:param project_id: the Google Cloud project ID in which to start a job.
1222+
If set to None or missing, the default project_id from the Google Cloud connection is used.
1223+
:param job_view: Optional. JobView object which determines representation of the returned data
1224+
:param location: Optional. The location of the Dataflow job (for example europe-west1). See:
1225+
https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
1226+
"""
1227+
project_id = project_id or (await self.get_project_id())
1228+
client = await self.initialize_client(JobsV1Beta3AsyncClient)
1229+
1230+
request = GetJobRequest(
1231+
dict(
1232+
project_id=project_id,
1233+
job_id=job_id,
1234+
view=job_view,
1235+
location=location,
1236+
)
1237+
)
1238+
1239+
job = await client.get_job(
1240+
request=request,
1241+
)
1242+
1243+
return job
1244+
1245+
async def get_job_status(
1246+
self,
1247+
job_id: str,
1248+
project_id: str = PROVIDE_PROJECT_ID,
1249+
job_view: int = JobView.JOB_VIEW_SUMMARY,
1250+
location: str = DEFAULT_DATAFLOW_LOCATION,
1251+
) -> JobState:
1252+
"""
1253+
Gets the job status with the specified Job ID.
1254+
1255+
:param job_id: Job ID to get.
1256+
:param project_id: the Google Cloud project ID in which to start a job.
1257+
If set to None or missing, the default project_id from the Google Cloud connection is used.
1258+
:param job_view: Optional. JobView object which determines representation of the returned data
1259+
:param location: Optional. The location of the Dataflow job (for example europe-west1). See:
1260+
https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
1261+
"""
1262+
job = await self.get_job(
1263+
project_id=project_id,
1264+
job_id=job_id,
1265+
job_view=job_view,
1266+
location=location,
1267+
)
1268+
state = job.current_state
1269+
return state

airflow/providers/google/cloud/links/dataflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,5 @@ def persist(
4848
operator_instance.xcom_push(
4949
context,
5050
key=DataflowJobLink.key,
51-
value={"project_id": project_id, "location": region, "job_id": job_id},
51+
value={"project_id": project_id, "region": region, "job_id": job_id},
5252
)

0 commit comments

Comments
 (0)