Skip to content

Commit

Permalink
Support all RuntimeEnvironment parameters in DataflowTemplatedJobStar…
Browse files Browse the repository at this point in the history
…tOperator (#8531)
  • Loading branch information
mik-laj committed May 6, 2020
1 parent 520aeed commit 25ee421
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 120 deletions.
26 changes: 13 additions & 13 deletions airflow/providers/google/cloud/hooks/dataflow.py
Expand Up @@ -532,7 +532,13 @@ def start_template_dataflow(
:param job_name: The name of the job.
:type job_name: str
:param variables: Variables passed to the job.
:param variables: Map of job runtime environment options.
.. seealso::
For more information on possible configurations, look at the API documentation
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
:type variables: dict
:param parameters: Parameters fot the template
:type parameters: dict
Expand All @@ -548,23 +554,17 @@ def start_template_dataflow(
:type location: str
"""
name = self._build_dataflow_job_name(job_name, append_job_name)
# Builds RuntimeEnvironment from variables dictionary
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
environment = {}
for key in ['numWorkers', 'maxWorkers', 'zone', 'serviceAccountEmail',
'tempLocation', 'bypassTempDirValidation', 'machineType',
'additionalExperiments', 'network', 'subnetwork', 'additionalUserLabels']:
if key in variables:
environment.update({key: variables[key]})
body = {"jobName": name,
"parameters": parameters,
"environment": environment}

service = self.get_conn()
request = service.projects().locations().templates().launch( # pylint: disable=no-member
projectId=project_id,
location=location,
gcsPath=dataflow_template,
body=body
body={
"jobName": name,
"parameters": parameters,
"environment": variables
}
)
response = request.execute(num_retries=self.num_retries)

Expand Down
40 changes: 28 additions & 12 deletions airflow/providers/google/cloud/operators/dataflow.py
Expand Up @@ -23,7 +23,7 @@
import re
from contextlib import ExitStack
from enum import Enum
from typing import List, Optional
from typing import Any, Dict, List, Optional

from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.dataflow import DEFAULT_DATAFLOW_LOCATION, DataflowHook
Expand Down Expand Up @@ -277,6 +277,14 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
:type template: str
:param job_name: The 'jobName' to use when executing the DataFlow template
(templated).
:param options: Map of job runtime environment options.
.. seealso::
For more information on possible configurations, look at the API documentation
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
:type options: dict
:param dataflow_default_options: Map of default job environment options.
:type dataflow_default_options: dict
:param parameters: Map of job specific parameters for the template.
Expand Down Expand Up @@ -344,16 +352,25 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
For more detail on job template execution have a look at the reference:
https://cloud.google.com/dataflow/docs/templates/executing-templates
"""
template_fields = ['parameters', 'dataflow_default_options', 'template', 'job_name']
template_fields = [
'template',
'job_name',
'options',
'parameters',
'project_id',
'location',
'gcp_conn_id'
]
ui_color = '#0273d4'

@apply_defaults
def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
template: str,
job_name: str = '{{task.task_id}}',
dataflow_default_options: Optional[dict] = None,
parameters: Optional[dict] = None,
options: Optional[Dict[str, Any]] = None,
dataflow_default_options: Optional[Dict[str, Any]] = None,
parameters: Optional[Dict[str, str]] = None,
project_id: Optional[str] = None,
location: str = DEFAULT_DATAFLOW_LOCATION,
gcp_conn_id: str = 'google_cloud_default',
Expand All @@ -362,14 +379,11 @@ def __init__(
*args,
**kwargs) -> None:
super().__init__(*args, **kwargs)

dataflow_default_options = dataflow_default_options or {}
parameters = parameters or {}

self.template = template
self.job_name = job_name
self.dataflow_default_options = dataflow_default_options
self.parameters = parameters
self.options = options or {}
self.dataflow_default_options = dataflow_default_options or {}
self.parameters = parameters or {}
self.project_id = project_id
self.location = location
self.gcp_conn_id = gcp_conn_id
Expand All @@ -387,10 +401,12 @@ def execute(self, context):

def set_current_job_id(job_id):
self.job_id = job_id
options = self.dataflow_default_options
options.update(self.options)

job = self.hook.start_template_dataflow(
job_name=self.job_name,
variables=self.dataflow_default_options,
variables=options,
parameters=self.parameters,
dataflow_template=self.template,
on_new_job_id_callback=set_current_job_id,
Expand Down

0 comments on commit 25ee421

Please sign in to comment.