Skip to content

Commit

Permalink
Add expected_terminal_state parameter to Dataflow operators (#34217)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahar1 committed Sep 11, 2023
1 parent 25d463c commit 050a47a
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 34 deletions.
60 changes: 41 additions & 19 deletions airflow/providers/google/cloud/hooks/dataflow.py
Expand Up @@ -175,16 +175,16 @@ class _DataflowJobsController(LoggingMixin):
:param num_retries: Maximum number of retries in case of connection problems.
:param multiple_jobs: If set to true this task will be searched by name prefix (``name`` parameter),
not by specific job ID, then actions will be performed on all matching jobs.
:param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
:param drain_pipeline: Optional, set to True if we want to stop streaming job by draining it
instead of canceling.
:param cancel_timeout: wait time in seconds for successful job canceling
:param wait_until_finished: If True, wait for the end of pipeline execution before exiting. If False,
it only submits job and check once is job not in terminal state.
The default behavior depends on the type of pipeline:
* for the streaming pipeline, wait for jobs to start,
* for the batch pipeline, wait for the jobs to complete.
* for the streaming pipeline, wait for jobs to be in JOB_STATE_RUNNING,
* for the batch pipeline, wait for the jobs to be in JOB_STATE_DONE.
"""

def __init__(
Expand All @@ -200,6 +200,7 @@ def __init__(
drain_pipeline: bool = False,
cancel_timeout: int | None = 5 * 60,
wait_until_finished: bool | None = None,
expected_terminal_state: str | None = None,
) -> None:

super().__init__()
Expand All @@ -215,6 +216,7 @@ def __init__(
self._jobs: list[dict] | None = None
self.drain_pipeline = drain_pipeline
self._wait_until_finished = wait_until_finished
self._expected_terminal_state = expected_terminal_state

def is_job_running(self) -> bool:
"""
Expand Down Expand Up @@ -391,27 +393,44 @@ def _check_dataflow_job_state(self, job) -> bool:
:return: True if job is done.
:raise: Exception
"""
if self._wait_until_finished is None:
wait_for_running = job.get("type") == DataflowJobType.JOB_TYPE_STREAMING
current_state = job["currentState"]
is_streaming = job.get("type") == DataflowJobType.JOB_TYPE_STREAMING

if self._expected_terminal_state is None:
if is_streaming:
self._expected_terminal_state = DataflowJobStatus.JOB_STATE_RUNNING
else:
self._expected_terminal_state = DataflowJobStatus.JOB_STATE_DONE
else:
wait_for_running = not self._wait_until_finished
terminal_states = DataflowJobStatus.TERMINAL_STATES | {DataflowJobStatus.JOB_STATE_RUNNING}
if self._expected_terminal_state not in terminal_states:
raise Exception(
f"Google Cloud Dataflow job's expected terminal state "
f"'{self._expected_terminal_state}' is invalid."
f" The value should be any of the following: {terminal_states}"
)
elif is_streaming and self._expected_terminal_state == DataflowJobStatus.JOB_STATE_DONE:
raise Exception(
"Google Cloud Dataflow job's expected terminal state cannot be "
"JOB_STATE_DONE while it is a streaming job"
)
elif not is_streaming and self._expected_terminal_state == DataflowJobStatus.JOB_STATE_DRAINED:
raise Exception(
"Google Cloud Dataflow job's expected terminal state cannot be "
"JOB_STATE_DRAINED while it is a batch job"
)

if job["currentState"] == DataflowJobStatus.JOB_STATE_DONE:
if not self._wait_until_finished and current_state == self._expected_terminal_state:
return True
elif job["currentState"] == DataflowJobStatus.JOB_STATE_FAILED:
raise Exception(f"Google Cloud Dataflow job {job['name']} has failed.")
elif job["currentState"] == DataflowJobStatus.JOB_STATE_CANCELLED:
raise Exception(f"Google Cloud Dataflow job {job['name']} was cancelled.")
elif job["currentState"] == DataflowJobStatus.JOB_STATE_DRAINED:
raise Exception(f"Google Cloud Dataflow job {job['name']} was drained.")
elif job["currentState"] == DataflowJobStatus.JOB_STATE_UPDATED:
raise Exception(f"Google Cloud Dataflow job {job['name']} was updated.")
elif job["currentState"] == DataflowJobStatus.JOB_STATE_RUNNING and wait_for_running:
return True
elif job["currentState"] in DataflowJobStatus.AWAITING_STATES:

if current_state in DataflowJobStatus.AWAITING_STATES:
return self._wait_until_finished is False

self.log.debug("Current job: %s", str(job))
raise Exception(f"Google Cloud Dataflow job {job['name']} was unknown state: {job['currentState']}")
raise Exception(
f"Google Cloud Dataflow job {job['name']} is in an unexpected terminal state: {current_state}, "
f"expected terminal state: {self._expected_terminal_state}"
)

def wait_for_done(self) -> None:
"""Helper method to wait for result of submitted job."""
Expand Down Expand Up @@ -514,6 +533,7 @@ def __init__(
drain_pipeline: bool = False,
cancel_timeout: int | None = 5 * 60,
wait_until_finished: bool | None = None,
expected_terminal_state: str | None = None,
**kwargs,
) -> None:
if kwargs.get("delegate_to") is not None:
Expand All @@ -527,6 +547,7 @@ def __init__(
self.wait_until_finished = wait_until_finished
self.job_id: str | None = None
self.beam_hook = BeamHook(BeamRunnerType.DataflowRunner)
self.expected_terminal_state = expected_terminal_state
super().__init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
Expand Down Expand Up @@ -691,6 +712,7 @@ def start_template_dataflow(
drain_pipeline=self.drain_pipeline,
cancel_timeout=self.cancel_timeout,
wait_until_finished=self.wait_until_finished,
expected_terminal_state=self.expected_terminal_state,
)
jobs_controller.wait_for_done()
return response["job"]
Expand Down
15 changes: 15 additions & 0 deletions airflow/providers/google/cloud/operators/dataflow.py
Expand Up @@ -288,6 +288,8 @@ class DataflowCreateJavaJobOperator(GoogleCloudBaseOperator):
If you in your pipeline do not call the wait_for_pipeline method, and pass wait_until_finish=False
to the operator, the second loop will check once is job not in terminal state and exit the loop.
:param expected_terminal_state: The expected terminal state of the operator on which the corresponding
Airflow task succeeds. When not specified, it will be determined by the hook.
Note that both
``dataflow_default_options`` and ``options`` will be merged to specify pipeline
Expand Down Expand Up @@ -349,6 +351,7 @@ def __init__(
multiple_jobs: bool = False,
cancel_timeout: int | None = 10 * 60,
wait_until_finished: bool | None = None,
expected_terminal_state: str | None = None,
**kwargs,
) -> None:
# TODO: Remove one day
Expand Down Expand Up @@ -378,6 +381,7 @@ def __init__(
self.check_if_running = check_if_running
self.cancel_timeout = cancel_timeout
self.wait_until_finished = wait_until_finished
self.expected_terminal_state = expected_terminal_state
self.job_id = None
self.beam_hook: BeamHook | None = None
self.dataflow_hook: DataflowHook | None = None
Expand All @@ -390,6 +394,7 @@ def execute(self, context: Context):
poll_sleep=self.poll_sleep,
cancel_timeout=self.cancel_timeout,
wait_until_finished=self.wait_until_finished,
expected_terminal_state=self.expected_terminal_state,
)
job_name = self.dataflow_hook.build_dataflow_job_name(job_name=self.job_name)
pipeline_options = copy.deepcopy(self.dataflow_default_options)
Expand Down Expand Up @@ -531,6 +536,8 @@ class DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
If you in your pipeline do not call the wait_for_pipeline method, and pass wait_until_finish=False
to the operator, the second loop will check once is job not in terminal state and exit the loop.
:param expected_terminal_state: The expected terminal state of the operator on which the corresponding
Airflow task succeeds. When not specified, it will be determined by the hook.
It's a good practice to define dataflow_* parameters in the default_args of the dag
like the project, zone and staging location.
Expand Down Expand Up @@ -614,6 +621,7 @@ def __init__(
wait_until_finished: bool | None = None,
append_job_name: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
expected_terminal_state: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -633,6 +641,7 @@ def __init__(
self.wait_until_finished = wait_until_finished
self.append_job_name = append_job_name
self.deferrable = deferrable
self.expected_terminal_state = expected_terminal_state

self.job: dict | None = None

Expand All @@ -657,6 +666,7 @@ def hook(self) -> DataflowHook:
impersonation_chain=self.impersonation_chain,
cancel_timeout=self.cancel_timeout,
wait_until_finished=self.wait_until_finished,
expected_terminal_state=self.expected_terminal_state,
)
return hook

Expand Down Expand Up @@ -787,6 +797,8 @@ class DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param expected_terminal_state: The expected final status of the operator on which the corresponding
Airflow task succeeds. When not specified, it will be determined by the hook.
:param append_job_name: True if unique suffix has to be appended to job name.
"""

Expand All @@ -805,6 +817,7 @@ def __init__(
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
append_job_name: bool = True,
expected_terminal_state: str | None = None,
*args,
**kwargs,
) -> None:
Expand All @@ -819,6 +832,7 @@ def __init__(
self.job: dict | None = None
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.expected_terminal_state = expected_terminal_state
self.append_job_name = append_job_name

self._validate_deferrable_params()
Expand All @@ -842,6 +856,7 @@ def hook(self) -> DataflowHook:
cancel_timeout=self.cancel_timeout,
wait_until_finished=self.wait_until_finished,
impersonation_chain=self.impersonation_chain,
expected_terminal_state=self.expected_terminal_state,
)
return hook

Expand Down
76 changes: 61 additions & 15 deletions tests/providers/google/cloud/hooks/test_dataflow.py
Expand Up @@ -861,6 +861,7 @@ def test_start_template_dataflow(self, mock_conn, mock_controller, mock_uuid):
project_number=TEST_PROJECT,
location=DEFAULT_DATAFLOW_LOCATION,
drain_pipeline=False,
expected_terminal_state=None,
cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
wait_until_finished=None,
)
Expand Down Expand Up @@ -900,6 +901,7 @@ def test_start_template_dataflow_with_custom_region_as_variable(
project_number=TEST_PROJECT,
location=TEST_LOCATION,
drain_pipeline=False,
expected_terminal_state=None,
cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
wait_until_finished=None,
)
Expand Down Expand Up @@ -943,6 +945,7 @@ def test_start_template_dataflow_with_custom_region_as_parameter(
drain_pipeline=False,
cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
wait_until_finished=None,
expected_terminal_state=None,
)
mock_controller.return_value.wait_for_done.assert_called_once()

