Skip to content

Commit

Permalink
Handling project location param on async BigQuery dts trigger (#29786)
Browse files Browse the repository at this point in the history
  • Loading branch information
okayhooni committed Mar 15, 2023
1 parent 53afba2 commit 5a3be72
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
13 changes: 12 additions & 1 deletion airflow/providers/google/cloud/hooks/bigquery_dts.py
Expand Up @@ -304,11 +304,16 @@ async def _get_project_id(self) -> str:
sync_hook = await self.get_sync_hook()
return sync_hook.project_id

async def _get_project_location(self) -> str:
sync_hook = await self.get_sync_hook()
return sync_hook.location

async def get_transfer_run(
self,
config_id: str,
run_id: str,
project_id: str | None,
location: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
Expand All @@ -321,6 +326,7 @@ async def get_transfer_run(
:param project_id: The BigQuery project id where the transfer configuration should be
created. If set to None or missing, the default project_id from the Google Cloud connection
is used.
:param location: BigQuery Transfer Service location for regional transfers.
:param retry: A retry object used to retry requests. If `None` is
specified, requests will not be retried.
:param timeout: The amount of time, in seconds, to wait for the request to
Expand All @@ -330,8 +336,13 @@ async def get_transfer_run(
:return: An ``google.cloud.bigquery_datatransfer_v1.types.TransferRun`` instance.
"""
project_id = project_id or (await self._get_project_id())
location = location or (await self._get_project_location())
name = f"projects/{project_id}"
if location:
name += f"/locations/{location}"
name += f"/transferConfigs/{config_id}/runs/{run_id}"

client = await self._get_conn()
name = f"projects/{project_id}/transferConfigs/{config_id}/runs/{run_id}"
transfer_run = await client.get_transfer_run(
name=name,
retry=retry,
Expand Down
4 changes: 4 additions & 0 deletions airflow/providers/google/cloud/operators/bigquery_dts.py
Expand Up @@ -306,6 +306,10 @@ def hook(self) -> BiqQueryDataTransferServiceHook:

def execute(self, context: Context):
self.log.info("Submitting manual transfer for %s", self.transfer_config_id)

if self.requested_run_time and isinstance(self.requested_run_time.get("seconds"), str):
self.requested_run_time["seconds"] = int(self.requested_run_time["seconds"])

response = self.hook.start_manual_transfer_runs(
transfer_config_id=self.transfer_config_id,
requested_time_range=self.requested_time_range,
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/google/cloud/triggers/bigquery_dts.py
Expand Up @@ -101,6 +101,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
project_id=self.project_id,
config_id=self.config_id,
run_id=self.run_id,
location=self.location,
)
state = transfer_run.state
self.log.info("Current state is %s", state)
Expand Down

0 comments on commit 5a3be72

Please sign in to comment.