Skip to content

Commit

Permalink
Add job labels to bigquery check operators. (#14685)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmcarp committed Mar 11, 2021
1 parent 60373eb commit 943baff
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
5 changes: 5 additions & 0 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
bigquery_conn_id: Optional[str] = None,
api_resource_configs: Optional[Dict] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
labels: Optional[Dict] = None,
) -> None:
# To preserve backward compatibility
# TODO: remove one day
Expand All @@ -101,6 +102,7 @@ def __init__(
self.location = location
self.running_job_id = None # type: Optional[str]
self.api_resource_configs = api_resource_configs if api_resource_configs else {} # type Dict
self.labels = labels

def get_conn(self) -> "BigQueryConnection":
"""Returns a BigQuery PEP 249 connection object."""
Expand Down Expand Up @@ -2060,6 +2062,7 @@ def run_query(
if not self.project_id:
raise ValueError("The project_id should be set")

labels = labels or self.labels
schema_update_options = list(schema_update_options or [])

if time_partitioning is None:
Expand Down Expand Up @@ -2258,6 +2261,7 @@ def __init__(
api_resource_configs: Optional[Dict] = None,
location: Optional[str] = None,
num_retries: int = 5,
labels: Optional[Dict] = None,
) -> None:

super().__init__()
Expand All @@ -2270,6 +2274,7 @@ def __init__(
self.running_job_id = None # type: Optional[str]
self.location = location
self.num_retries = num_retries
self.labels = labels
self.hook = hook

def create_empty_table(self, *args, **kwargs) -> None:
Expand Down
16 changes: 16 additions & 0 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def get_db_hook(self) -> BigQueryHook:
use_legacy_sql=self.use_legacy_sql,
location=self.location,
impersonation_chain=self.impersonation_chain,
labels=self.labels,
)


Expand Down Expand Up @@ -152,12 +153,15 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type impersonation_chain: Union[str, Sequence[str]]
:param labels: a dictionary containing labels for the table, passed to BigQuery
:type labels: dict
"""

template_fields = (
'sql',
'gcp_conn_id',
'impersonation_chain',
'labels',
)
template_ext = ('.sql',)
ui_color = BigQueryUIColors.CHECK.value
Expand All @@ -172,6 +176,7 @@ def __init__(
use_legacy_sql: bool = True,
location: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
labels: Optional[dict] = None,
**kwargs,
) -> None:
super().__init__(sql=sql, **kwargs)
Expand All @@ -184,6 +189,7 @@ def __init__(
self.use_legacy_sql = use_legacy_sql
self.location = location
self.impersonation_chain = impersonation_chain
self.labels = labels


class BigQueryValueCheckOperator(_BigQueryDbHookMixin, SQLValueCheckOperator):
Expand Down Expand Up @@ -216,13 +222,16 @@ class BigQueryValueCheckOperator(_BigQueryDbHookMixin, SQLValueCheckOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type impersonation_chain: Union[str, Sequence[str]]
:param labels: a dictionary containing labels for the table, passed to BigQuery
:type labels: dict
"""

template_fields = (
'sql',
'gcp_conn_id',
'pass_value',
'impersonation_chain',
'labels',
)
template_ext = ('.sql',)
ui_color = BigQueryUIColors.CHECK.value
Expand All @@ -239,6 +248,7 @@ def __init__(
use_legacy_sql: bool = True,
location: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
labels: Optional[dict] = None,
**kwargs,
) -> None:
super().__init__(sql=sql, pass_value=pass_value, tolerance=tolerance, **kwargs)
Expand All @@ -251,6 +261,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.use_legacy_sql = use_legacy_sql
self.impersonation_chain = impersonation_chain
self.labels = labels


class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperator):
Expand Down Expand Up @@ -296,6 +307,8 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperat
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type impersonation_chain: Union[str, Sequence[str]]
:param labels: a dictionary containing labels for the table, passed to BigQuery
:type labels: dict
"""

template_fields = (
Expand All @@ -304,6 +317,7 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperat
'sql1',
'sql2',
'impersonation_chain',
'labels',
)
ui_color = BigQueryUIColors.CHECK.value

Expand All @@ -320,6 +334,7 @@ def __init__(
use_legacy_sql: bool = True,
location: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
labels: Optional[Dict] = None,
**kwargs,
) -> None:
super().__init__(
Expand All @@ -338,6 +353,7 @@ def __init__(
self.use_legacy_sql = use_legacy_sql
self.location = location
self.impersonation_chain = impersonation_chain
self.labels = labels


class BigQueryGetDataOperator(BaseOperator):
Expand Down

0 comments on commit 943baff

Please sign in to comment.