Skip to content

Commit

Permalink
Add BigQuery Column and Table Check Operators (#26368)
Browse files Browse the repository at this point in the history
* Add Column and Table Check Operators

Add two new operators based on the SQLColumnCheckOperator and
SQLTableCheckOperator that also provide job_ids so results
of the queries can be pulled and parsed, and so OpenLineage
can parse datasets and provide lineage information.
  • Loading branch information
denimalpaca committed Sep 21, 2022
1 parent 3cd4df1 commit c4256ca
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 0 deletions.
239 changes: 239 additions & 0 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@
from airflow.models.xcom import XCom
from airflow.providers.common.sql.operators.sql import (
SQLCheckOperator,
SQLColumnCheckOperator,
SQLIntervalCheckOperator,
SQLTableCheckOperator,
SQLValueCheckOperator,
_get_failed_checks,
parse_boolean,
)
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
Expand Down Expand Up @@ -520,6 +524,241 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
)


class BigQueryColumnCheckOperator(_BigQueryDbHookMixin, SQLColumnCheckOperator):
"""
BigQueryColumnCheckOperator subclasses the SQLColumnCheckOperator
in order to provide a job id for OpenLineage to parse. See base class
docstring for usage.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BigQueryColumnCheckOperator`
:param table: the table name
:param column_mapping: a dictionary relating columns to their checks
:param partition_clause: a string SQL statement added to a WHERE clause
to partition data
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
:param use_legacy_sql: Whether to use legacy SQL (true)
or standard SQL (false).
:param location: The geographic location of the job. See details at:
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param labels: a dictionary containing labels for the table, passed to BigQuery
"""

def __init__(
self,
*,
table: str,
column_mapping: dict,
partition_clause: str | None = None,
gcp_conn_id: str = "google_cloud_default",
use_legacy_sql: bool = True,
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
labels: dict | None = None,
**kwargs,
) -> None:
super().__init__(
table=table, column_mapping=column_mapping, partition_clause=partition_clause, **kwargs
)
self.table = table
self.column_mapping = column_mapping
self.partition_clause = partition_clause
self.gcp_conn_id = gcp_conn_id
self.use_legacy_sql = use_legacy_sql
self.location = location
self.impersonation_chain = impersonation_chain
self.labels = labels
# OpenLineage needs a valid SQL query with the input/output table(s) to parse
self.sql = ""

def _submit_job(
self,
hook: BigQueryHook,
job_id: str,
) -> BigQueryJob:
"""Submit a new job and get the job id for polling the status using Trigger."""
configuration = {"query": {"query": self.sql}}

return hook.insert_job(
configuration=configuration,
project_id=hook.project_id,
location=self.location,
job_id=job_id,
nowait=False,
)

def execute(self, context=None):
"""Perform checks on the given columns."""
hook = self.get_db_hook()
failed_tests = []
for column in self.column_mapping:
checks = [*self.column_mapping[column]]
checks_sql = ",".join([self.column_checks[check].replace("column", column) for check in checks])
partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else ""
self.sql = f"SELECT {checks_sql} FROM {self.table} {partition_clause_statement};"

job_id = hook.generate_job_id(
dag_id=self.dag_id,
task_id=self.task_id,
logical_date=context["logical_date"],
configuration=self.configuration,
)
job = self._submit_job(hook, job_id=job_id)
context["ti"].xcom_push(key="job_id", value=job.job_id)
records = list(job.result().to_dataframe().values.flatten())

if not records:
raise AirflowException(f"The following query returned zero rows: {self.sql}")

self.log.info("Record: %s", records)

for idx, result in enumerate(records):
tolerance = self.column_mapping[column][checks[idx]].get("tolerance")

self.column_mapping[column][checks[idx]]["result"] = result
self.column_mapping[column][checks[idx]]["success"] = self._get_match(
self.column_mapping[column][checks[idx]], result, tolerance
)

failed_tests.extend(_get_failed_checks(self.column_mapping[column], column))
if failed_tests:
raise AirflowException(
f"Test failed.\nResults:\n{records!s}\n"
"The following tests have failed:"
f"\n{''.join(failed_tests)}"
)

self.log.info("All tests have passed")


class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator):
"""
BigQueryTableCheckOperator subclasses the SQLTableCheckOperator
in order to provide a job id for OpenLineage to parse. See base class
for usage.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BigQueryTableCheckOperator`
:param table: the table name
:param checks: a dictionary of check names and boolean SQL statements
:param partition_clause: a string SQL statement added to a WHERE clause
to partition data
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
:param use_legacy_sql: Whether to use legacy SQL (true)
or standard SQL (false).
:param location: The geographic location of the job. See details at:
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param labels: a dictionary containing labels for the table, passed to BigQuery
"""

