Skip to content

Commit

Permalink
GCS to BigQuery Transfer Operator with Labels and Description paramet…
Browse files Browse the repository at this point in the history
…er (#14881)

This adds the following optional parameters to the GCS to BQ transfer operator:

labels
description
These can be set by users to update labels and/or description information in the destination BigQuery table.
  • Loading branch information
fhoda committed Mar 22, 2021
1 parent 5d96eb0 commit 72ea841
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 1 deletion.
23 changes: 22 additions & 1 deletion airflow/providers/google/cloud/hooks/bigquery.py
Expand Up @@ -531,6 +531,7 @@ def create_external_table( # pylint: disable=too-many-locals,too-many-arguments
encoding: str = "UTF-8",
src_fmt_configs: Optional[Dict] = None,
labels: Optional[Dict] = None,
description: Optional[str] = None,
encryption_configuration: Optional[Dict] = None,
location: Optional[str] = None,
project_id: Optional[str] = None,
Expand Down Expand Up @@ -599,8 +600,10 @@ def create_external_table( # pylint: disable=too-many-locals,too-many-arguments
:type encoding: str
:param src_fmt_configs: configure optional fields specific to the source format
:type src_fmt_configs: dict
:param labels: a dictionary containing labels for the table, passed to BigQuery
:param labels: A dictionary containing labels for the BiqQuery table.
:type labels: dict
:param description: A string containing the description for the BigQuery table.
:type descriptin: str
:param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys).
**Example**: ::
Expand Down Expand Up @@ -669,6 +672,9 @@ def create_external_table( # pylint: disable=too-many-locals,too-many-arguments
if labels:
table.labels = labels

if description:
table.description = description

if encryption_configuration:
table.encryption_configuration = EncryptionConfiguration.from_api_repr(encryption_configuration)

Expand Down Expand Up @@ -1560,6 +1566,8 @@ def run_load( # pylint: disable=too-many-locals,too-many-arguments,invalid-name
cluster_fields: Optional[List] = None,
autodetect: bool = False,
encryption_configuration: Optional[Dict] = None,
labels: Optional[Dict] = None,
description: Optional[str] = None,
) -> str:
"""
Executes a BigQuery load command to load data from Google Cloud Storage
Expand Down Expand Up @@ -1642,6 +1650,10 @@ def run_load( # pylint: disable=too-many-locals,too-many-arguments,invalid-name
"kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key"
}
:type encryption_configuration: dict
:param labels: A dictionary containing labels for the BiqQuery table.
:type labels: dict
:param description: A string containing the description for the BigQuery table.
:type descriptin: str
"""
warnings.warn(
"This method is deprecated. Please use `BigQueryHook.insert_job` method.", DeprecationWarning
Expand Down Expand Up @@ -1742,6 +1754,15 @@ def run_load( # pylint: disable=too-many-locals,too-many-arguments,invalid-name
if encryption_configuration:
configuration["load"]["destinationEncryptionConfiguration"] = encryption_configuration

if labels or description:
configuration['load'].update({'destinationTableProperties': {}})

if labels:
configuration['load']['destinationTableProperties']['labels'] = labels

if description:
configuration['load']['destinationTableProperties']['description'] = description

src_fmt_to_configs_mapping = {
'CSV': [
'allowJaggedRows',
Expand Down
13 changes: 13 additions & 0 deletions airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
Expand Up @@ -157,6 +157,10 @@ class GCSToBigQueryOperator(BaseOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type impersonation_chain: Union[str, Sequence[str]]
:param labels: [Optional] Labels for the BiqQuery table.
:type labels: dict
:param description: [Optional] Description for the BigQuery table.
:type description: str
"""

template_fields = (
Expand Down Expand Up @@ -204,6 +208,8 @@ def __init__(
encryption_configuration=None,
location=None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
labels=None,
description=None,
**kwargs,
):

Expand Down Expand Up @@ -249,6 +255,9 @@ def __init__(
self.location = location
self.impersonation_chain = impersonation_chain

self.labels = labels
self.description = description

def execute(self, context):
bq_hook = BigQueryHook(
bigquery_conn_id=self.bigquery_conn_id,
Expand Down Expand Up @@ -300,6 +309,8 @@ def execute(self, context):
encoding=self.encoding,
src_fmt_configs=self.src_fmt_configs,
encryption_configuration=self.encryption_configuration,
labels=self.labels,
description=self.description,
)
else:
cursor.run_load(
Expand All @@ -323,6 +334,8 @@ def execute(self, context):
time_partitioning=self.time_partitioning,
cluster_fields=self.cluster_fields,
encryption_configuration=self.encryption_configuration,
labels=self.labels,
description=self.description,
)

if cursor.use_legacy_sql:
Expand Down
58 changes: 58 additions & 0 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Expand Up @@ -1779,3 +1779,61 @@ def test_deprecation_warning(self, func_name, mock_bq_hook):

mocked_func.assert_called_once_with(*args, **kwargs)
assert re.search(f".*{new_path}.*", func.__doc__)


class TestBigQueryWithLabelsAndDescription(_BigQueryBaseTestClass):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_run_load_labels(self, mock_insert):

labels = {'label1': 'test1', 'label2': 'test2'}
self.hook.run_load(
destination_project_dataset_table='my_dataset.my_table',
schema_fields=[],
source_uris=[],
labels=labels,
)

_, kwargs = mock_insert.call_args
assert kwargs["configuration"]['load']['destinationTableProperties']['labels'] is labels

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_run_load_description(self, mock_insert):

description = "Test Description"
self.hook.run_load(
destination_project_dataset_table='my_dataset.my_table',
schema_fields=[],
source_uris=[],
description=description,
)

_, kwargs = mock_insert.call_args
assert kwargs["configuration"]['load']['destinationTableProperties']['description'] is description

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_table")
def test_create_external_table_labels(self, mock_create):

labels = {'label1': 'test1', 'label2': 'test2'}
self.hook.create_external_table(
external_project_dataset_table='my_dataset.my_table',
schema_fields=[],
source_uris=[],
labels=labels,
)

_, kwargs = mock_create.call_args
self.assertDictEqual(kwargs['table_resource']['labels'], labels)

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_table")
def test_create_external_table_description(self, mock_create):

description = "Test Description"
self.hook.create_external_table(
external_project_dataset_table='my_dataset.my_table',
schema_fields=[],
source_uris=[],
description=description,
)

_, kwargs = mock_create.call_args
assert kwargs['table_resource']['description'] is description
150 changes: 150 additions & 0 deletions tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
Expand Up @@ -26,6 +26,8 @@
TEST_BUCKET = 'test-bucket'
MAX_ID_KEY = 'id'
TEST_SOURCE_OBJECTS = ['test/objects/*']
LABELS = {'k1': 'v1'}
DESCRIPTION = "Test Description"


class TestGoogleCloudStorageToBigQueryOperator(unittest.TestCase):
Expand Down Expand Up @@ -66,3 +68,151 @@ def test_execute_explicit_project(self, bq_hook):
bq_hook.return_value.get_conn.return_value.cursor.return_value.execute.assert_called_once_with(
"SELECT MAX(id) FROM `test-project.dataset.table`"
)

@mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook')
def test_labels(self, bq_hook):

operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
labels=LABELS,
)

operator.execute(None)

bq_hook.return_value.get_conn.return_value.cursor.return_value.run_load.assert_called_once_with(
destination_project_dataset_table=mock.ANY,
schema_fields=mock.ANY,
source_uris=mock.ANY,
source_format=mock.ANY,
autodetect=mock.ANY,
create_disposition=mock.ANY,
skip_leading_rows=mock.ANY,
write_disposition=mock.ANY,
field_delimiter=mock.ANY,
max_bad_records=mock.ANY,
quote_character=mock.ANY,
ignore_unknown_values=mock.ANY,
allow_quoted_newlines=mock.ANY,
allow_jagged_rows=mock.ANY,
encoding=mock.ANY,
schema_update_options=mock.ANY,
src_fmt_configs=mock.ANY,
time_partitioning=mock.ANY,
cluster_fields=mock.ANY,
encryption_configuration=mock.ANY,
labels=LABELS,
description=mock.ANY,
)

@mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook')
def test_description(self, bq_hook):

operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
description=DESCRIPTION,
)

operator.execute(None)

bq_hook.return_value.get_conn.return_value.cursor.return_value.run_load.assert_called_once_with(
destination_project_dataset_table=mock.ANY,
schema_fields=mock.ANY,
source_uris=mock.ANY,
source_format=mock.ANY,
autodetect=mock.ANY,
create_disposition=mock.ANY,
skip_leading_rows=mock.ANY,
write_disposition=mock.ANY,
field_delimiter=mock.ANY,
max_bad_records=mock.ANY,
quote_character=mock.ANY,
ignore_unknown_values=mock.ANY,
allow_quoted_newlines=mock.ANY,
allow_jagged_rows=mock.ANY,
encoding=mock.ANY,
schema_update_options=mock.ANY,
src_fmt_configs=mock.ANY,
time_partitioning=mock.ANY,
cluster_fields=mock.ANY,
encryption_configuration=mock.ANY,
labels=mock.ANY,
description=DESCRIPTION,
)

@mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook')
def test_labels_external_table(self, bq_hook):

operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
labels=LABELS,
external_table=True,
)

operator.execute(None)
# fmt: off
bq_hook.return_value.get_conn.return_value.cursor.return_value.create_external_table. \
assert_called_once_with(
external_project_dataset_table=mock.ANY,
schema_fields=mock.ANY,
source_uris=mock.ANY,
source_format=mock.ANY,
compression=mock.ANY,
skip_leading_rows=mock.ANY,
field_delimiter=mock.ANY,
max_bad_records=mock.ANY,
quote_character=mock.ANY,
ignore_unknown_values=mock.ANY,
allow_quoted_newlines=mock.ANY,
allow_jagged_rows=mock.ANY,
encoding=mock.ANY,
src_fmt_configs=mock.ANY,
encryption_configuration=mock.ANY,
labels=LABELS,
description=mock.ANY,
)
# fmt: on

@mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook')
def test_description_external_table(self, bq_hook):

operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
description=DESCRIPTION,
external_table=True,
)

operator.execute(None)
# fmt: off
bq_hook.return_value.get_conn.return_value.cursor.return_value.create_external_table. \
assert_called_once_with(
external_project_dataset_table=mock.ANY,
schema_fields=mock.ANY,
source_uris=mock.ANY,
source_format=mock.ANY,
compression=mock.ANY,
skip_leading_rows=mock.ANY,
field_delimiter=mock.ANY,
max_bad_records=mock.ANY,
quote_character=mock.ANY,
ignore_unknown_values=mock.ANY,
allow_quoted_newlines=mock.ANY,
allow_jagged_rows=mock.ANY,
encoding=mock.ANY,
src_fmt_configs=mock.ANY,
encryption_configuration=mock.ANY,
labels=mock.ANY,
description=DESCRIPTION,
)
# fmt: on

0 comments on commit 72ea841

Please sign in to comment.