Expand Down Expand Up @@ -986,6 +989,7 @@ def test_start_template_dataflow_with_runtime_env(self, mock_conn, mock_dataflow
drain_pipeline=False,
cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
wait_until_finished=None,
expected_terminal_state=None,
)
mock_uuid.assert_called_once_with()

Expand Down Expand Up @@ -1033,6 +1037,7 @@ def test_start_template_dataflow_update_runtime_env(self, mock_conn, mock_datafl
drain_pipeline=False,
cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
wait_until_finished=None,
expected_terminal_state=None,
)
mock_uuid.assert_called_once_with()

Expand Down Expand Up @@ -1232,13 +1237,13 @@ def test_dataflow_job_wait_for_multiple_jobs(self):
@pytest.mark.parametrize(
"state, exception_regex",
[
(DataflowJobStatus.JOB_STATE_FAILED, "Google Cloud Dataflow job name-2 has failed\\."),
(DataflowJobStatus.JOB_STATE_CANCELLED, "Google Cloud Dataflow job name-2 was cancelled\\."),
(DataflowJobStatus.JOB_STATE_DRAINED, "Google Cloud Dataflow job name-2 was drained\\."),
(DataflowJobStatus.JOB_STATE_UPDATED, "Google Cloud Dataflow job name-2 was updated\\."),
(DataflowJobStatus.JOB_STATE_FAILED, "unexpected terminal state: JOB_STATE_FAILED"),
(DataflowJobStatus.JOB_STATE_CANCELLED, "unexpected terminal state: JOB_STATE_CANCELLED"),
(DataflowJobStatus.JOB_STATE_DRAINED, "unexpected terminal state: JOB_STATE_DRAINED"),
(DataflowJobStatus.JOB_STATE_UPDATED, "unexpected terminal state: JOB_STATE_UPDATED"),
(
DataflowJobStatus.JOB_STATE_UNKNOWN,
"Google Cloud Dataflow job name-2 was unknown state: JOB_STATE_UNKNOWN",
"JOB_STATE_UNKNOWN",
),
],
)
Expand Down Expand Up @@ -1446,52 +1451,52 @@ def test_check_dataflow_job_state_without_job_type(self, job_state, wait_until_f
(
DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_FAILED,
"Google Cloud Dataflow job name-2 has failed\\.",
"JOB_STATE_FAILED",
),
(
DataflowJobType.JOB_TYPE_STREAMING,
DataflowJobStatus.JOB_STATE_FAILED,
"Google Cloud Dataflow job name-2 has failed\\.",
"JOB_STATE_FAILED",
),
(
DataflowJobType.JOB_TYPE_STREAMING,
DataflowJobStatus.JOB_STATE_UNKNOWN,
"Google Cloud Dataflow job name-2 was unknown state: JOB_STATE_UNKNOWN",
"JOB_STATE_UNKNOWN",
),
(
DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_UNKNOWN,
"Google Cloud Dataflow job name-2 was unknown state: JOB_STATE_UNKNOWN",
"JOB_STATE_UNKNOWN",
),
(
DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_CANCELLED,
"Google Cloud Dataflow job name-2 was cancelled\\.",
"JOB_STATE_CANCELLED",
),
(
DataflowJobType.JOB_TYPE_STREAMING,
DataflowJobStatus.JOB_STATE_CANCELLED,
"Google Cloud Dataflow job name-2 was cancelled\\.",
"JOB_STATE_CANCELLED",
),
(
DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_DRAINED,
"Google Cloud Dataflow job name-2 was drained\\.",
"JOB_STATE_DRAINED",
),
(
DataflowJobType.JOB_TYPE_STREAMING,
DataflowJobStatus.JOB_STATE_DRAINED,
"Google Cloud Dataflow job name-2 was drained\\.",
"JOB_STATE_DRAINED",
),
(
DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_UPDATED,
"Google Cloud Dataflow job name-2 was updated\\.",
"JOB_STATE_UPDATED",
),
(
DataflowJobType.JOB_TYPE_STREAMING,
DataflowJobStatus.JOB_STATE_UPDATED,
"Google Cloud Dataflow job name-2 was updated\\.",
"JOB_STATE_UPDATED",
),
],
)
Expand All @@ -1510,6 +1515,47 @@ def test_check_dataflow_job_state_terminal_state(self, job_type, job_state, exce
with pytest.raises(Exception, match=exception_regex):
dataflow_job._check_dataflow_job_state(job)

@pytest.mark.parametrize(
"job_type, expected_terminal_state, match",
[
(
DataflowJobType.JOB_TYPE_BATCH,
"test",
"invalid",
),
(
DataflowJobType.JOB_TYPE_STREAMING,
DataflowJobStatus.JOB_STATE_DONE,
"cannot be JOB_STATE_DONE while it is a streaming job",
),
(
DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_DRAINED,
"cannot be JOB_STATE_DRAINED while it is a batch job",
),
],
)
def test_check_dataflow_job_state__invalid_expected_state(self, job_type, expected_terminal_state, match):
job = {
"id": "id-2",
"name": "name-2",
"type": job_type,
"currentState": DataflowJobStatus.JOB_STATE_QUEUED,
}
dataflow_job = _DataflowJobsController(
dataflow=self.mock_dataflow,
project_number=TEST_PROJECT,
name=UNIQUE_JOB_NAME,
location=TEST_LOCATION,
poll_sleep=0,
job_id=TEST_JOB_ID,
num_retries=20,
multiple_jobs=False,
expected_terminal_state=expected_terminal_state,
)
with pytest.raises(Exception, match=match):
dataflow_job._check_dataflow_job_state(job)

def test_dataflow_job_cancel_job(self):
mock_jobs = self.mock_dataflow.projects.return_value.locations.return_value.jobs
get_method = mock_jobs.return_value.get
Expand Down

0 comments on commit 050a47a

Please sign in to comment.