Skip to content

Commit

Permalink
Fix BigQuery transfer operators to respect project_id arguments (#32232)
Browse files Browse the repository at this point in the history
  • Loading branch information
avinashpandeshwar committed Jul 6, 2023
1 parent b3db4de commit 2d690de
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 55 deletions.
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _handle_job_error(job: BigQueryJob | UnknownJob) -> None:
def _prepare_configuration(self):
source_project, source_dataset, source_table = self.hook.split_tablename(
table_input=self.source_project_dataset_table,
default_project_id=self.project_id or self.hook.project_id,
default_project_id=self.hook.project_id,
var_name="source_project_dataset_table",
)

Expand Down Expand Up @@ -184,7 +184,7 @@ def _submit_job(

return hook.insert_job(
configuration=configuration,
project_id=configuration["extract"]["sourceTable"]["projectId"],
project_id=self.project_id or hook.project_id,
location=self.location,
job_id=job_id,
timeout=self.result_timeout,
Expand Down Expand Up @@ -255,7 +255,7 @@ def execute(self, context: Context):
trigger=BigQueryInsertJobTrigger(
conn_id=self.gcp_conn_id,
job_id=job_id,
project_id=self.hook.project_id,
project_id=self.project_id or self.hook.project_id,
),
method_name="execute_complete",
)
Expand Down
31 changes: 14 additions & 17 deletions airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _submit_job(
# Submit a new job without waiting for it to complete.
return hook.insert_job(
configuration=self.configuration,
project_id=self.project_id,
project_id=self.project_id or hook.project_id,
location=self.location,
job_id=job_id,
timeout=self.result_timeout,
Expand Down Expand Up @@ -359,7 +359,7 @@ def execute(self, context: Context):

if self.external_table:
self.log.info("Creating a new BigQuery table for storing data...")
table_obj_api_repr = self._create_empty_table()
table_obj_api_repr = self._create_external_table()

BigQueryTableLink.persist(
context=context,
Expand All @@ -381,7 +381,7 @@ def execute(self, context: Context):
except Conflict:
# If the job already exists retrieve it
job = self.hook.get_job(
project_id=self.hook.project_id,
project_id=self.project_id or self.hook.project_id,
location=self.location,
job_id=job_id,
)
Expand Down Expand Up @@ -414,12 +414,12 @@ def execute(self, context: Context):
persist_kwargs = {
"context": context,
"task_instance": self,
"project_id": self.hook.project_id,
"table_id": table,
}
if not isinstance(table, str):
persist_kwargs["table_id"] = table["tableId"]
persist_kwargs["dataset_id"] = table["datasetId"]
persist_kwargs["project_id"] = table["projectId"]
BigQueryTableLink.persist(**persist_kwargs)

self.job_id = job.job_id
Expand All @@ -430,7 +430,7 @@ def execute(self, context: Context):
trigger=BigQueryInsertJobTrigger(
conn_id=self.gcp_conn_id,
job_id=self.job_id,
project_id=self.hook.project_id,
project_id=self.project_id or self.hook.project_id,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -475,7 +475,9 @@ def _find_max_value_in_column(self):
}
}
try:
job_id = hook.insert_job(configuration=self.configuration, project_id=hook.project_id)
job_id = hook.insert_job(
configuration=self.configuration, project_id=self.project_id or hook.project_id
)
rows = list(hook.get_job(job_id=job_id, location=self.location).result())
except BadRequest as e:
if "Unrecognized name:" in e.message:
Expand All @@ -498,12 +500,7 @@ def _find_max_value_in_column(self):
else:
raise RuntimeError(f"The {select_command} returned no rows!")

def _create_empty_table(self):
self.project_id, dataset_id, table_id = self.hook.split_tablename(
table_input=self.destination_project_dataset_table,
default_project_id=self.project_id or self.hook.project_id,
)

def _create_external_table(self):
external_config_api_repr = {
"autodetect": self.autodetect,
"sourceFormat": self.source_format,
Expand Down Expand Up @@ -549,7 +546,7 @@ def _create_empty_table(self):

# build table definition
table = Table(
table_ref=TableReference.from_string(self.destination_project_dataset_table, self.project_id)
table_ref=TableReference.from_string(self.destination_project_dataset_table, self.hook.project_id)
)
table.external_data_configuration = external_config
if self.labels:
Expand All @@ -567,17 +564,17 @@ def _create_empty_table(self):
self.log.info("Creating external table: %s", self.destination_project_dataset_table)
self.hook.create_empty_table(
table_resource=table_obj_api_repr,
project_id=self.project_id,
project_id=self.project_id or self.hook.project_id,
location=self.location,
exists_ok=True,
)
self.log.info("External table created successfully: %s", self.destination_project_dataset_table)
return table_obj_api_repr

def _use_existing_table(self):
self.project_id, destination_dataset, destination_table = self.hook.split_tablename(
destination_project_id, destination_dataset, destination_table = self.hook.split_tablename(
table_input=self.destination_project_dataset_table,
default_project_id=self.project_id or self.hook.project_id,
default_project_id=self.hook.project_id,
var_name="destination_project_dataset_table",
)

Expand All @@ -597,7 +594,7 @@ def _use_existing_table(self):
"autodetect": self.autodetect,
"createDisposition": self.create_disposition,
"destinationTable": {
"projectId": self.project_id,
"projectId": destination_project_id,
"datasetId": destination_dataset,
"tableId": destination_table,
},
Expand Down
12 changes: 7 additions & 5 deletions tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TEST_DATASET = "test-dataset"
TEST_TABLE_ID = "test-table-id"
PROJECT_ID = "test-project-id"
JOB_PROJECT_ID = "job-project-id"


class TestBigQueryToGCSOperator:
Expand Down Expand Up @@ -66,7 +67,7 @@ def test_execute(self, mock_hook):
mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID)
mock_hook.return_value.generate_job_id.return_value = real_job_id
mock_hook.return_value.insert_job.return_value = MagicMock(job_id="real_job_id", error_result=False)
mock_hook.return_value.project_id = PROJECT_ID
mock_hook.return_value.project_id = JOB_PROJECT_ID

operator = BigQueryToGCSOperator(
task_id=TASK_ID,
Expand All @@ -77,13 +78,14 @@ def test_execute(self, mock_hook):
field_delimiter=field_delimiter,
print_header=print_header,
labels=labels,
project_id=JOB_PROJECT_ID,
)
operator.execute(context=mock.MagicMock())

mock_hook.return_value.insert_job.assert_called_once_with(
job_id="123456_hash",
configuration=expected_configuration,
project_id=PROJECT_ID,
project_id=JOB_PROJECT_ID,
location=None,
timeout=None,
retry=DEFAULT_RETRY,
Expand Down Expand Up @@ -122,10 +124,10 @@ def test_execute_deferrable_mode(self, mock_hook):
mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID)
mock_hook.return_value.generate_job_id.return_value = real_job_id
mock_hook.return_value.insert_job.return_value = MagicMock(job_id="real_job_id", error_result=False)
mock_hook.return_value.project_id = PROJECT_ID
mock_hook.return_value.project_id = JOB_PROJECT_ID

operator = BigQueryToGCSOperator(
project_id=PROJECT_ID,
project_id=JOB_PROJECT_ID,
task_id=TASK_ID,
source_project_dataset_table=source_project_dataset_table,
destination_cloud_storage_uris=destination_cloud_storage_uris,
Expand All @@ -146,7 +148,7 @@ def test_execute_deferrable_mode(self, mock_hook):
mock_hook.return_value.insert_job.assert_called_once_with(
configuration=expected_configuration,
job_id="123456_hash",
project_id=PROJECT_ID,
project_id=JOB_PROJECT_ID,
location=None,
timeout=None,
retry=DEFAULT_RETRY,
Expand Down

0 comments on commit 2d690de

Please sign in to comment.