Skip to content

Commit

Permalink
Improve handling server errors in DataprocSubmitJobOperator (#11947)
Browse files Browse the repository at this point in the history
* Improve handling server errors in DataprocSubmitJobOperator

* fixup! Improve handling server errors in DataprocSubmitJobOperator
  • Loading branch information
turbaszek committed Nov 2, 2020
1 parent 1c7fbaf commit 6071fdd
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
18 changes: 15 additions & 3 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union

from cached_property import cached_property
from google.api_core.exceptions import ServerError
from google.api_core.retry import Retry
from google.cloud.dataproc_v1beta2 import ( # pylint: disable=no-name-in-module
ClusterControllerClient,
Expand Down Expand Up @@ -704,7 +705,9 @@ def instantiate_inline_workflow_template(
return operation

@GoogleBaseHook.fallback_to_default_project_id
def wait_for_job(self, job_id: str, location: str, project_id: str, wait_time: int = 10) -> None:
def wait_for_job(
self, job_id: str, location: str, project_id: str, wait_time: int = 10, timeout: Optional[int] = None
) -> None:
"""
Helper method which polls a job to check if it finishes.
Expand All @@ -716,12 +719,21 @@ def wait_for_job(self, job_id: str, location: str, project_id: str, wait_time: i
:type location: str
:param wait_time: Number of seconds between checks
:type wait_time: int
:param timeout: How many seconds wait for job to be ready. Used only if ``asynchronous`` is False
:type timeout: int
"""
state = None
start = time.monotonic()
while state not in (JobStatus.ERROR, JobStatus.DONE, JobStatus.CANCELLED):
if timeout and start + timeout < time.monotonic():
raise AirflowException(f"Timeout: dataproc job {job_id} is not ready after {timeout}s")
time.sleep(wait_time)
job = self.get_job(location=location, job_id=job_id, project_id=project_id)
state = job.status.state
try:
job = self.get_job(location=location, job_id=job_id, project_id=project_id)
state = job.status.state
except ServerError as err:
self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err)

if state == JobStatus.ERROR:
raise AirflowException('Job failed:\n{}'.format(job))
if state == JobStatus.CANCELLED:
Expand Down
9 changes: 8 additions & 1 deletion airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,9 +1785,12 @@ class DataprocSubmitJobOperator(BaseOperator):
:type asynchronous: bool
:param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called
:type cancel_on_kill: bool
:param wait_timeout: How many seconds wait for job to be ready. Used only if ``asynchronous`` is False
:type wait_timeout: int
"""

template_fields = ('project_id', 'location', 'job', 'impersonation_chain')
template_fields_renderers = {"job": "json"}

@apply_defaults
def __init__(
Expand All @@ -1804,6 +1807,7 @@ def __init__(
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
asynchronous: bool = False,
cancel_on_kill: bool = True,
wait_timeout: Optional[int] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -1820,6 +1824,7 @@ def __init__(
self.cancel_on_kill = cancel_on_kill
self.hook: Optional[DataprocHook] = None
self.job_id: Optional[str] = None
self.wait_timeout = wait_timeout

def execute(self, context: Dict):
self.log.info("Submitting job")
Expand All @@ -1838,7 +1843,9 @@ def execute(self, context: Dict):

if not self.asynchronous:
self.log.info('Waiting for job %s to complete', job_id)
self.hook.wait_for_job(job_id=job_id, location=self.location, project_id=self.project_id)
self.hook.wait_for_job(
job_id=job_id, location=self.location, project_id=self.project_id, timeout=self.wait_timeout
)
self.log.info('Job %s completed successfully.', job_id)

self.job_id = job_id
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def test_execute(self, mock_hook):
metadata=METADATA,
)
mock_hook.return_value.wait_for_job.assert_called_once_with(
job_id=job_id, project_id=GCP_PROJECT, location=GCP_LOCATION
job_id=job_id, project_id=GCP_PROJECT, location=GCP_LOCATION, timeout=None
)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
Expand Down

0 comments on commit 6071fdd

Please sign in to comment.