def __init__(
self,
*,
table: str,
checks: dict,
partition_clause: str | None = None,
gcp_conn_id: str = "google_cloud_default",
use_legacy_sql: bool = True,
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
labels: dict | None = None,
**kwargs,
) -> None:
super().__init__(table=table, checks=checks, partition_clause=partition_clause, **kwargs)
self.table = table
self.checks = checks
self.partition_clause = partition_clause
self.gcp_conn_id = gcp_conn_id
self.use_legacy_sql = use_legacy_sql
self.location = location
self.impersonation_chain = impersonation_chain
self.labels = labels
# OpenLineage needs a valid SQL query with the input/output table(s) to parse
self.sql = ""

def _submit_job(
self,
hook: BigQueryHook,
job_id: str,
) -> BigQueryJob:
"""Submit a new job and get the job id for polling the status using Trigger."""
configuration = {"query": {"query": self.sql}}

return hook.insert_job(
configuration=configuration,
project_id=hook.project_id,
location=self.location,
job_id=job_id,
nowait=False,
)

def execute(self, context=None):
"""Execute the given checks on the table."""
hook = self.get_db_hook()
checks_sql = " UNION ALL ".join(
[
self.sql_check_template.replace("check_statement", value["check_statement"])
.replace("_check_name", check_name)
.replace("table", self.table)
for check_name, value in self.checks.items()
]
)
partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else ""
self.sql = f"SELECT check_name, check_result FROM ({checks_sql}) "
f"AS check_table {partition_clause_statement};"

job_id = hook.generate_job_id(
dag_id=self.dag_id,
task_id=self.task_id,
logical_date=context["logical_date"],
configuration=self.configuration,
)
job = self._submit_job(hook, job_id=job_id)
context["ti"].xcom_push(key="job_id", value=job.job_id)
records = job.result().to_dataframe()

if records.empty:
raise AirflowException(f"The following query returned zero rows: {self.sql}")

records.columns = records.columns.str.lower()
self.log.info("Record:\n%s", records)

for row in records.iterrows():
check = row[1].get("check_name")
result = row[1].get("check_result")
self.checks[check]["success"] = parse_boolean(str(result))

failed_tests = _get_failed_checks(self.checks)
if failed_tests:
raise AirflowException(
f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n"
"The following tests have failed:"
f"\n{', '.join(failed_tests)}"
)

self.log.info("All tests have passed")


class BigQueryGetDataOperator(BaseOperator):
"""
Fetches the data from a BigQuery table (alternatively fetch data for selected columns)
Expand Down
28 changes: 28 additions & 0 deletions docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,34 @@ Also you can use deferrable mode in this operator
:start-after: [START howto_operator_bigquery_interval_check_async]
:end-before: [END howto_operator_bigquery_interval_check_async]

.. _howto/operator:BigQueryColumnCheckOperator:

Check columns with predefined tests
"""""""""""""""""""""""""""""""""""

To check that columns pass user-configurable tests you can use
:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryColumnCheckOperator`

.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
:language: python
:dedent: 4
:start-after: [START howto_operator_bigquery_column_check]
:end-before: [END howto_operator_bigquery_column_check]

.. _howto/operator:BigQueryTableCheckOperator:

Check table level data quality
""""""""""""""""""""""""""""""

To check that tables pass user-defined tests you can use
:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryTableCheckOperator`

.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
:language: python
:dedent: 4
:start-after: [START howto_operator_bigquery_table_check]
:end-before: [END howto_operator_bigquery_table_check]

Sensors
^^^^^^^

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
from airflow.operators.bash import BashOperator
from airflow.providers.google.cloud.operators.bigquery import (
BigQueryCheckOperator,
BigQueryColumnCheckOperator,
BigQueryCreateEmptyDatasetOperator,
BigQueryCreateEmptyTableOperator,
BigQueryDeleteDatasetOperator,
BigQueryGetDataOperator,
BigQueryInsertJobOperator,
BigQueryIntervalCheckOperator,
BigQueryTableCheckOperator,
BigQueryValueCheckOperator,
)
from airflow.utils.trigger_rule import TriggerRule
Expand Down Expand Up @@ -209,6 +211,22 @@
)
# [END howto_operator_bigquery_interval_check]

# [START howto_operator_bigquery_column_check]
column_check = BigQueryColumnCheckOperator(
task_id="column_check",
table=f"{DATASET}.{TABLE_1}",
column_mapping={"value": {"null_check": {"equal_to": 0}}},
)
# [END howto_operator_bigquery_column_check]

# [START howto_operator_bigquery_table_check]
table_check = BigQueryTableCheckOperator(
task_id="table_check",
table=f"{DATASET}.{TABLE_1}",
checks={"row_count_check": {"check_statement": {"COUNT(*) = 4"}}},
)
# [END howto_operator_bigquery_table_check]

delete_dataset = BigQueryDeleteDatasetOperator(
task_id="delete_dataset",
dataset_id=DATASET,
Expand Down

0 comments on commit c4256ca

Please sign in to comment.