Skip to content

Commit

Permalink
Add deferrable mode to dataflow operators (#27776)
Browse files Browse the repository at this point in the history
* Add deferrable mode to DataflowTemplatedJobStartOperator and DataflowStartFlexTemplateOperator operators

* Change project_id param to be optional, add fixes for tests and docs build

* Add comment about upper-bound for google-cloud-dataflow-client lib and
change warning message

---------

Co-authored-by: Heorhi Parkhomenka <[email protected]>
  • Loading branch information
MrGeorgeOwl and Heorhi Parkhomenka committed Jan 30, 2023
1 parent 5e470c1 commit 094d6bf
Show file tree
Hide file tree
Showing 12 changed files with 769 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
GCS_OUTPUT = os.environ.get("GCP_DATAFLOW_GCS_OUTPUT", "gs://INVALID BUCKET NAME/output")
GCS_JAR = os.environ.get("GCP_DATAFLOW_JAR", "gs://INVALID BUCKET NAME/word-count-beam-bundled-0.1.jar")
GCS_PYTHON = os.environ.get("GCP_DATAFLOW_PYTHON", "gs://INVALID BUCKET NAME/wordcount_debugging.py")
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")

GCS_JAR_PARTS = urlsplit(GCS_JAR)
GCS_JAR_BUCKET_NAME = GCS_JAR_PARTS.netloc
Expand Down Expand Up @@ -257,6 +258,7 @@ def check_autoscaling_event(autoscaling_events: list[dict]) -> bool:
# [START howto_operator_start_template_job]
start_template_job = DataflowTemplatedJobStartOperator(
task_id="start-template-job",
project_id=PROJECT_ID,
template="gs://dataflow-templates/latest/Word_Count",
parameters={"inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT},
location="europe-west3",
Expand All @@ -279,6 +281,7 @@ def check_autoscaling_event(autoscaling_events: list[dict]) -> bool:
# [END howto_operator_stop_dataflow_job]
start_template_job = DataflowTemplatedJobStartOperator(
task_id="start-template-job",
project_id=PROJECT_ID,
template="gs://dataflow-templates/latest/Word_Count",
parameters={"inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT},
location="europe-west3",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
# [START howto_operator_start_template_job]
start_flex_template = DataflowStartFlexTemplateOperator(
task_id="start_flex_template_streaming_beam_sql",
project_id=GCP_PROJECT_ID,
body={
"launchParameter": {
"containerSpecGcsPath": GCS_FLEX_TEMPLATE_TEMPLATE_PATH,
Expand Down
166 changes: 132 additions & 34 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,16 @@
from copy import deepcopy
from typing import Any, Callable, Generator, Sequence, TypeVar, cast

from google.cloud.dataflow_v1beta3 import GetJobRequest, Job, JobState, JobsV1Beta3AsyncClient, JobView
from googleapiclient.discovery import build

from airflow.exceptions import AirflowException
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
GoogleBaseHook,
)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.timeout import timeout

Expand Down Expand Up @@ -645,36 +650,10 @@ def start_template_dataflow(
"""
name = self.build_dataflow_job_name(job_name, append_job_name)

environment = environment or {}
# available keys for runtime environment are listed here:
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
environment_keys = [
"numWorkers",
"maxWorkers",
"zone",
"serviceAccountEmail",
"tempLocation",
"bypassTempDirValidation",
"machineType",
"additionalExperiments",
"network",
"subnetwork",
"additionalUserLabels",
"kmsKeyName",
"ipConfiguration",
"workerRegion",
"workerZone",
]

for key in variables:
if key in environment_keys:
if key in environment:
self.log.warning(
"'%s' parameter in 'variables' will override of "
"the same one passed in 'environment'!",
key,
)
environment.update({key: variables[key]})
environment = self._update_environment(
variables=variables,
environment=environment,
)

service = self.get_conn()

Expand Down Expand Up @@ -723,6 +702,40 @@ def start_template_dataflow(
jobs_controller.wait_for_done()
return response["job"]

def _update_environment(self, variables: dict, environment: dict | None = None) -> dict:
environment = environment or {}
# available keys for runtime environment are listed here:
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
environment_keys = {
"numWorkers",
"maxWorkers",
"zone",
"serviceAccountEmail",
"tempLocation",
"bypassTempDirValidation",
"machineType",
"additionalExperiments",
"network",
"subnetwork",
"additionalUserLabels",
"kmsKeyName",
"ipConfiguration",
"workerRegion",
"workerZone",
}

def _check_one(key, val):
if key in environment:
self.log.warning(
"%r parameter in 'variables' will override the same one passed in 'environment'!",
key,
)
return key, val

environment.update(_check_one(key, val) for key, val in variables.items() if key in environment_keys)

return environment

@GoogleBaseHook.fallback_to_default_project_id
def start_flex_template(
self,
Expand All @@ -731,9 +744,9 @@ def start_flex_template(
project_id: str,
on_new_job_id_callback: Callable[[str], None] | None = None,
on_new_job_callback: Callable[[dict], None] | None = None,
):
) -> dict:
"""
Starts flex templates with the Dataflow pipeline.
Starts flex templates with the Dataflow pipeline.
:param body: The request body. See:
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
Expand Down Expand Up @@ -1041,7 +1054,7 @@ def start_sql_job(
def get_job(
self,
job_id: str,
project_id: str,
project_id: str = PROVIDE_PROJECT_ID,
location: str = DEFAULT_DATAFLOW_LOCATION,
) -> dict:
"""
Expand Down Expand Up @@ -1169,3 +1182,88 @@ def wait_for_done(
wait_until_finished=self.wait_until_finished,
)
job_controller.wait_for_done()


class AsyncDataflowHook(GoogleBaseAsyncHook):
"""Async hook class for dataflow service."""

sync_hook_class = DataflowHook

async def initialize_client(self, client_class):
"""
Initialize object of the given class.
Method is used to initialize asynchronous client. Because of the big amount of the classes which are
used for Dataflow service it was decided to initialize them the same way with credentials which are
received from the method of the GoogleBaseHook class.
:param client_class: Class of the Google cloud SDK
"""
credentials = (await self.get_sync_hook()).get_credentials()
return client_class(
credentials=credentials,
)

async def get_project_id(self) -> str:
project_id = (await self.get_sync_hook()).project_id
return project_id

async def get_job(
self,
job_id: str,
project_id: str = PROVIDE_PROJECT_ID,
job_view: int = JobView.JOB_VIEW_SUMMARY,
location: str = DEFAULT_DATAFLOW_LOCATION,
) -> Job:
"""
Gets the job with the specified Job ID.
:param job_id: Job ID to get.
:param project_id: the Google Cloud project ID in which to start a job.
If set to None or missing, the default project_id from the Google Cloud connection is used.
:param job_view: Optional. JobView object which determines representation of the returned data
:param location: Optional. The location of the Dataflow job (for example europe-west1). See:
https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
"""
project_id = project_id or (await self.get_project_id())
client = await self.initialize_client(JobsV1Beta3AsyncClient)

request = GetJobRequest(
dict(
project_id=project_id,
job_id=job_id,
view=job_view,
location=location,
)
)

job = await client.get_job(
request=request,
)

return job

async def get_job_status(
self,
job_id: str,
project_id: str = PROVIDE_PROJECT_ID,
job_view: int = JobView.JOB_VIEW_SUMMARY,
location: str = DEFAULT_DATAFLOW_LOCATION,
) -> JobState:
"""
Gets the job status with the specified Job ID.
:param job_id: Job ID to get.
:param project_id: the Google Cloud project ID in which to start a job.
If set to None or missing, the default project_id from the Google Cloud connection is used.
:param job_view: Optional. JobView object which determines representation of the returned data
:param location: Optional. The location of the Dataflow job (for example europe-west1). See:
https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
"""
job = await self.get_job(
project_id=project_id,
job_id=job_id,
job_view=job_view,
location=location,
)
state = job.current_state
return state
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/links/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ def persist(
operator_instance.xcom_push(
context,
key=DataflowJobLink.key,
value={"project_id": project_id, "location": region, "job_id": job_id},
value={"project_id": project_id, "region": region, "job_id": job_id},
)

0 comments on commit 094d6bf

Please sign in to comment.