Skip to content

Commit

Permalink
Add on_kill method to DataprocSubmitJobOperator (#10847)
Browse files Browse the repository at this point in the history
  • Loading branch information
tszerszen committed Sep 10, 2020
1 parent e773f8b commit 68cc727
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
19 changes: 15 additions & 4 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1790,6 +1790,8 @@ class DataprocSubmitJobOperator(BaseOperator):
This is useful for submitting long running jobs and
waiting on them asynchronously using the DataprocJobSensor
:type asynchronous: bool
:param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called
:type cancel_on_kill: bool
"""

template_fields = ('project_id', 'location', 'job', 'impersonation_chain')
Expand All @@ -1808,6 +1810,7 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
asynchronous: bool = False,
cancel_on_kill: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -1821,11 +1824,14 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.asynchronous = asynchronous
self.cancel_on_kill = cancel_on_kill
self.hook: Optional[DataprocHook] = None
self.job_id: Optional[str] = None

def execute(self, context: Dict):
self.log.info("Submitting job")
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
job_object = hook.submit_job(
self.hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
job_object = self.hook.submit_job(
project_id=self.project_id,
location=self.location,
job=self.job,
Expand All @@ -1839,10 +1845,15 @@ def execute(self, context: Dict):

if not self.asynchronous:
self.log.info('Waiting for job %s to complete', job_id)
hook.wait_for_job(job_id=job_id, location=self.location, project_id=self.project_id)
self.hook.wait_for_job(job_id=job_id, location=self.location, project_id=self.project_id)
self.log.info('Job %s completed successfully.', job_id)

return job_id
self.job_id = job_id
return self.job_id

def on_kill(self):
if self.job_id and self.cancel_on_kill:
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id, location=self.location)


class DataprocUpdateClusterOperator(BaseOperator):
Expand Down
31 changes: 31 additions & 0 deletions tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,37 @@ def test_execute_async(self, mock_hook):
)
mock_hook.return_value.wait_for_job.assert_not_called()

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_on_kill(self, mock_hook):
job = {}
job_id = "job_id"
mock_hook.return_value.wait_for_job.return_value = None
mock_hook.return_value.submit_job.return_value.reference.job_id = job_id

op = DataprocSubmitJobOperator(
task_id=TASK_ID,
location=GCP_LOCATION,
project_id=GCP_PROJECT,
job=job,
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
request_id=REQUEST_ID,
impersonation_chain=IMPERSONATION_CHAIN,
cancel_on_kill=False,
)
op.execute(context={})

op.on_kill()
mock_hook.return_value.cancel_job.assert_not_called()

op.cancel_on_kill = True
op.on_kill()
mock_hook.return_value.cancel_job.assert_called_once_with(
project_id=GCP_PROJECT, location=GCP_LOCATION, job_id=job_id
)


class TestDataprocUpdateClusterOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
Expand Down

0 comments on commit 68cc727

Please sign in to comment.