Skip to content

Commit

Permalink
Fix GCSToBigQueryOperator not respecting schema_obj (#28444)
Browse files Browse the repository at this point in the history
* Fix GCSToBigQueryOperator not respecting schema_obj
  • Loading branch information
vchiapaikeo committed Dec 20, 2022
1 parent 032a542 commit 9eacf60
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 4 deletions.
10 changes: 6 additions & 4 deletions airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,15 +313,17 @@ def execute(self, context: Context):
self.source_objects if isinstance(self.source_objects, list) else [self.source_objects]
)
source_uris = [f"gs://{self.bucket}/{source_object}" for source_object in self.source_objects]
if not self.schema_fields:

if not self.schema_fields and self.schema_object and self.source_format != "DATASTORE_BACKUP":
gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
if self.schema_object and self.source_format != "DATASTORE_BACKUP":
schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object).decode("utf-8"))
self.log.info("Autodetected fields from schema object: %s", schema_fields)
self.schema_fields = json.loads(
gcs_hook.download(self.schema_object_bucket, self.schema_object).decode("utf-8")
)
self.log.info("Autodetected fields from schema object: %s", self.schema_fields)

if self.external_table:
self.log.info("Creating a new BigQuery table for storing data...")
Expand Down
114 changes: 114 additions & 0 deletions tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import json
import unittest
from unittest import mock
from unittest.mock import MagicMock, call
Expand Down Expand Up @@ -51,6 +52,8 @@
{"name": "id", "type": "INTEGER", "mode": "NULLABLE"},
{"name": "name", "type": "STRING", "mode": "NULLABLE"},
]
SCHEMA_BUCKET = "test-schema-bucket"
SCHEMA_OBJECT = "test/schema/schema.json"
TEST_SOURCE_OBJECTS = ["test/objects/test.csv"]
TEST_SOURCE_OBJECTS_AS_STRING = "test/objects/test.csv"
LABELS = {"k1": "v1"}
Expand Down Expand Up @@ -675,6 +678,117 @@ def test_source_objs_as_string_without_external_table_should_execute_successfull

hook.return_value.insert_job.assert_has_calls(calls)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSHook")
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
def test_schema_obj_external_table_should_execute_successfully(self, bq_hook, gcs_hook):
bq_hook.return_value.insert_job.side_effect = [
MagicMock(job_id=pytest.real_job_id, error_result=False),
pytest.real_job_id,
]
bq_hook.return_value.generate_job_id.return_value = pytest.real_job_id
bq_hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
gcs_hook.return_value.download.return_value = bytes(json.dumps(SCHEMA_FIELDS), "utf-8")
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
schema_object_bucket=SCHEMA_BUCKET,
schema_object=SCHEMA_OBJECT,
write_disposition=WRITE_DISPOSITION,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
external_table=True,
)

operator.execute(context=MagicMock())

bq_hook.return_value.create_empty_table.assert_called_once_with(
table_resource={
"tableReference": {"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
"labels": None,
"description": None,
"externalDataConfiguration": {
"source_uris": [f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
"source_format": "CSV",
"maxBadRecords": 0,
"autodetect": True,
"compression": "NONE",
"csvOptions": {
"fieldDelimeter": ",",
"skipLeadingRows": None,
"quote": None,
"allowQuotedNewlines": False,
"allowJaggedRows": False,
},
},
"location": None,
"encryptionConfiguration": None,
"schema": {"fields": SCHEMA_FIELDS},
}
)
gcs_hook.return_value.download.assert_called_once_with(SCHEMA_BUCKET, SCHEMA_OBJECT)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSHook")
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
def test_schema_obj_without_external_table_should_execute_successfully(self, bq_hook, gcs_hook):
bq_hook.return_value.insert_job.side_effect = [
MagicMock(job_id=pytest.real_job_id, error_result=False),
pytest.real_job_id,
]
bq_hook.return_value.generate_job_id.return_value = pytest.real_job_id
bq_hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
gcs_hook.return_value.download.return_value = bytes(json.dumps(SCHEMA_FIELDS), "utf-8")

operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
schema_object_bucket=SCHEMA_BUCKET,
schema_object=SCHEMA_OBJECT,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
write_disposition=WRITE_DISPOSITION,
external_table=False,
)

operator.execute(context=MagicMock())

calls = [
call(
configuration={
"load": dict(
autodetect=True,
createDisposition="CREATE_IF_NEEDED",
destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
destinationTableProperties={
"description": None,
"labels": None,
},
sourceFormat="CSV",
skipLeadingRows=None,
sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
writeDisposition=WRITE_DISPOSITION,
ignoreUnknownValues=False,
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=bq_hook.return_value.project_id,
location=None,
job_id=pytest.real_job_id,
timeout=None,
retry=DEFAULT_RETRY,
nowait=True,
),
]

bq_hook.return_value.insert_job.assert_has_calls(calls)
gcs_hook.return_value.download.assert_called_once_with(SCHEMA_BUCKET, SCHEMA_OBJECT)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
def test_all_fields_should_be_present(self, hook):
hook.return_value.insert_job.side_effect = [
Expand Down

0 comments on commit 9eacf60

Please sign in to comment.