Skip to content

Commit

Permalink
Suppress hook warnings from the Bigquery transfers (#20119)
Browse files Browse the repository at this point in the history
  • Loading branch information
kazanzhy committed Feb 28, 2022
1 parent e782b37 commit 0c55ca2
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 139 deletions.
21 changes: 11 additions & 10 deletions airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,14 @@ def execute(self, context: 'Context') -> None:
location=self.location,
impersonation_chain=self.impersonation_chain,
)
conn = hook.get_conn()
cursor = conn.cursor()
cursor.run_copy(
source_project_dataset_tables=self.source_project_dataset_tables,
destination_project_dataset_table=self.destination_project_dataset_table,
write_disposition=self.write_disposition,
create_disposition=self.create_disposition,
labels=self.labels,
encryption_configuration=self.encryption_configuration,
)

with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
hook.run_copy(
source_project_dataset_tables=self.source_project_dataset_tables,
destination_project_dataset_table=self.destination_project_dataset_table,
write_disposition=self.write_disposition,
create_disposition=self.create_disposition,
labels=self.labels,
encryption_configuration=self.encryption_configuration,
)
133 changes: 72 additions & 61 deletions airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""This module contains a Google Cloud Storage to BigQuery operator."""

import json
import warnings
from typing import TYPE_CHECKING, Optional, Sequence, Union

from airflow.models import BaseOperator
Expand Down Expand Up @@ -163,8 +164,9 @@ def __init__(
allow_jagged_rows=False,
encoding="UTF-8",
max_id_key=None,
bigquery_conn_id='google_cloud_default',
google_cloud_storage_conn_id='google_cloud_default',
gcp_conn_id='google_cloud_default',
bigquery_conn_id=None,
google_cloud_storage_conn_id=None,
delegate_to=None,
schema_update_options=(),
src_fmt_configs=None,
Expand All @@ -179,6 +181,15 @@ def __init__(
description=None,
**kwargs,
):
# To preserve backward compatibility. Remove one day
if bigquery_conn_id or google_cloud_storage_conn_id:
warnings.warn(
"The bigquery_conn_id and google_cloud_storage_conn_id parameters have been deprecated. "
"You should pass only gcp_conn_id parameter. "
"Will be used bigquery_conn_id or google_cloud_storage_conn_id if gcp_conn_id not passed.",
DeprecationWarning,
stacklevel=2,
)

super().__init__(**kwargs)

Expand Down Expand Up @@ -209,8 +220,7 @@ def __init__(
self.encoding = encoding

self.max_id_key = max_id_key
self.bigquery_conn_id = bigquery_conn_id
self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
self.gcp_conn_id = gcp_conn_id or bigquery_conn_id or google_cloud_storage_conn_id
self.delegate_to = delegate_to

self.schema_update_options = schema_update_options
Expand All @@ -227,7 +237,7 @@ def __init__(

def execute(self, context: 'Context'):
bq_hook = BigQueryHook(
bigquery_conn_id=self.bigquery_conn_id,
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
location=self.location,
impersonation_chain=self.impersonation_chain,
Expand All @@ -236,7 +246,7 @@ def execute(self, context: 'Context'):
if not self.schema_fields:
if self.schema_object and self.source_format != 'DATASTORE_BACKUP':
gcs_hook = GCSHook(
gcp_conn_id=self.google_cloud_storage_conn_id,
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
Expand All @@ -247,72 +257,73 @@ def execute(self, context: 'Context'):
schema_fields = json.loads(blob.decode("utf-8"))
else:
schema_fields = None

else:
schema_fields = self.schema_fields

self.source_objects = (
self.source_objects if isinstance(self.source_objects, list) else [self.source_objects]
)
source_uris = [f'gs://{self.bucket}/{source_object}' for source_object in self.source_objects]
conn = bq_hook.get_conn()
cursor = conn.cursor()

if self.external_table:
cursor.create_external_table(
external_project_dataset_table=self.destination_project_dataset_table,
schema_fields=schema_fields,
source_uris=source_uris,
source_format=self.source_format,
compression=self.compression,
skip_leading_rows=self.skip_leading_rows,
field_delimiter=self.field_delimiter,
max_bad_records=self.max_bad_records,
quote_character=self.quote_character,
ignore_unknown_values=self.ignore_unknown_values,
allow_quoted_newlines=self.allow_quoted_newlines,
allow_jagged_rows=self.allow_jagged_rows,
encoding=self.encoding,
src_fmt_configs=self.src_fmt_configs,
encryption_configuration=self.encryption_configuration,
labels=self.labels,
description=self.description,
)
else:
cursor.run_load(
destination_project_dataset_table=self.destination_project_dataset_table,
schema_fields=schema_fields,
source_uris=source_uris,
source_format=self.source_format,
autodetect=self.autodetect,
create_disposition=self.create_disposition,
skip_leading_rows=self.skip_leading_rows,
write_disposition=self.write_disposition,
field_delimiter=self.field_delimiter,
max_bad_records=self.max_bad_records,
quote_character=self.quote_character,
ignore_unknown_values=self.ignore_unknown_values,
allow_quoted_newlines=self.allow_quoted_newlines,
allow_jagged_rows=self.allow_jagged_rows,
encoding=self.encoding,
schema_update_options=self.schema_update_options,
src_fmt_configs=self.src_fmt_configs,
time_partitioning=self.time_partitioning,
cluster_fields=self.cluster_fields,
encryption_configuration=self.encryption_configuration,
labels=self.labels,
description=self.description,
)

if cursor.use_legacy_sql:
escaped_table_name = f'[{self.destination_project_dataset_table}]'
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
bq_hook.create_external_table(
external_project_dataset_table=self.destination_project_dataset_table,
schema_fields=schema_fields,
source_uris=source_uris,
source_format=self.source_format,
compression=self.compression,
skip_leading_rows=self.skip_leading_rows,
field_delimiter=self.field_delimiter,
max_bad_records=self.max_bad_records,
quote_character=self.quote_character,
ignore_unknown_values=self.ignore_unknown_values,
allow_quoted_newlines=self.allow_quoted_newlines,
allow_jagged_rows=self.allow_jagged_rows,
encoding=self.encoding,
src_fmt_configs=self.src_fmt_configs,
encryption_configuration=self.encryption_configuration,
labels=self.labels,
description=self.description,
)
else:
escaped_table_name = f'`{self.destination_project_dataset_table}`'
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
bq_hook.run_load(
destination_project_dataset_table=self.destination_project_dataset_table,
schema_fields=schema_fields,
source_uris=source_uris,
source_format=self.source_format,
autodetect=self.autodetect,
create_disposition=self.create_disposition,
skip_leading_rows=self.skip_leading_rows,
write_disposition=self.write_disposition,
field_delimiter=self.field_delimiter,
max_bad_records=self.max_bad_records,
quote_character=self.quote_character,
ignore_unknown_values=self.ignore_unknown_values,
allow_quoted_newlines=self.allow_quoted_newlines,
allow_jagged_rows=self.allow_jagged_rows,
encoding=self.encoding,
schema_update_options=self.schema_update_options,
src_fmt_configs=self.src_fmt_configs,
time_partitioning=self.time_partitioning,
cluster_fields=self.cluster_fields,
encryption_configuration=self.encryption_configuration,
labels=self.labels,
description=self.description,
)

if self.max_id_key:
select_command = f'SELECT MAX({self.max_id_key}) FROM {escaped_table_name}'
cursor.execute(select_command)
row = cursor.fetchone()
select_command = f'SELECT MAX({self.max_id_key}) FROM `{self.destination_project_dataset_table}`'
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
job_id = bq_hook.run_query(
sql=select_command,
use_legacy_sql=False,
)
row = list(bq_hook.get_job(job_id).result())
if row:
max_id = row[0] if row[0] else 0
self.log.info(
Expand All @@ -322,4 +333,4 @@ def execute(self, context: 'Context'):
max_id,
)
else:
raise RuntimeError(f"The f{select_command} returned no rows!")
raise RuntimeError(f"The {select_command} returned no rows!")
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_execute(self, mock_hook):
)

operator.execute(None)
mock_hook.return_value.get_conn.return_value.cursor.return_value.run_copy.assert_called_once_with(
mock_hook.return_value.run_copy.assert_called_once_with(
source_project_dataset_tables=source_project_dataset_tables,
destination_project_dataset_table=destination_project_dataset_table,
write_disposition=write_disposition,
Expand Down

0 comments on commit 0c55ca2

Please sign in to comment.