Skip to content

Commit

Permalink
Dataflow - add waiting for successful job cancel (#11501)
Browse files Browse the repository at this point in the history
Co-authored-by: Kamil Breguła <[email protected]>
  • Loading branch information
Tobiasz Kędzierski and Kamil Breguła committed Nov 6, 2020
1 parent bdcb6f8 commit 0caec9f
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 5 deletions.
46 changes: 43 additions & 3 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@
import warnings
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union, cast
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, TypeVar, Union, cast

from googleapiclient.discovery import build

from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.python_virtualenv import prepare_virtualenv
from airflow.utils.timeout import timeout

# This is the default location
# https://cloud.google.com/dataflow/pipelines/specifying-exec-params
Expand Down Expand Up @@ -147,9 +148,10 @@ class _DataflowJobsController(LoggingMixin):
not by specific job ID, then actions will be performed on all matching jobs.
:param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
instead of canceling.
:param cancel_timeout: wait time in seconds for successful job canceling
"""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
dataflow: Any,
project_number: str,
Expand All @@ -160,6 +162,7 @@ def __init__(
num_retries: int = 0,
multiple_jobs: bool = False,
drain_pipeline: bool = False,
cancel_timeout: Optional[int] = 5 * 60,
) -> None:

super().__init__()
Expand All @@ -171,8 +174,9 @@ def __init__(
self._job_id = job_id
self._num_retries = num_retries
self._poll_sleep = poll_sleep
self.drain_pipeline = drain_pipeline
self._cancel_timeout = cancel_timeout
self._jobs: Optional[List[dict]] = None
self.drain_pipeline = drain_pipeline

def is_job_running(self) -> bool:
"""
Expand Down Expand Up @@ -317,6 +321,29 @@ def get_jobs(self, refresh: bool = False) -> List[dict]:

return self._jobs

def _wait_for_states(self, expected_states: Set[str]):
"""Waiting for the jobs to reach a certain state."""
if not self._jobs:
raise ValueError("The _jobs should be set")
while True:
self._refresh_jobs()
job_states = {job['currentState'] for job in self._jobs}
if not job_states.difference(expected_states):
return
unexpected_failed_end_states = expected_states - DataflowJobStatus.FAILED_END_STATES
if unexpected_failed_end_states.intersection(job_states):
unexpected_failed_jobs = {
job for job in self._jobs if job['currentState'] in unexpected_failed_end_states
}
raise AirflowException(
"Jobs failed: "
+ ", ".join(
f"ID: {job['id']} name: {job['name']} state: {job['currentState']}"
for job in unexpected_failed_jobs
)
)
time.sleep(self._poll_sleep)

def cancel(self) -> None:
"""Cancels or drains current job"""
jobs = self.get_jobs()
Expand All @@ -342,6 +369,12 @@ def cancel(self) -> None:
)
)
batch.execute()
if self._cancel_timeout and isinstance(self._cancel_timeout, int):
timeout_error_message = "Canceling jobs failed due to timeout ({}s): {}".format(
self._cancel_timeout, ", ".join(job_ids)
)
with timeout(seconds=self._cancel_timeout, error_message=timeout_error_message):
self._wait_for_states({DataflowJobStatus.JOB_STATE_CANCELLED})
else:
self.log.info("No jobs to cancel")

Expand Down Expand Up @@ -453,9 +486,11 @@ def __init__(
poll_sleep: int = 10,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
drain_pipeline: bool = False,
cancel_timeout: Optional[int] = 5 * 60,
) -> None:
self.poll_sleep = poll_sleep
self.drain_pipeline = drain_pipeline
self.cancel_timeout = cancel_timeout
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
Expand Down Expand Up @@ -496,6 +531,7 @@ def _start_dataflow(
num_retries=self.num_retries,
multiple_jobs=multiple_jobs,
drain_pipeline=self.drain_pipeline,
cancel_timeout=self.cancel_timeout,
)
job_controller.wait_for_done()

Expand Down Expand Up @@ -669,6 +705,7 @@ def start_template_dataflow(
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
drain_pipeline=self.drain_pipeline,
cancel_timeout=self.cancel_timeout,
)
jobs_controller.wait_for_done()
return response["job"]
Expand Down Expand Up @@ -714,6 +751,7 @@ def start_flex_template(
location=location,
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
cancel_timeout=self.cancel_timeout,
)
jobs_controller.wait_for_done()

Expand Down Expand Up @@ -898,6 +936,7 @@ def is_job_dataflow_running(
poll_sleep=self.poll_sleep,
drain_pipeline=self.drain_pipeline,
num_retries=self.num_retries,
cancel_timeout=self.cancel_timeout,
)
return jobs_controller.is_job_running()

Expand Down Expand Up @@ -933,6 +972,7 @@ def cancel_job(
poll_sleep=self.poll_sleep,
drain_pipeline=self.drain_pipeline,
num_retries=self.num_retries,
cancel_timeout=self.cancel_timeout,
)
jobs_controller.cancel()

