Skip to content

Commit 495ae23

Browse files
authored
Optimize deferrable mode execution for DataprocSubmitJobOperator (#31317)
1 parent 9276310 commit 495ae23

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

airflow/providers/google/cloud/operators/dataproc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,6 +2030,14 @@ def execute(self, context: Context):
20302030

20312031
self.job_id = new_job_id
20322032
if self.deferrable:
2033+
job = self.hook.get_job(project_id=self.project_id, region=self.region, job_id=self.job_id)
2034+
state = job.status.state
2035+
if state == JobStatus.State.DONE:
2036+
return self.job_id
2037+
elif state == JobStatus.State.ERROR:
2038+
raise AirflowException(f"Job failed:\n{job}")
2039+
elif state == JobStatus.State.CANCELLED:
2040+
raise AirflowException(f"Job was cancelled:\n{job}")
20332041
self.defer(
20342042
trigger=DataprocSubmitTrigger(
20352043
job_id=self.job_id,

tests/providers/google/cloud/operators/test_dataproc.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import pytest
2424
from google.api_core.exceptions import AlreadyExists, NotFound
2525
from google.api_core.retry import Retry
26-
from google.cloud.dataproc_v1 import Batch
26+
from google.cloud.dataproc_v1 import Batch, JobStatus
2727

2828
from airflow.exceptions import (
2929
AirflowException,
@@ -1058,6 +1058,32 @@ def test_execute_deferrable(self, mock_trigger_hook, mock_hook):
10581058
assert isinstance(exc.value.trigger, DataprocSubmitTrigger)
10591059
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
10601060

1061+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
1062+
@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocSubmitJobOperator.defer")
1063+
@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocHook.submit_job")
1064+
def test_dataproc_operator_execute_async_done_before_defer(self, mock_submit_job, mock_defer, mock_hook):
1065+
mock_submit_job.return_value.reference.job_id = TEST_JOB_ID
1066+
job_status = mock_hook.return_value.get_job.return_value.status
1067+
job_status.state = JobStatus.State.DONE
1068+
1069+
op = DataprocSubmitJobOperator(
1070+
task_id=TASK_ID,
1071+
region=GCP_REGION,
1072+
project_id=GCP_PROJECT,
1073+
job={},
1074+
gcp_conn_id=GCP_CONN_ID,
1075+
retry=RETRY,
1076+
asynchronous=True,
1077+
timeout=TIMEOUT,
1078+
metadata=METADATA,
1079+
request_id=REQUEST_ID,
1080+
impersonation_chain=IMPERSONATION_CHAIN,
1081+
deferrable=True,
1082+
)
1083+
1084+
op.execute(context=self.mock_context)
1085+
assert not mock_defer.called
1086+
10611087
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
10621088
def test_on_kill(self, mock_hook):
10631089
job = {}

0 commit comments

Comments
 (0)