Skip to content

Commit

Permalink
Fix BigQueryValueCheckOperator deferrable mode optimisation (#34018)
Browse files Browse the repository at this point in the history
PR #31872 tried to optimise the deferrable mode in BigQueryValueCheckOperator.
However for deciding on whether to defer it just checked the
job status but did not actually verified the passed value
to check for and returned a success prematurely.
This PR adds on the missing logic with the optimisation to check
and compare the pass value and tolerations.

closes: #34010
  • Loading branch information
pankajkoti committed Sep 3, 2023
1 parent 6ef80e8 commit d757f6a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
10 changes: 6 additions & 4 deletions airflow/providers/common/sql/operators/sql.py
Expand Up @@ -827,10 +827,7 @@ def __init__(
self.tol = tol if isinstance(tol, float) else None
self.has_tolerance = self.tol is not None

def execute(self, context: Context):
self.log.info("Executing SQL check: %s", self.sql)
records = self.get_db_hook().get_first(self.sql)

def check_value(self, records):
if not records:
self._raise_exception(f"The following query returned zero rows: {self.sql}")

Expand Down Expand Up @@ -862,6 +859,11 @@ def execute(self, context: Context):
if not all(tests):
self._raise_exception(error_msg)

def execute(self, context: Context):
self.log.info("Executing SQL check: %s", self.sql)
records = self.get_db_hook().get_first(self.sql)
self.check_value(records)

def _to_float(self, records):
return [float(record) for record in records]

Expand Down
4 changes: 4 additions & 0 deletions airflow/providers/google/cloud/operators/bigquery.py
Expand Up @@ -443,6 +443,10 @@ def execute(self, context: Context) -> None: # type: ignore[override]
method_name="execute_complete",
)
self._handle_job_error(job)
# job.result() returns a RowIterator. Mypy expects an instance of SupportsNext[Any] for
# the next() call which the RowIterator does not resemble to. Hence, ignore the arg-type error.
records = next(job.result()) # type: ignore[arg-type]
self.check_value(records)
self.log.info("Current state of job %s is %s", job.job_id, job.state)

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/google/cloud/operators/test_bigquery.py
Expand Up @@ -1919,11 +1919,11 @@ def test_bigquery_value_check_async(self, mock_hook, create_task_instance_of_ope
exc.value.trigger, BigQueryValueCheckTrigger
), "Trigger is not a BigQueryValueCheckTrigger"

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.execute")
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.defer")
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.check_value")
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_value_check_operator_async_finish_before_deferred(
self, mock_hook, mock_defer, mock_execute, create_task_instance_of_operator
self, mock_hook, mock_check_value, mock_defer, create_task_instance_of_operator
):
job_id = "123456"
hash_ = "hash"
Expand All @@ -1944,7 +1944,7 @@ def test_bigquery_value_check_operator_async_finish_before_deferred(

ti.task.execute(MagicMock())
assert not mock_defer.called
assert mock_execute.called
assert mock_check_value.called

@pytest.mark.parametrize(
"kwargs, expected",
Expand Down

0 comments on commit d757f6a

Please sign in to comment.