Skip to content

Commit

Permalink
Fix BigQueryInsertJobOperator cancel_on_kill (#25342)
Browse files Browse the repository at this point in the history
  • Loading branch information
lidalei committed Aug 4, 2022
1 parent dd06797 commit e84d753
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 11 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,7 +1415,7 @@ def cancel_job(
location: Optional[str] = None,
) -> None:
"""
Cancels a job an wait for cancellation to complete
Cancel a job and wait for cancellation to complete
:param job_id: id of the job.
:param project_id: Google Cloud Project where the job is running
Expand Down
18 changes: 10 additions & 8 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2143,14 +2143,15 @@ def _submit_job(
hook: BigQueryHook,
job_id: str,
) -> BigQueryJob:
# Submit a new job and wait for it to complete and get the result.
# Submit a new job without waiting for it to complete.
return hook.insert_job(
configuration=self.configuration,
project_id=self.project_id,
location=self.location,
job_id=job_id,
timeout=self.result_timeout,
retry=self.result_retry,
nowait=True,
)

@staticmethod
Expand Down Expand Up @@ -2178,19 +2179,14 @@ def execute(self, context: Any):
try:
self.log.info("Executing: %s'", self.configuration)
job = self._submit_job(hook, job_id)
self._handle_job_error(job)
except Conflict:
# If the job already exists retrieve it
job = hook.get_job(
project_id=self.project_id,
location=self.location,
job_id=job_id,
)
if job.state in self.reattach_states:
# We are reattaching to a job
job.result(timeout=self.result_timeout, retry=self.result_retry)
self._handle_job_error(job)
else:
if job.state not in self.reattach_states:
# Same job configuration so we need force_rerun
raise AirflowException(
f"Job with id: {job_id} already exists and is in {job.state} state. If you "
Expand Down Expand Up @@ -2225,10 +2221,16 @@ def execute(self, context: Any):
BigQueryTableLink.persist(**persist_kwargs)

self.job_id = job.job_id
return job.job_id
# Wait for the job to complete
job.result(timeout=self.result_timeout, retry=self.result_retry)
self._handle_job_error(job)

return self.job_id

def on_kill(self) -> None:
if self.job_id and self.cancel_on_kill:
self.hook.cancel_job( # type: ignore[union-attr]
job_id=self.job_id, project_id=self.project_id, location=self.location
)
else:
self.log.info('Skipping to cancel job: %s:%s.%s', self.project_id, self.location, self.job_id)
46 changes: 44 additions & 2 deletions tests/providers/google/cloud/operators/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from google.cloud.bigquery import DEFAULT_RETRY
from google.cloud.exceptions import Conflict

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.providers.google.cloud.operators.bigquery import (
BigQueryCheckOperator,
BigQueryConsoleIndexableLink,
Expand Down Expand Up @@ -830,6 +830,7 @@ def test_execute_query_success(self, mock_hook):
configuration=configuration,
location=TEST_DATASET_LOCATION,
job_id=real_job_id,
nowait=True,
project_id=TEST_GCP_PROJECT_ID,
retry=DEFAULT_RETRY,
timeout=None,
Expand Down Expand Up @@ -870,6 +871,7 @@ def test_execute_copy_success(self, mock_hook):
configuration=configuration,
location=TEST_DATASET_LOCATION,
job_id=real_job_id,
nowait=True,
project_id=TEST_GCP_PROJECT_ID,
retry=DEFAULT_RETRY,
timeout=None,
Expand Down Expand Up @@ -913,6 +915,45 @@ def test_on_kill(self, mock_hook):
project_id=TEST_GCP_PROJECT_ID,
)

@mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook')
@mock.patch('airflow.providers.google.cloud.hooks.bigquery.BigQueryJob')
def test_on_kill_after_execution_timeout(self, mock_job, mock_hook):
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

configuration = {
"query": {
"query": "SELECT * FROM any",
"useLegacySql": False,
}
}

mock_job.job_id = real_job_id
mock_job.error_result = False
mock_job.result.side_effect = AirflowTaskTimeout()

mock_hook.return_value.insert_job.return_value = mock_job
mock_hook.return_value.generate_job_id.return_value = real_job_id

op = BigQueryInsertJobOperator(
task_id="insert_query_job",
configuration=configuration,
location=TEST_DATASET_LOCATION,
job_id=job_id,
project_id=TEST_GCP_PROJECT_ID,
cancel_on_kill=True,
)
with pytest.raises(AirflowTaskTimeout):
op.execute(context=MagicMock())

op.on_kill()
mock_hook.return_value.cancel_job.assert_called_once_with(
job_id=real_job_id,
location=TEST_DATASET_LOCATION,
project_id=TEST_GCP_PROJECT_ID,
)

@mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook')
def test_execute_failure(self, mock_hook):
job_id = "123456"
Expand Down Expand Up @@ -1018,6 +1059,7 @@ def test_execute_force_rerun(self, mock_hook):
configuration=configuration,
location=TEST_DATASET_LOCATION,
job_id=real_job_id,
nowait=True,
project_id=TEST_GCP_PROJECT_ID,
retry=DEFAULT_RETRY,
timeout=None,
Expand All @@ -1038,7 +1080,7 @@ def test_execute_no_force_rerun(self, mock_hook):
}
}

mock_hook.return_value.insert_job.return_value.result.side_effect = Conflict("any")
mock_hook.return_value.insert_job.side_effect = Conflict("any")
mock_hook.return_value.generate_job_id.return_value = real_job_id
job = MagicMock(
job_id=real_job_id,
Expand Down

0 comments on commit e84d753

Please sign in to comment.