Skip to content

Commit

Permalink
Refactor SQL/BigQuery/Qubole/Druid Check operators (#12677)
Browse files Browse the repository at this point in the history
closes: #10271
related: #9844 #14184

This PR refactor SQL/BigQuery Check operators to reduce duplicated code:

create BaseSQLOperator: it standardizes how some of the generic SQL operators retrieve DB hook with the .get_db_hook() method
Add a database kwarg *CheckOperators for a consistent interface
create _BigQueryDbHookMixin to standardize the .get_db_hook() method for BigQuery
create _QuboleCheckOperatorMixin to remove duplicate code
replace <class-name>.template_fields with _get_template_fields in __getattribute__ to avoid hard coding class name, and reduce duplicate code
remove and deprecate DruidCheckOperator the same functionality can be achieved by SQLCheckOperator - the deprecation method is the same for PrestoCheckOperator
Misc:

Fix docstrings
Update deprecated Operator name and import path
Remove unnecessary if statements check parameters in SQLBranchOperator
  • Loading branch information
xinbinhuang committed Feb 26, 2021
1 parent aa28e4e commit 33214d9
Show file tree
Hide file tree
Showing 17 changed files with 353 additions and 431 deletions.
2 changes: 1 addition & 1 deletion airflow/operators/druid_check_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from airflow.providers.apache.druid.operators.druid_check import DruidCheckOperator # noqa

warnings.warn(
"This module is deprecated. Please use `airflow.providers.apache.druid.operators.druid_check`.",
"This module is deprecated. Please use `airflow.operators.sql.SQLCheckOperator`.",
DeprecationWarning,
stacklevel=2,
)
198 changes: 84 additions & 114 deletions airflow/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,59 @@
from distutils.util import strtobool
from typing import Any, Dict, Iterable, List, Mapping, Optional, SupportsAbs, Union

from cached_property import cached_property

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.hooks.dbapi import DbApiHook
from airflow.models import BaseOperator, SkipMixin
from airflow.utils.decorators import apply_defaults

ALLOWED_CONN_TYPE = {
"google_cloud_platform",
"jdbc",
"mssql",
"mysql",
"odbc",
"oracle",
"postgres",
"presto",
"snowflake",
"sqlite",
"vertica",
}


class SQLCheckOperator(BaseOperator):

class BaseSQLOperator(BaseOperator):
"""
This is a base class for generic SQL Operator to get a DB Hook
The provided method is .get_db_hook(). The default behavior will try to
retrieve the DB hook based on connection type.
You can custom the behavior by overriding the .get_db_hook() method.
"""

@apply_defaults
def __init__(self, *, conn_id: Optional[str] = None, database: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.conn_id = conn_id
self.database = database

@cached_property
def _hook(self):
"""Get DB Hook based on connection type"""
self.log.debug("Get connection for %s", self.conn_id)
conn = BaseHook.get_connection(self.conn_id)

hook = conn.get_hook()
if not isinstance(hook, DbApiHook):
raise AirflowException(
f'The connection type is not supported by {self.__class__.__name__}. '
f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}'
)

if self.database:
hook.schema = self.database

return hook

def get_db_hook(self) -> DbApiHook:
"""
Get the database hook for the connection.
:return: the database hook object.
:rtype: DbApiHook
"""
return self._hook


class SQLCheckOperator(BaseSQLOperator):
"""
Performs checks against a db. The ``SQLCheckOperator`` expects
a sql query that will return a single row. Each value on that
Expand Down Expand Up @@ -68,6 +100,10 @@ class SQLCheckOperator(BaseOperator):
:param sql: the sql to be executed. (templated)
:type sql: str
:param conn_id: the connection ID used to connect to the database.
:type conn_id: str
:param database: name of database which overwrite the defined one in connection
:type database: str
"""

template_fields: Iterable[str] = ("sql",)
Expand All @@ -78,9 +114,10 @@ class SQLCheckOperator(BaseOperator):
ui_color = "#fff7e6"

@apply_defaults
def __init__(self, *, sql: str, conn_id: Optional[str] = None, **kwargs) -> None:
super().__init__(**kwargs)
self.conn_id = conn_id
def __init__(
self, *, sql: str, conn_id: Optional[str] = None, database: Optional[str] = None, **kwargs
) -> None:
super().__init__(conn_id=conn_id, database=database, **kwargs)
self.sql = sql

def execute(self, context=None):
Expand All @@ -95,15 +132,6 @@ def execute(self, context=None):

self.log.info("Success.")

def get_db_hook(self):
"""
Get the database hook for the connection.
:return: the database hook object.
:rtype: DbApiHook
"""
return BaseHook.get_hook(conn_id=self.conn_id)


def _convert_to_float_if_possible(s):
"""
Expand All @@ -120,16 +148,16 @@ def _convert_to_float_if_possible(s):
return ret


class SQLValueCheckOperator(BaseOperator):
class SQLValueCheckOperator(BaseSQLOperator):
"""
Performs a simple value check using sql code.
Note that this is an abstract class and get_db_hook
needs to be defined. Whereas a get_db_hook is hook that gets a
single record from an external source.
:param sql: the sql to be executed. (templated)
:type sql: str
:param conn_id: the connection ID used to connect to the database.
:type conn_id: str
:param database: name of database which overwrite the defined one in connection
:type database: str
"""

__mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"}
Expand All @@ -151,11 +179,11 @@ def __init__(
pass_value: Any,
tolerance: Any = None,
conn_id: Optional[str] = None,
database: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(conn_id=conn_id, database=database, **kwargs)
self.sql = sql
self.conn_id = conn_id
self.pass_value = str(pass_value)
tol = _convert_to_float_if_possible(tolerance)
self.tol = tol if isinstance(tol, float) else None
Expand Down Expand Up @@ -212,27 +240,18 @@ def _get_numeric_matches(self, numeric_records, numeric_pass_value_conv):

return [record == numeric_pass_value_conv for record in numeric_records]

def get_db_hook(self):
"""
Get the database hook for the connection.

:return: the database hook object.
:rtype: DbApiHook
"""
return BaseHook.get_hook(conn_id=self.conn_id)


class SQLIntervalCheckOperator(BaseOperator):
class SQLIntervalCheckOperator(BaseSQLOperator):
"""
Checks that the values of metrics given as SQL expressions are within
a certain tolerance of the ones from days_back before.
Note that this is an abstract class and get_db_hook
needs to be defined. Whereas a get_db_hook is hook that gets a
single record from an external source.
:param table: the table name
:type table: str
:param conn_id: the connection ID used to connect to the database.
:type conn_id: str
:param database: name of database which overwrite the defined one in connection
:type database: str
:param days_back: number of days between ds and the ds we want to check
against. Defaults to 7 days
:type days_back: int
Expand Down Expand Up @@ -275,9 +294,10 @@ def __init__(
ratio_formula: Optional[str] = "max_over_min",
ignore_zero: bool = True,
conn_id: Optional[str] = None,
database: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(conn_id=conn_id, database=database, **kwargs)
if ratio_formula not in self.ratio_formulas:
msg_template = "Invalid diff_method: {diff_method}. Supported diff methods are: {diff_methods}"

Expand All @@ -291,7 +311,6 @@ def __init__(
self.metrics_sorted = sorted(metrics_thresholds.keys())
self.date_filter_column = date_filter_column
self.days_back = -abs(days_back)
self.conn_id = conn_id
sqlexp = ", ".join(self.metrics_sorted)
sqlt = f"SELECT {sqlexp} FROM {table} WHERE {date_filter_column}="

Expand Down Expand Up @@ -362,28 +381,19 @@ def execute(self, context=None):

self.log.info("All tests have passed")

def get_db_hook(self):
"""
Get the database hook for the connection.

:return: the database hook object.
:rtype: DbApiHook
"""
return BaseHook.get_hook(conn_id=self.conn_id)


class SQLThresholdCheckOperator(BaseOperator):
class SQLThresholdCheckOperator(BaseSQLOperator):
"""
Performs a value check using sql code against a minimum threshold
and a maximum threshold. Thresholds can be in the form of a numeric
value OR a sql statement that results a numeric.
Note that this is an abstract class and get_db_hook
needs to be defined. Whereas a get_db_hook is hook that gets a
single record from an external source.
:param sql: the sql to be executed. (templated)
:type sql: str
:param conn_id: the connection ID used to connect to the database.
:type conn_id: str
:param database: name of database which overwrite the defined one in connection
:type database: str
:param min_threshold: numerical value or min threshold sql to be executed (templated)
:type min_threshold: numeric or str
:param max_threshold: numerical value or max threshold sql to be executed (templated)
Expand All @@ -404,11 +414,11 @@ def __init__(
min_threshold: Any,
max_threshold: Any,
conn_id: Optional[str] = None,
database: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(conn_id=conn_id, database=database, **kwargs)
self.sql = sql
self.conn_id = conn_id
self.min_threshold = _convert_to_float_if_possible(min_threshold)
self.max_threshold = _convert_to_float_if_possible(max_threshold)

Expand Down Expand Up @@ -456,12 +466,8 @@ def push(self, meta_data):
info = "\n".join([f"""{key}: {item}""" for key, item in meta_data.items()])
self.log.info("Log from %s:\n%s", self.dag_id, info)

def get_db_hook(self):
"""Returns DB hook"""
return BaseHook.get_hook(conn_id=self.conn_id)


class BranchSQLOperator(BaseOperator, SkipMixin):
class BranchSQLOperator(BaseSQLOperator, SkipMixin):
"""
Executes sql code in a specific database
Expand All @@ -474,9 +480,10 @@ class BranchSQLOperator(BaseOperator, SkipMixin):
:type follow_task_ids_if_true: str or list
:param follow_task_ids_if_false: task id or task ids to follow if query return true
:type follow_task_ids_if_false: str or list
:param conn_id: reference to a specific database
:param conn_id: the connection ID used to connect to the database.
:type conn_id: str
:param database: name of database which overwrite defined one in connection
:param database: name of database which overwrite the defined one in connection
:type database: str
:param parameters: (optional) the parameters to render the SQL query with.
:type parameters: mapping or iterable
"""
Expand All @@ -498,57 +505,20 @@ def __init__(
parameters: Optional[Union[Mapping, Iterable]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.conn_id = conn_id
super().__init__(conn_id=conn_id, database=database, **kwargs)
self.sql = sql
self.parameters = parameters
self.follow_task_ids_if_true = follow_task_ids_if_true
self.follow_task_ids_if_false = follow_task_ids_if_false
self.database = database
self._hook = None

def _get_hook(self):
self.log.debug("Get connection for %s", self.conn_id)
conn = BaseHook.get_connection(self.conn_id)

if conn.conn_type not in ALLOWED_CONN_TYPE:
raise AirflowException(
"The connection type is not supported by BranchSQLOperator.\
Supported connection types: {}".format(
list(ALLOWED_CONN_TYPE)
)
)

if not self._hook:
self._hook = conn.get_hook()
if self.database:
self._hook.schema = self.database

return self._hook

def execute(self, context: Dict):
# get supported hook
self._hook = self._get_hook()

if self._hook is None:
raise AirflowException(f"Failed to establish connection to '{self.conn_id}'")

if self.sql is None:
raise AirflowException("Expected 'sql' parameter is missing.")

if self.follow_task_ids_if_true is None:
raise AirflowException("Expected 'follow_task_ids_if_true' parameter is missing.")

if self.follow_task_ids_if_false is None:
raise AirflowException("Expected 'follow_task_ids_if_false' parameter is missing.")

self.log.info(
"Executing: %s (with parameters %s) with connection: %s",
self.sql,
self.parameters,
self._hook,
self.conn_id,
)
record = self._hook.get_first(self.sql, self.parameters)
record = self.get_db_hook().get_first(self.sql, self.parameters)
if not record:
raise AirflowException(
"No rows returned from sql query. Operator expected True or False return value."
Expand Down
2 changes: 2 additions & 0 deletions airflow/providers/apache/druid/hooks/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class DruidDbApiHook(DbApiHook):

conn_name_attr = 'druid_broker_conn_id'
default_conn_name = 'druid_broker_default'
conn_type = 'druid'
hook_name = 'Druid'
supports_autocommit = False

def get_conn(self) -> connect:
Expand Down

0 comments on commit 33214d9

Please sign in to comment.