Expand Down
26 changes: 25 additions & 1 deletion airflow/providers/google/cloud/operators/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ class DataflowCreateJavaJobOperator(BaseOperator):
:type check_if_running: CheckJobRunning(IgnoreJob = do not check if running, FinishIfRunning=
if job is running finish with nothing, WaitForRun= wait until job finished and the run job)
``jar``, ``options``, and ``job_name`` are templated so you can use variables in them.
:param cancel_timeout: How long (in seconds) operator should wait for the pipeline to be
successfully cancelled when task is being killed.
:type cancel_timeout: Optional[int]
Note that both
``dataflow_default_options`` and ``options`` will be merged to specify pipeline
Expand Down Expand Up @@ -193,6 +196,7 @@ def __init__(
job_class: Optional[str] = None,
check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun,
multiple_jobs: Optional[bool] = None,
cancel_timeout: Optional[int] = 10 * 60,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -214,6 +218,7 @@ def __init__(
self.poll_sleep = poll_sleep
self.job_class = job_class
self.check_if_running = check_if_running
self.cancel_timeout = cancel_timeout
self.job_id = None
self.hook = None

Expand All @@ -222,6 +227,7 @@ def execute(self, context):
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
poll_sleep=self.poll_sleep,
cancel_timeout=self.cancel_timeout,
)
dataflow_options = copy.copy(self.dataflow_default_options)
dataflow_options.update(self.options)
Expand Down Expand Up @@ -324,6 +330,9 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
:type environment: Optional[dict]
:param cancel_timeout: How long (in seconds) operator should wait for the pipeline to be
successfully cancelled when task is being killed.
:type cancel_timeout: Optional[int]
It's a good practice to define dataflow_* parameters in the default_args of the dag
like the project, zone and staging location.
Expand Down Expand Up @@ -401,6 +410,7 @@ def __init__( # pylint: disable=too-many-arguments
poll_sleep: int = 10,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
environment: Optional[Dict] = None,
cancel_timeout: Optional[int] = 10 * 60,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -418,13 +428,15 @@ def __init__( # pylint: disable=too-many-arguments
self.hook: Optional[DataflowHook] = None
self.impersonation_chain = impersonation_chain
self.environment = environment
self.cancel_timeout = cancel_timeout

def execute(self, context) -> dict:
self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
poll_sleep=self.poll_sleep,
impersonation_chain=self.impersonation_chain,
cancel_timeout=self.cancel_timeout,
)

def set_current_job_id(job_id):
Expand Down Expand Up @@ -473,6 +485,9 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
instead of canceling during during killing task instance. See:
https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
:type drain_pipeline: bool
:param cancel_timeout: How long (in seconds) operator should wait for the pipeline to be
successfully cancelled when task is being killed.
:type cancel_timeout: Optional[int]
"""

template_fields = ["body", "location", "project_id", "gcp_conn_id"]
Expand All @@ -486,6 +501,7 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
drain_pipeline: bool = False,
cancel_timeout: Optional[int] = 10 * 60,
*args,
**kwargs,
) -> None:
Expand All @@ -495,15 +511,17 @@ def __init__(
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.drain_pipeline = drain_pipeline
self.cancel_timeout = cancel_timeout
self.job_id = None
self.hook: Optional[DataflowHook] = None
self.drain_pipeline = drain_pipeline

def execute(self, context):
self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
drain_pipeline=self.drain_pipeline,
cancel_timeout=self.cancel_timeout,
)

def set_current_job_id(job_id):
Expand Down Expand Up @@ -692,6 +710,9 @@ class DataflowCreatePythonJobOperator(BaseOperator):
instead of canceling during during killing task instance. See:
https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
:type drain_pipeline: bool
:param cancel_timeout: How long (in seconds) operator should wait for the pipeline to be
successfully cancelled when task is being killed.
:type cancel_timeout: Optional[int]
"""

template_fields = ["options", "dataflow_default_options", "job_name", "py_file"]
Expand All @@ -714,6 +735,7 @@ def __init__( # pylint: disable=too-many-arguments
delegate_to: Optional[str] = None,
poll_sleep: int = 10,
drain_pipeline: bool = False,
cancel_timeout: Optional[int] = 10 * 60,
**kwargs,
) -> None:

Expand All @@ -736,6 +758,7 @@ def __init__( # pylint: disable=too-many-arguments
self.delegate_to = delegate_to
self.poll_sleep = poll_sleep
self.drain_pipeline = drain_pipeline
self.cancel_timeout = cancel_timeout
self.job_id = None
self.hook = None

Expand All @@ -754,6 +777,7 @@ def execute(self, context):
delegate_to=self.delegate_to,
poll_sleep=self.poll_sleep,
drain_pipeline=self.drain_pipeline,
cancel_timeout=self.cancel_timeout,
)
dataflow_options = self.dataflow_default_options.copy()
dataflow_options.update(self.options)
Expand Down

0 comments on commit 0caec9f

Please sign in to comment.