Skip to content

Commit

Permalink
merge BigQueryTableExistenceAsyncSensor into BigQueryTableExistenceSe…
Browse files Browse the repository at this point in the history
…nsor (#30235)

* feat(providers/google): move the async execution logic from BigQueryTableExistenceAsyncSensor to BigQueryTableExistenceSensor

* docs(providers/google): update the doc for BigQueryTableExistenceSensor deferrable mode and BigQueryTableExistenceAsyncSensor deprecation

* test(providers/google): add test cases for BigQueryTableExistenceSensor when its deferrable attribute is set to True

* refactor(providers/google): deprecate poll_interval argument

---------

Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
Lee-W and uranusjr committed Mar 30, 2023
1 parent f08296d commit 540a076
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 67 deletions.
96 changes: 56 additions & 40 deletions airflow/providers/google/cloud/sensors/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,22 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
**kwargs,
) -> None:
if deferrable and "poke_interval" not in kwargs:
# TODO: Remove once deprecated
if "polling_interval" in kwargs:
kwargs["poke_interval"] = kwargs["polling_interval"]
warnings.warn(
"Argument `poll_interval` is deprecated and will be removed "
"in a future release. Please use `poke_interval` instead.",
DeprecationWarning,
stacklevel=2,
)
else:
kwargs["poke_interval"] = 5

super().__init__(**kwargs)

self.project_id = project_id
Expand All @@ -90,6 +104,8 @@ def __init__(
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain

self.deferrable = deferrable

def poke(self, context: Context) -> bool:
table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}"
self.log.info("Sensor checks existence of table: %s", table_uri)
Expand All @@ -102,6 +118,38 @@ def poke(self, context: Context) -> bool:
project_id=self.project_id, dataset_id=self.dataset_id, table_id=self.table_id
)

def execute(self, context: Context) -> None:
"""Airflow runs this method on the worker and defers using the trigger."""
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=BigQueryTableExistenceTrigger(
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=self.project_id,
poll_interval=self.poke_interval,
gcp_conn_id=self.gcp_conn_id,
hook_params={
"delegate_to": self.delegate_to,
"impersonation_chain": self.impersonation_chain,
},
),
method_name="execute_complete",
)

def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}"
self.log.info("Sensor checks existence of table: %s", table_uri)
if event:
if event["status"] == "success":
return event["message"]
raise AirflowException(event["message"])
raise AirflowException("No event received in trigger callback")


class BigQueryTablePartitionExistenceSensor(BaseSensorOperator):
"""
Expand Down Expand Up @@ -249,47 +297,15 @@ class BigQueryTableExistenceAsyncSensor(BigQueryTableExistenceSensor):
:param polling_interval: The interval in seconds to wait between checks table existence.
"""

def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
polling_interval: float = 5.0,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.polling_interval = polling_interval
self.gcp_conn_id = gcp_conn_id

def execute(self, context: Context) -> None:
"""Airflow runs this method on the worker and defers using the trigger."""
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=BigQueryTableExistenceTrigger(
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=self.project_id,
poll_interval=self.polling_interval,
gcp_conn_id=self.gcp_conn_id,
hook_params={
"delegate_to": self.delegate_to,
"impersonation_chain": self.impersonation_chain,
},
),
method_name="execute_complete",
def __init__(self, **kwargs):
warnings.warn(
"Class `BigQueryTableExistenceAsyncSensor` is deprecated and "
"will be removed in a future release. "
"Please use `BigQueryTableExistenceSensor` and "
"set `deferrable` attribute to `True` instead",
DeprecationWarning,
)

def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}"
self.log.info("Sensor checks existence of table: %s", table_uri)
if event:
if event["status"] == "success":
return event["message"]
raise AirflowException(event["message"])
raise AirflowException("No event received in trigger callback")
super().__init__(deferrable=True, **kwargs)


class BigQueryTableExistencePartitionAsyncSensor(BigQueryTablePartitionExistenceSensor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,15 @@ use the ``{{ ds_nodash }}`` macro as the table name suffix.
:start-after: [START howto_sensor_bigquery_table]
:end-before: [END howto_sensor_bigquery_table]

Use the :class:`~airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceAsyncSensor`
(deferrable version) if you would like to free up the worker slots while the sensor is running.
Also you can use deferrable mode in this operator if you would like to free up the worker slots while the sensor is running.

.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_bigquery_table_defered]
:end-before: [END howto_sensor_bigquery_table_defered]

:class:`~airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceAsyncSensor`.
:class:`~airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceAsyncSensor` is deprecated and will be removed in a future release. Please use :class:`~airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor` and use the deferrable mode in that operator.

.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py
:language: python
Expand Down
114 changes: 90 additions & 24 deletions tests/providers/google/cloud/sensors/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,61 @@ def test_passing_arguments_to_hook(self, mock_hook):
project_id=TEST_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID
)

