Skip to content

Commit

Permalink
Refactor BigQuery check operators (#8813)
Browse files Browse the repository at this point in the history
* Refactor BigQuery check operators

This commit applies some code formatting to existing BigQuery
check operators. It also adds location parameter to
BigQueryIntervalCheckOperator and BigQueryValueCheckOperator.

* fixup! Refactor BigQuery check operators
  • Loading branch information
turbaszek committed May 12, 2020
1 parent 7533378 commit 1d12c34
Showing 1 changed file with 79 additions and 50 deletions.
129 changes: 79 additions & 50 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@

BIGQUERY_JOB_DETAILS_LINK_FMT = 'https://console.cloud.google.com/bigquery?j={job_id}'

_DEPRECATION_MSG = "The bigquery_conn_id parameter has been deprecated. " \
"You should pass the gcp_conn_id parameter."


class BigQueryUIColors(enum.Enum):
"""Hex colors for BigQuery operators"""
Expand Down Expand Up @@ -120,8 +123,7 @@ class BigQueryCheckOperator(CheckOperator):
:param use_legacy_sql: Whether to use legacy SQL (true)
or standard SQL (false).
:type use_legacy_sql: bool
:param location: The geographic location of the job. Required except for
US and EU. See details at
:param location: The geographic location of the job. See details at:
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
:type location: str
"""
Expand All @@ -131,29 +133,32 @@ class BigQueryCheckOperator(CheckOperator):
ui_color = BigQueryUIColors.CHECK.value

@apply_defaults
def __init__(self,
sql: str,
gcp_conn_id: str = 'google_cloud_default',
bigquery_conn_id: Optional[str] = None,
use_legacy_sql: bool = True,
location=None,
*args, **kwargs) -> None:
def __init__(
self,
sql: str,
gcp_conn_id: str = 'google_cloud_default',
bigquery_conn_id: Optional[str] = None,
use_legacy_sql: bool = True,
location: Optional[str] = None,
*args,
**kwargs,
) -> None:
super().__init__(sql=sql, *args, **kwargs)
if bigquery_conn_id:
warnings.warn(
"The bigquery_conn_id parameter has been deprecated. You should pass "
"the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3)
warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
gcp_conn_id = bigquery_conn_id # type: ignore

self.gcp_conn_id = gcp_conn_id
self.sql = sql
self.use_legacy_sql = use_legacy_sql
self.location = location

def get_db_hook(self):
return BigQueryHook(bigquery_conn_id=self.gcp_conn_id,
use_legacy_sql=self.use_legacy_sql,
location=self.location)
def get_db_hook(self) -> BigQueryHook:
return BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
use_legacy_sql=self.use_legacy_sql,
location=self.location
)


class BigQueryValueCheckOperator(ValueCheckOperator):
Expand All @@ -170,36 +175,49 @@ class BigQueryValueCheckOperator(ValueCheckOperator):
:param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform.
This parameter has been deprecated. You should pass the gcp_conn_id parameter instead.
:type bigquery_conn_id: str
:param location: The geographic location of the job. See details at:
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
:type location: str
"""

template_fields = ('sql', 'gcp_conn_id', 'pass_value',)
template_ext = ('.sql',)
ui_color = BigQueryUIColors.CHECK.value

@apply_defaults
def __init__(self, sql: str,
pass_value: Any,
tolerance: Any = None,
gcp_conn_id: str = 'google_cloud_default',
bigquery_conn_id: Optional[str] = None,
use_legacy_sql: bool = True,
*args, **kwargs) -> None:
def __init__(
self,
sql: str,
pass_value: Any,
tolerance: Any = None,
gcp_conn_id: str = 'google_cloud_default',
bigquery_conn_id: Optional[str] = None,
use_legacy_sql: bool = True,
location: Optional[str] = None,
*args,
**kwargs,
) -> None:
super().__init__(
sql=sql, pass_value=pass_value, tolerance=tolerance,
*args, **kwargs)
sql=sql,
pass_value=pass_value,
tolerance=tolerance,
*args, **kwargs
)

if bigquery_conn_id:
warnings.warn(
"The bigquery_conn_id parameter has been deprecated. You should pass "
"the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3)
warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
gcp_conn_id = bigquery_conn_id

self.location = location
self.gcp_conn_id = gcp_conn_id
self.use_legacy_sql = use_legacy_sql

def get_db_hook(self):
return BigQueryHook(bigquery_conn_id=self.gcp_conn_id,
use_legacy_sql=self.use_legacy_sql)
def get_db_hook(self) -> BigQueryHook:
return BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
use_legacy_sql=self.use_legacy_sql,
location=self.location
)


class BigQueryIntervalCheckOperator(IntervalCheckOperator):
Expand Down Expand Up @@ -229,39 +247,50 @@ class BigQueryIntervalCheckOperator(IntervalCheckOperator):
:param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform.
This parameter has been deprecated. You should pass the gcp_conn_id parameter instead.
:type bigquery_conn_id: str
:param location: The geographic location of the job. See details at:
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
:type location: str
"""

template_fields = ('table', 'gcp_conn_id', 'sql1', 'sql2')
ui_color = BigQueryUIColors.CHECK.value

@apply_defaults
def __init__(self,
table: str,
metrics_thresholds: dict,
date_filter_column: str = 'ds',
days_back: SupportsAbs[int] = -7,
gcp_conn_id: str = 'google_cloud_default',
bigquery_conn_id: Optional[str] = None,
use_legacy_sql: bool = True,
*args,
**kwargs) -> None:
def __init__(
self,
table: str,
metrics_thresholds: dict,
date_filter_column: str = 'ds',
days_back: SupportsAbs[int] = -7,
gcp_conn_id: str = 'google_cloud_default',
bigquery_conn_id: Optional[str] = None,
use_legacy_sql: bool = True,
location: Optional[str] = None,
*args,
**kwargs,
) -> None:
super().__init__(
table=table, metrics_thresholds=metrics_thresholds,
date_filter_column=date_filter_column, days_back=days_back,
*args, **kwargs)
table=table,
metrics_thresholds=metrics_thresholds,
date_filter_column=date_filter_column,
days_back=days_back,
*args, **kwargs
)

if bigquery_conn_id:
warnings.warn(
"The bigquery_conn_id parameter has been deprecated. You should pass "
"the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3)
warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
gcp_conn_id = bigquery_conn_id

self.gcp_conn_id = gcp_conn_id
self.use_legacy_sql = use_legacy_sql
self.location = location

def get_db_hook(self):
return BigQueryHook(bigquery_conn_id=self.gcp_conn_id,
use_legacy_sql=self.use_legacy_sql)
def get_db_hook(self) -> BigQueryHook:
return BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
use_legacy_sql=self.use_legacy_sql,
location=self.location,
)


class BigQueryGetDataOperator(BaseOperator):
Expand Down

0 comments on commit 1d12c34

Please sign in to comment.