Skip to content

Commit

Permalink
Dataflow Operators - use project and location from job in on_kill met…
Browse files Browse the repository at this point in the history
…hod. (#18699)

Reason why we need this is because we can have situation where project_id is set to None but we define it in the dataflow_default_options. Job will start normally without error but in case when we decide to mark running task to different state we will get a error that the job does not exits.
  • Loading branch information
Łukasz Wyszomirski committed Oct 7, 2021
1 parent 6103b26 commit 20df60d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 45 deletions.
59 changes: 47 additions & 12 deletions airflow/providers/google/cloud/hooks/dataflow.py
Expand Up @@ -623,6 +623,7 @@ def start_template_dataflow(
project_id: str,
append_job_name: bool = True,
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
on_new_job_callback: Optional[Callable[[dict], None]] = None,
location: str = DEFAULT_DATAFLOW_LOCATION,
environment: Optional[dict] = None,
) -> dict:
Expand All @@ -648,8 +649,10 @@ def start_template_dataflow(
If set to None or missing, the default project_id from the Google Cloud connection is used.
:param append_job_name: True if unique suffix has to be appended to job name.
:type append_job_name: bool
:param on_new_job_id_callback: Callback called when the job ID is known.
:param on_new_job_id_callback: (Deprecated) Callback called when the Job is known.
:type on_new_job_id_callback: callable
:param on_new_job_callback: Callback called when the Job is known.
:type on_new_job_callback: callable
:param location: Job location.
:type location: str
:type environment: Optional, Map of job runtime environment options.
Expand Down Expand Up @@ -713,15 +716,24 @@ def start_template_dataflow(
)
response = request.execute(num_retries=self.num_retries)

job_id = response["job"]["id"]
job = response["job"]

if on_new_job_id_callback:
on_new_job_id_callback(job_id)
warnings.warn(
"on_new_job_id_callback is Deprecated. Please start using on_new_job_callback",
DeprecationWarning,
stacklevel=3,
)
on_new_job_id_callback(job.get("id"))

if on_new_job_callback:
on_new_job_callback(job)

jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
name=name,
job_id=job_id,
job_id=job["id"],
location=location,
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
Expand All @@ -739,6 +751,7 @@ def start_flex_template(
location: str,
project_id: str,
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
on_new_job_callback: Optional[Callable[[dict], None]] = None,
):
"""
Starts flex templates with the Dataflow pipeline.
Expand All @@ -750,7 +763,8 @@ def start_flex_template(
:param project_id: The ID of the GCP project that owns the job.
If set to ``None`` or missing, the default project_id from the GCP connection is used.
:type project_id: Optional[str]
:param on_new_job_id_callback: A callback that is called when a Job ID is detected.
:param on_new_job_id_callback: (Deprecated) A callback that is called when a Job ID is detected.
:param on_new_job_callback: A callback that is called when a Job is detected.
:return: the Job
"""
service = self.get_conn()
Expand All @@ -761,15 +775,23 @@ def start_flex_template(
.launch(projectId=project_id, body=body, location=location)
)
response = request.execute(num_retries=self.num_retries)
job_id = response["job"]["id"]
job = response["job"]

if on_new_job_id_callback:
on_new_job_id_callback(job_id)
warnings.warn(
"on_new_job_id_callback is Deprecated. Please start using on_new_job_callback",
DeprecationWarning,
stacklevel=3,
)
on_new_job_id_callback(job.get("id"))

if on_new_job_callback:
on_new_job_callback(job)

jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
job_id=job_id,
job_id=job.get("id"),
location=location,
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
Expand Down Expand Up @@ -973,6 +995,7 @@ def start_sql_job(
project_id: str,
location: str = DEFAULT_DATAFLOW_LOCATION,
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
on_new_job_callback: Optional[Callable[[dict], None]] = None,
):
"""
Starts Dataflow SQL query.
Expand All @@ -991,8 +1014,10 @@ def start_sql_job(
:param project_id: The ID of the GCP project that owns the job.
If set to ``None`` or missing, the default project_id from the GCP connection is used.
:type project_id: Optional[str]
:param on_new_job_id_callback: Callback called when the job ID is known.
:param on_new_job_id_callback: (Deprecated) Callback called when the job ID is known.
:type on_new_job_id_callback: callable
:param on_new_job_callback: Callback called when the job is known.
:type on_new_job_callback: callable
:return: the new job object
"""
cmd = [
Expand All @@ -1018,8 +1043,6 @@ def start_sql_job(
job_id = proc.stdout.decode().strip()

self.log.info("Created job ID: %s", job_id)
if on_new_job_id_callback:
on_new_job_id_callback(job_id)

jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
Expand All @@ -1031,8 +1054,20 @@ def start_sql_job(
drain_pipeline=self.drain_pipeline,
wait_until_finished=self.wait_until_finished,
)
jobs_controller.wait_for_done()
job = jobs_controller.get_jobs(refresh=True)[0]

if on_new_job_id_callback:
warnings.warn(
"on_new_job_id_callback is Deprecated. Please start using on_new_job_callback",
DeprecationWarning,
stacklevel=3,
)
on_new_job_id_callback(job.get("id"))

if on_new_job_callback:
on_new_job_callback(job)

jobs_controller.wait_for_done()
return jobs_controller.get_jobs(refresh=True)[0]

@GoogleBaseHook.fallback_to_default_project_id
Expand Down
48 changes: 30 additions & 18 deletions airflow/providers/google/cloud/operators/dataflow.py
Expand Up @@ -657,7 +657,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.poll_sleep = poll_sleep
self.job_id = None
self.job = None
self.hook: Optional[DataflowHook] = None
self.impersonation_chain = impersonation_chain
self.environment = environment
Expand All @@ -674,8 +674,8 @@ def execute(self, context) -> dict:
wait_until_finished=self.wait_until_finished,
)

def set_current_job_id(job_id):
self.job_id = job_id
def set_current_job(current_job):
self.job = current_job

options = self.dataflow_default_options
options.update(self.options)
Expand All @@ -684,7 +684,7 @@ def set_current_job_id(job_id):
variables=options,
parameters=self.parameters,
dataflow_template=self.template,
on_new_job_id_callback=set_current_job_id,
on_new_job_callback=set_current_job,
project_id=self.project_id,
location=self.location,
environment=self.environment,
Expand All @@ -694,8 +694,12 @@ def set_current_job_id(job_id):

def on_kill(self) -> None:
self.log.info("On kill.")
if self.job_id:
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
if self.job:
self.hook.cancel_job(
job_id=self.job.get("id"),
project_id=self.job.get("projectId"),
location=self.job.get("location"),
)


class DataflowStartFlexTemplateOperator(BaseOperator):
Expand Down Expand Up @@ -787,7 +791,7 @@ def __init__(
self.drain_pipeline = drain_pipeline
self.cancel_timeout = cancel_timeout
self.wait_until_finished = wait_until_finished
self.job_id = None
self.job = None
self.hook: Optional[DataflowHook] = None

def execute(self, context):
Expand All @@ -799,22 +803,26 @@ def execute(self, context):
wait_until_finished=self.wait_until_finished,
)

def set_current_job_id(job_id):
self.job_id = job_id
def set_current_job(current_job):
self.job = current_job

job = self.hook.start_flex_template(
body=self.body,
location=self.location,
project_id=self.project_id,
on_new_job_id_callback=set_current_job_id,
on_new_job_callback=set_current_job,
)

return job

def on_kill(self) -> None:
self.log.info("On kill.")
if self.job_id:
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
if self.job:
self.hook.cancel_job(
job_id=self.job.get("id"),
project_id=self.job.get("projectId"),
location=self.job.get("location"),
)


class DataflowStartSqlJobOperator(BaseOperator):
Expand Down Expand Up @@ -890,7 +898,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.drain_pipeline = drain_pipeline
self.job_id = None
self.job = None
self.hook: Optional[DataflowHook] = None

def execute(self, context):
Expand All @@ -900,24 +908,28 @@ def execute(self, context):
drain_pipeline=self.drain_pipeline,
)

def set_current_job_id(job_id):
self.job_id = job_id
def set_current_job(current_job):
self.job = current_job

job = self.hook.start_sql_job(
job_name=self.job_name,
query=self.query,
options=self.options,
location=self.location,
project_id=self.project_id,
on_new_job_id_callback=set_current_job_id,
on_new_job_callback=set_current_job,
)

return job

def on_kill(self) -> None:
self.log.info("On kill.")
if self.job_id:
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
if self.job:
self.hook.cancel_job(
job_id=self.job.get("id"),
project_id=self.job.get("projectId"),
location=self.job.get("location"),
)


class DataflowCreatePythonJobOperator(BaseOperator):
Expand Down
17 changes: 10 additions & 7 deletions tests/providers/google/cloud/hooks/test_dataflow.py
Expand Up @@ -1016,19 +1016,21 @@ def test_start_template_dataflow_update_runtime_env(self, mock_conn, mock_datafl
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
def test_start_flex_template(self, mock_conn, mock_controller):
expected_job = {"id": TEST_JOB_ID}

mock_locations = mock_conn.return_value.projects.return_value.locations
launch_method = mock_locations.return_value.flexTemplates.return_value.launch
launch_method.return_value.execute.return_value = {"job": {"id": TEST_JOB_ID}}
launch_method.return_value.execute.return_value = {"job": expected_job}
mock_controller.return_value.get_jobs.return_value = [{"id": TEST_JOB_ID}]

on_new_job_id_callback = mock.MagicMock()
on_new_job_callback = mock.MagicMock()
result = self.dataflow_hook.start_flex_template(
body={"launchParameter": TEST_FLEX_PARAMETERS},
location=TEST_LOCATION,
project_id=TEST_PROJECT_ID,
on_new_job_id_callback=on_new_job_id_callback,
on_new_job_callback=on_new_job_callback,
)
on_new_job_id_callback.assert_called_once_with(TEST_JOB_ID)
on_new_job_callback.assert_called_once_with(expected_job)
launch_method.assert_called_once_with(
projectId='test-project-id',
body={'launchParameter': TEST_FLEX_PARAMETERS},
Expand Down Expand Up @@ -1080,14 +1082,15 @@ def test_start_sql_job_failed_to_run(
mock_run.return_value = mock.MagicMock(
stdout=f"{TEST_JOB_ID}\n".encode(), stderr=f"{TEST_JOB_ID}\n".encode(), returncode=0
)
on_new_job_id_callback = mock.MagicMock()
on_new_job_callback = mock.MagicMock()

result = self.dataflow_hook.start_sql_job(
job_name=TEST_SQL_JOB_NAME,
query=TEST_SQL_QUERY,
options=TEST_SQL_OPTIONS,
location=TEST_LOCATION,
project_id=TEST_PROJECT,
on_new_job_id_callback=on_new_job_id_callback,
on_new_job_callback=on_new_job_callback,
)
mock_run.assert_called_once_with(
[
Expand Down Expand Up @@ -1135,7 +1138,7 @@ def test_start_sql_job(self, mock_run, mock_provide_authorized_gcloud, mock_get_
options=TEST_SQL_OPTIONS,
location=TEST_LOCATION,
project_id=TEST_PROJECT,
on_new_job_id_callback=mock.MagicMock(),
on_new_job_callback=mock.MagicMock(),
)


Expand Down
18 changes: 10 additions & 8 deletions tests/providers/google/cloud/operators/test_dataflow.py
Expand Up @@ -89,7 +89,7 @@
bigquery.table.test-project.beam_samples.beam_table
GROUP BY sales_region;
"""
TEST_SQL_JOB_ID = 'test-job-id'
TEST_SQL_JOB = {'id': 'test-job-id'}


class TestDataflowPythonOperator(unittest.TestCase):
Expand Down Expand Up @@ -410,7 +410,7 @@ def test_exec(self, dataflow_mock):
variables=expected_options,
parameters=PARAMETERS,
dataflow_template=TEMPLATE,
on_new_job_id_callback=mock.ANY,
on_new_job_callback=mock.ANY,
project_id=None,
location=TEST_LOCATION,
environment={'maxWorkers': 2},
Expand All @@ -432,7 +432,7 @@ def test_execute(self, mock_dataflow):
body={"launchParameter": TEST_FLEX_PARAMETERS},
location=TEST_LOCATION,
project_id=TEST_PROJECT,
on_new_job_id_callback=mock.ANY,
on_new_job_callback=mock.ANY,
)

def test_on_kill(self):
Expand All @@ -444,10 +444,10 @@ def test_on_kill(self):
project_id=TEST_PROJECT,
)
start_flex_template.hook = mock.MagicMock()
start_flex_template.job_id = JOB_ID
start_flex_template.job = {"id": JOB_ID, "projectId": TEST_PROJECT, "location": TEST_LOCATION}
start_flex_template.on_kill()
start_flex_template.hook.cancel_job.assert_called_once_with(
job_id='test-dataflow-pipeline-id', project_id=TEST_PROJECT
job_id='test-dataflow-pipeline-id', project_id=TEST_PROJECT, location=TEST_LOCATION
)


Expand All @@ -473,8 +473,10 @@ def test_execute(self, mock_hook):
options=TEST_SQL_OPTIONS,
location=TEST_LOCATION,
project_id=None,
on_new_job_id_callback=mock.ANY,
on_new_job_callback=mock.ANY,
)
start_sql.job_id = TEST_SQL_JOB_ID
start_sql.job = TEST_SQL_JOB
start_sql.on_kill()
mock_hook.return_value.cancel_job.assert_called_once_with(job_id='test-job-id', project_id=None)
mock_hook.return_value.cancel_job.assert_called_once_with(
job_id='test-job-id', project_id=None, location=None
)

0 comments on commit 20df60d

Please sign in to comment.