def test_execute_defered(self):
"""
Asserts that a task is deferred and a BigQueryTableExistenceTrigger will be fired
when the BigQueryTableExistenceAsyncSensor is executed.
"""
task = BigQueryTableExistenceSensor(
task_id="check_table_exists",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
deferrable=True,
)
with pytest.raises(TaskDeferred) as exc:
task.execute(context={})
assert isinstance(
exc.value.trigger, BigQueryTableExistenceTrigger
), "Trigger is not a BigQueryTableExistenceTrigger"

def test_excute_defered_failure(self):
"""Tests that an AirflowException is raised in case of error event"""
task = BigQueryTableExistenceSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
deferrable=True,
)
with pytest.raises(AirflowException):
task.execute_complete(context={}, event={"status": "error", "message": "test failure message"})

def test_execute_complete(self):
"""Asserts that logging occurs as expected"""
task = BigQueryTableExistenceSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
deferrable=True,
)
table_uri = f"{TEST_PROJECT_ID}:{TEST_DATASET_ID}.{TEST_TABLE_ID}"
with mock.patch.object(task.log, "info") as mock_log_info:
task.execute_complete(context={}, event={"status": "success", "message": "Job completed"})
mock_log_info.assert_called_with("Sensor checks existence of table: %s", table_uri)

def test_execute_defered_complete_event_none(self):
"""Asserts that logging occurs as expected"""
task = BigQueryTableExistenceSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
)
with pytest.raises(AirflowException):
task.execute_complete(context={}, event=None)


class TestBigqueryTablePartitionExistenceSensor:
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
Expand Down Expand Up @@ -171,17 +226,25 @@ def context():


class TestBigQueryTableExistenceAsyncSensor:
depcrecation_message = (
"Class `BigQueryTableExistenceAsyncSensor` is deprecated and "
"will be removed in a future release. "
"Please use `BigQueryTableExistenceSensor` and "
"set `deferrable` attribute to `True` instead"
)

def test_big_query_table_existence_sensor_async(self):
"""
Asserts that a task is deferred and a BigQueryTableExistenceTrigger will be fired
when the BigQueryTableExistenceAsyncSensor is executed.
"""
task = BigQueryTableExistenceAsyncSensor(
task_id="check_table_exists",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
)
with pytest.warns(DeprecationWarning, match=self.depcrecation_message):
task = BigQueryTableExistenceAsyncSensor(
task_id="check_table_exists",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
)
with pytest.raises(TaskDeferred) as exc:
task.execute(context={})
assert isinstance(
Expand All @@ -190,36 +253,39 @@ def test_big_query_table_existence_sensor_async(self):

def test_big_query_table_existence_sensor_async_execute_failure(self):
"""Tests that an AirflowException is raised in case of error event"""
task = BigQueryTableExistenceAsyncSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
)
with pytest.warns(DeprecationWarning, match=self.depcrecation_message):
task = BigQueryTableExistenceAsyncSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
)
with pytest.raises(AirflowException):
task.execute_complete(context={}, event={"status": "error", "message": "test failure message"})

def test_big_query_table_existence_sensor_async_execute_complete(self):
"""Asserts that logging occurs as expected"""
task = BigQueryTableExistenceAsyncSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
)
with pytest.warns(DeprecationWarning, match=self.depcrecation_message):
task = BigQueryTableExistenceAsyncSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
)
table_uri = f"{TEST_PROJECT_ID}:{TEST_DATASET_ID}.{TEST_TABLE_ID}"
with mock.patch.object(task.log, "info") as mock_log_info:
task.execute_complete(context={}, event={"status": "success", "message": "Job completed"})
mock_log_info.assert_called_with("Sensor checks existence of table: %s", table_uri)

def test_big_query_sensor_async_execute_complete_event_none(self):
"""Asserts that logging occurs as expected"""
task = BigQueryTableExistenceAsyncSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
)
with pytest.warns(DeprecationWarning, match=self.depcrecation_message):
task = BigQueryTableExistenceAsyncSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
)
with pytest.raises(AirflowException):
task.execute_complete(context={}, event=None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@
)
# [END howto_sensor_bigquery_table]

# [START howto_sensor_bigquery_table_defered]
check_table_exists: BaseOperator = BigQueryTableExistenceSensor(
task_id="check_table_exists_defered",
project_id=PROJECT_ID,
dataset_id=DATASET_NAME,
table_id=TABLE_NAME,
deferrable=True,
)
# [END howto_sensor_bigquery_table_defered]

# [START howto_sensor_async_bigquery_table]
check_table_exists_async = BigQueryTableExistenceAsyncSensor(
task_id="check_table_exists_async",
Expand Down

0 comments on commit 540a076

Please sign in to comment.