Skip to content

Commit

Permalink
Add retry param in GCSObjectExistenceSensor (#27943)
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro committed Dec 4, 2022
1 parent 2b107e6 commit 5cdff50
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
7 changes: 5 additions & 2 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
from urllib.parse import urlsplit

from google.api_core.exceptions import NotFound
from google.api_core.retry import Retry

# not sure why but mypy complains on missing `storage` but it is clearly there and is importable
from google.cloud import storage # type: ignore[attr-defined]
from google.cloud.exceptions import GoogleCloudError
from google.cloud.storage.retry import DEFAULT_RETRY

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
Expand Down Expand Up @@ -533,18 +535,19 @@ def _call_with_retry(f: Callable[[], None]) -> None:
else:
raise ValueError("'filename' and 'data' parameter missing. One is required to upload to gcs.")

def exists(self, bucket_name: str, object_name: str) -> bool:
def exists(self, bucket_name: str, object_name: str, retry: Retry = DEFAULT_RETRY) -> bool:
"""
Checks for the existence of a file in Google Cloud Storage.
:param bucket_name: The Google Cloud Storage bucket where the object is.
:param object_name: The name of the blob_name to check in the Google cloud
storage bucket.
:param retry: (Optional) How to retry the RPC
"""
client = self.get_conn()
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name=object_name)
return blob.exists()
return blob.exists(retry=retry)

def get_blob_update_time(self, bucket_name: str, object_name: str):
"""
Expand Down
8 changes: 7 additions & 1 deletion airflow/providers/google/cloud/sensors/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from datetime import datetime
from typing import TYPE_CHECKING, Callable, Sequence

from google.api_core.retry import Retry
from google.cloud.storage.retry import DEFAULT_RETRY

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.sensors.base import BaseSensorOperator, poke_mode_only
Expand Down Expand Up @@ -51,6 +54,7 @@ class GCSObjectExistenceSensor(BaseSensorOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param retry: (Optional) How to retry the RPC
"""

template_fields: Sequence[str] = (
Expand All @@ -68,6 +72,7 @@ def __init__(
google_cloud_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
retry: Retry = DEFAULT_RETRY,
**kwargs,
) -> None:

Expand All @@ -77,6 +82,7 @@ def __init__(
self.google_cloud_conn_id = google_cloud_conn_id
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
self.retry = retry

def poke(self, context: Context) -> bool:
self.log.info("Sensor checks existence of : %s, %s", self.bucket, self.object)
Expand All @@ -85,7 +91,7 @@ def poke(self, context: Context) -> bool:
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
return hook.exists(self.bucket, self.object)
return hook.exists(self.bucket, self.object, self.retry)


def ts_function(context):
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/google/cloud/hooks/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

# dynamic storage type in google.cloud needs to be type-ignored
from google.cloud import exceptions, storage # type: ignore[attr-defined]
from google.cloud.storage.retry import DEFAULT_RETRY

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks import gcs
Expand Down Expand Up @@ -156,7 +157,7 @@ def test_exists(self, mock_service):
assert response
bucket_mock.assert_called_once_with(test_bucket)
blob_object.assert_called_once_with(blob_name=test_object)
exists_method.assert_called_once_with()
exists_method.assert_called_once_with(retry=DEFAULT_RETRY)

@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_exists_nonexisting_object(self, mock_service):
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/google/cloud/sensors/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pendulum
import pytest
from google.cloud.storage.retry import DEFAULT_RETRY

from airflow.exceptions import AirflowSensorTimeout
from airflow.models.dag import DAG, AirflowException
Expand Down Expand Up @@ -84,7 +85,7 @@ def test_should_pass_argument_to_hook(self, mock_hook):
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
mock_hook.return_value.exists.assert_called_once_with(TEST_BUCKET, TEST_OBJECT)
mock_hook.return_value.exists.assert_called_once_with(TEST_BUCKET, TEST_OBJECT, DEFAULT_RETRY)


class TestTsFunction(TestCase):
Expand Down

0 comments on commit 5cdff50

Please sign in to comment.