Skip to content

Commit

Permalink
Add drain option when canceling Dataflow pipelines (#11374)
Browse files Browse the repository at this point in the history
* Add drain option when cancel Dataflow pipelines

* fixup! Add drain option when cancel Dataflow pipelines

* fixup! fixup! Add drain option when cancel Dataflow pipelines

* fixup! fixup! fixup! Add drain option when cancel Dataflow pipelines
  • Loading branch information
Tobiasz Kędzierski committed Oct 29, 2020
1 parent 039a86b commit e5713e0
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 9 deletions.
23 changes: 19 additions & 4 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class _DataflowJobsController(LoggingMixin):
:param num_retries: Maximum number of retries in case of connection problems.
:param multiple_jobs: If set to true this task will be searched by name prefix (``name`` parameter),
not by specific job ID, then actions will be performed on all matching jobs.
:param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
instead of canceling.
"""

def __init__(
Expand All @@ -157,6 +159,7 @@ def __init__(
job_id: Optional[str] = None,
num_retries: int = 0,
multiple_jobs: bool = False,
drain_pipeline: bool = False,
) -> None:

super().__init__()
Expand All @@ -168,6 +171,7 @@ def __init__(
self._job_id = job_id
self._num_retries = num_retries
self._poll_sleep = poll_sleep
self.drain_pipeline = drain_pipeline
self._jobs: Optional[List[dict]] = None

def is_job_running(self) -> bool:
Expand Down Expand Up @@ -304,22 +308,27 @@ def get_jobs(self, refresh=False) -> List[dict]:
return self._jobs

def cancel(self) -> None:
"""Cancels current job"""
"""Cancels or drains current job"""
jobs = self.get_jobs()
job_ids = [job['id'] for job in jobs if job['currentState'] not in DataflowJobStatus.TERMINAL_STATES]
if job_ids:
batch = self._dataflow.new_batch_http_request()
self.log.info("Canceling jobs: %s", ", ".join(job_ids))
for job_id in job_ids:
for job in jobs:
requested_state = (
DataflowJobStatus.JOB_STATE_DRAINED
if self.drain_pipeline and job['type'] == DataflowJobType.JOB_TYPE_STREAMING
else DataflowJobStatus.JOB_STATE_CANCELLED
)
batch.add(
self._dataflow.projects()
.locations()
.jobs()
.update(
projectId=self._project_number,
location=self._job_location,
jobId=job_id,
body={"requestedState": DataflowJobStatus.JOB_STATE_CANCELLED},
jobId=job['id'],
body={"requestedState": requested_state},
)
)
batch.execute()
Expand Down Expand Up @@ -427,8 +436,10 @@ def __init__(
delegate_to: Optional[str] = None,
poll_sleep: int = 10,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
drain_pipeline: bool = False,
) -> None:
self.poll_sleep = poll_sleep
self.drain_pipeline = drain_pipeline
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
Expand Down Expand Up @@ -464,6 +475,7 @@ def _start_dataflow(
job_id=job_id,
num_retries=self.num_retries,
multiple_jobs=multiple_jobs,
drain_pipeline=self.drain_pipeline,
)
job_controller.wait_for_done()

Expand Down Expand Up @@ -633,6 +645,7 @@ def start_template_dataflow(
location=location,
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
drain_pipeline=self.drain_pipeline,
)
jobs_controller.wait_for_done()
return response["job"]
Expand Down Expand Up @@ -870,6 +883,7 @@ def is_job_dataflow_running(
name=name,
location=location,
poll_sleep=self.poll_sleep,
drain_pipeline=self.drain_pipeline,
)
return jobs_controller.is_job_running()

Expand Down Expand Up @@ -903,5 +917,6 @@ def cancel_job(
job_id=job_id,
location=location,
poll_sleep=self.poll_sleep,
drain_pipeline=self.drain_pipeline,
)
jobs_controller.cancel()
25 changes: 21 additions & 4 deletions airflow/providers/google/cloud/operators/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def __init__(

def execute(self, context):
self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, poll_sleep=self.poll_sleep
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
poll_sleep=self.poll_sleep,
)
dataflow_options = copy.copy(self.dataflow_default_options)
dataflow_options.update(self.options)
Expand Down Expand Up @@ -467,6 +469,10 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
:param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
instead of canceling during during killing task instance. See:
https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
:type drain_pipeline: bool
"""

template_fields = ["body", 'location', 'project_id', 'gcp_conn_id']
Expand All @@ -479,6 +485,7 @@ def __init__(
project_id: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
drain_pipeline: bool = False,
*args,
**kwargs,
) -> None:
Expand All @@ -490,11 +497,11 @@ def __init__(
self.delegate_to = delegate_to
self.job_id = None
self.hook: Optional[DataflowHook] = None
self.drain_pipeline = drain_pipeline

def execute(self, context):
self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, drain_pipeline=self.drain_pipeline
)

def set_current_job_id(job_id):
Expand All @@ -515,6 +522,7 @@ def on_kill(self) -> None:
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)


# pylint: disable=too-many-instance-attributes
class DataflowCreatePythonJobOperator(BaseOperator):
"""
Launching Cloud Dataflow jobs written in python. Note that both
Expand Down Expand Up @@ -582,6 +590,10 @@ class DataflowCreatePythonJobOperator(BaseOperator):
Cloud Platform for the dataflow job status while the job is in the
JOB_STATE_RUNNING state.
:type poll_sleep: int
:param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
instead of canceling during during killing task instance. See:
https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
:type drain_pipeline: bool
"""

