Skip to content

Commit

Permalink
Add location support to BigQueryDataTransferServiceTransferRunSensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
tirkarthi authored and potiuk committed Apr 25, 2022
1 parent 692a089 commit 967140e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
3 changes: 3 additions & 0 deletions airflow/providers/google/cloud/sensors/bigquery_dts.py
Expand Up @@ -84,6 +84,7 @@ def __init__(
retry: Union[Retry, _MethodDefault] = DEFAULT,
request_timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
location: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
) -> None:
Expand All @@ -97,6 +98,7 @@ def __init__(
self.project_id = project_id
self.gcp_cloud_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.location = location

def _normalize_state_list(self, states) -> Set[TransferState]:
states = {states} if isinstance(states, (str, TransferState, int)) else states
Expand All @@ -122,6 +124,7 @@ def poke(self, context: 'Context') -> bool:
hook = BiqQueryDataTransferServiceHook(
gcp_conn_id=self.gcp_cloud_conn_id,
impersonation_chain=self.impersonation_chain,
location=self.location,
)
run = hook.get_transfer_run(
run_id=self.run_id,
Expand Down
9 changes: 9 additions & 0 deletions tests/providers/google/cloud/sensors/test_bigquery_dts.py
Expand Up @@ -30,6 +30,8 @@
TRANSFER_CONFIG_ID = "config_id"
RUN_ID = "run_id"
PROJECT_ID = "project_id"
LOCATION = "europe"
GCP_CONN_ID = "google_cloud_default"


class TestBigQueryDataTransferServiceTransferRunSensor(unittest.TestCase):
Expand All @@ -48,6 +50,8 @@ def test_poke_returns_false(self, mock_hook):

with pytest.raises(AirflowException, match="Transfer"):
op.poke({})

mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, location=None)
mock_hook.return_value.get_transfer_run.assert_called_once_with(
transfer_config_id=TRANSFER_CONFIG_ID,
run_id=RUN_ID,
Expand All @@ -68,10 +72,15 @@ def test_poke_returns_true(self, mock_hook):
task_id="id",
project_id=PROJECT_ID,
expected_statuses={"SUCCEEDED"},
location=LOCATION,
)
result = op.poke({})

assert result is True

mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, location=LOCATION
)
mock_hook.return_value.get_transfer_run.assert_called_once_with(
transfer_config_id=TRANSFER_CONFIG_ID,
run_id=RUN_ID,
Expand Down

0 comments on commit 967140e

Please sign in to comment.