Skip to content

Commit

Permalink
Added wait mechanizm to the DataprocJobSensor to avoid 509 errors whe…
Browse files Browse the repository at this point in the history
…n Job is not available (#19740)
  • Loading branch information
Łukasz Wyszomirski committed Nov 22, 2021
1 parent 56bdfe7 commit 0b2e1a8
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 4 deletions.
37 changes: 33 additions & 4 deletions airflow/providers/google/cloud/sensors/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
# under the License.
"""This module contains a Dataproc Job sensor."""
# pylint: disable=C0302
import time
import warnings
from typing import Optional
from typing import Dict, Optional

from google.api_core.exceptions import ServerError
from google.cloud.dataproc_v1.types import JobStatus

from airflow.exceptions import AirflowException
Expand All @@ -42,6 +44,8 @@ class DataprocJobSensor(BaseSensorOperator):
:type location: str
:param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform.
:type gcp_conn_id: str
:param wait_timeout: How many seconds wait for job to be ready.
:type wait_timeout: int
"""

template_fields = ('project_id', 'region', 'dataproc_job_id')
Expand All @@ -55,6 +59,7 @@ def __init__(
region: str = None,
location: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
wait_timeout: Optional[int] = None,
**kwargs,
) -> None:
if region is None:
Expand All @@ -73,12 +78,36 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.dataproc_job_id = dataproc_job_id
self.region = region
self.wait_timeout = wait_timeout
self.start_sensor_time = None

def poke(self, context: dict) -> bool:
def execute(self, context: Dict):
self.start_sensor_time = time.monotonic()
super().execute(context)

def _duration(self):
return time.monotonic() - self.start_sensor_time

def poke(self, context: Dict) -> bool:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
job = hook.get_job(job_id=self.dataproc_job_id, region=self.region, project_id=self.project_id)
state = job.status.state
if self.wait_timeout:
try:
job = hook.get_job(
job_id=self.dataproc_job_id, region=self.region, project_id=self.project_id
)
except ServerError as err:
self.log.info(f"DURATION RUN: {self._duration()}")
if self._duration() > self.wait_timeout:
raise AirflowException(
f"Timeout: dataproc job {self.dataproc_job_id} "
f"is not ready after {self.wait_timeout}s"
)
self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err)
return False
else:
job = hook.get_job(job_id=self.dataproc_job_id, region=self.region, project_id=self.project_id)

state = job.status.state
if state == JobStatus.State.ERROR:
raise AirflowException(f'Job failed:\n{job}')
elif state in {
Expand Down
44 changes: 44 additions & 0 deletions tests/providers/google/cloud/sensors/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import unittest
from unittest import mock
from unittest.mock import Mock

import pytest
from google.api_core.exceptions import ServerError
from google.cloud.dataproc_v1.types import JobStatus

from airflow import AirflowException
Expand Down Expand Up @@ -164,3 +166,45 @@ def test_location_deprecation_warning(self, mock_hook):
timeout=TIMEOUT,
)
sensor.poke(context={})

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_wait_timeout(self, mock_hook):
job_id = "job_id"
mock_hook.return_value.get_job.side_effect = ServerError("Job are not ready")

sensor = DataprocJobSensor(
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
dataproc_job_id=job_id,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
wait_timeout=300,
)

sensor._duration = Mock()
sensor._duration.return_value = 200

result = sensor.poke(context={})
assert not result

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_wait_timeout_raise_exception(self, mock_hook):
job_id = "job_id"
mock_hook.return_value.get_job.side_effect = ServerError("Job are not ready")

sensor = DataprocJobSensor(
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
dataproc_job_id=job_id,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
wait_timeout=300,
)

sensor._duration = Mock()
sensor._duration.return_value = 301

with pytest.raises(AirflowException, match="Timeout: dataproc job job_id is not ready after 300s"):
sensor.poke(context={})

0 comments on commit 0b2e1a8

Please sign in to comment.