template_fields = ['options', 'dataflow_default_options', 'job_name', 'py_file']
Expand All @@ -603,6 +615,7 @@ def __init__( # pylint: disable=too-many-arguments
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
poll_sleep: int = 10,
drain_pipeline: bool = False,
**kwargs,
) -> None:

Expand All @@ -624,6 +637,7 @@ def __init__( # pylint: disable=too-many-arguments
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.poll_sleep = poll_sleep
self.drain_pipeline = drain_pipeline
self.job_id = None
self.hook = None

Expand All @@ -638,7 +652,10 @@ def execute(self, context):
self.py_file = tmp_gcs_file.name

self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, poll_sleep=self.poll_sleep
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
poll_sleep=self.poll_sleep,
drain_pipeline=self.drain_pipeline,
)
dataflow_options = self.dataflow_default_options.copy()
dataflow_options.update(self.options)
Expand Down
60 changes: 60 additions & 0 deletions tests/providers/google/cloud/hooks/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ def test_start_template_dataflow(self, mock_conn, mock_controller, mock_uuid):
poll_sleep=10,
project_number=TEST_PROJECT,
location=DEFAULT_DATAFLOW_LOCATION,
drain_pipeline=False,
)
mock_controller.return_value.wait_for_done.assert_called_once()

Expand Down Expand Up @@ -692,6 +693,7 @@ def test_start_template_dataflow_with_custom_region_as_variable(
poll_sleep=10,
project_number=TEST_PROJECT,
location=TEST_LOCATION,
drain_pipeline=False,
)
mock_controller.return_value.wait_for_done.assert_called_once()

Expand Down Expand Up @@ -730,6 +732,7 @@ def test_start_template_dataflow_with_custom_region_as_parameter(
poll_sleep=10,
project_number=TEST_PROJECT,
location=TEST_LOCATION,
drain_pipeline=False,
)
mock_controller.return_value.wait_for_done.assert_called_once()

Expand Down Expand Up @@ -772,6 +775,7 @@ def test_start_template_dataflow_with_runtime_env(self, mock_conn, mock_dataflow
num_retries=5,
poll_sleep=10,
project_number=TEST_PROJECT,
drain_pipeline=False,
)
mock_uuid.assert_called_once_with()

Expand Down Expand Up @@ -818,6 +822,7 @@ def test_start_template_dataflow_update_runtime_env(self, mock_conn, mock_datafl
num_retries=5,
poll_sleep=10,
project_number=TEST_PROJECT,
drain_pipeline=False,
)
mock_uuid.assert_called_once_with()

Expand Down Expand Up @@ -868,6 +873,7 @@ def test_cancel_job(self, mock_get_conn, jobs_controller):
name=UNIQUE_JOB_NAME,
poll_sleep=10,
project_number=TEST_PROJECT,
drain_pipeline=False,
)
jobs_controller.cancel()

Expand Down Expand Up @@ -1196,6 +1202,60 @@ def test_dataflow_job_cancel_job(self):
)
mock_batch.add.assert_called_once_with(mock_update.return_value)

@parameterized.expand(
[
(False, "JOB_TYPE_BATCH", "JOB_STATE_CANCELLED"),
(False, "JOB_TYPE_STREAMING", "JOB_STATE_CANCELLED"),
(True, "JOB_TYPE_BATCH", "JOB_STATE_CANCELLED"),
(True, "JOB_TYPE_STREAMING", "JOB_STATE_DRAINED"),
]
)
def test_dataflow_job_cancel_or_drain_job(self, drain_pipeline, job_type, requested_state):
job = {
"id": TEST_JOB_ID,
"name": UNIQUE_JOB_NAME,
"currentState": DataflowJobStatus.JOB_STATE_RUNNING,
"type": job_type,
}
get_method = self.mock_dataflow.projects.return_value.locations.return_value.jobs.return_value.get
get_method.return_value.execute.return_value = job
# fmt: off
job_list_nest_method = (self.mock_dataflow
.projects.return_value.
locations.return_value.
jobs.return_value.list_next)
job_list_nest_method.return_value = None
# fmt: on
dataflow_job = _DataflowJobsController(
dataflow=self.mock_dataflow,
project_number=TEST_PROJECT,
name=UNIQUE_JOB_NAME,
location=TEST_LOCATION,
poll_sleep=10,
job_id=TEST_JOB_ID,
num_retries=20,
multiple_jobs=False,
drain_pipeline=drain_pipeline,
)
dataflow_job.cancel()

get_method.assert_called_once_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)

get_method.return_value.execute.assert_called_once_with(num_retries=20)

self.mock_dataflow.new_batch_http_request.assert_called_once_with()

mock_batch = self.mock_dataflow.new_batch_http_request.return_value
mock_update = self.mock_dataflow.projects.return_value.locations.return_value.jobs.return_value.update
mock_update.assert_called_once_with(
body={'requestedState': requested_state},
jobId='test-job-id',
location=TEST_LOCATION,
projectId='test-project',
)
mock_batch.add.assert_called_once_with(mock_update.return_value)
mock_batch.execute.assert_called_once()

def test_dataflow_job_cancel_job_no_running_jobs(self):
mock_jobs = self.mock_dataflow.projects.return_value.locations.return_value.jobs
get_method = mock_jobs.return_value.get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_successful_run(self):
hook_instance.start_python_dataflow.return_value = None
summary.execute(None)
mock_dataflow_hook.assert_called_once_with(
gcp_conn_id='google_cloud_default', delegate_to=None, poll_sleep=10
gcp_conn_id='google_cloud_default', delegate_to=None, poll_sleep=10, drain_pipeline=False
)
hook_instance.start_python_dataflow.assert_called_once_with(
job_name='{{task.task_id}}',
Expand Down

0 comments on commit e5713e0

Please sign in to comment.