Skip to content

Commit

Permalink
CloudRunExecuteJobOperator: Add project_id to hook.get_job calls (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
freyrsae committed Feb 22, 2024
1 parent c0e30cb commit 5fc866a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/operators/cloud_run.py
Expand Up @@ -302,7 +302,7 @@ def execute(self, context: Context):
if not self.deferrable:
result: Execution = self._wait_for_operation(self.operation)
self._fail_if_execution_failed(result)
job = hook.get_job(job_name=result.job, region=self.region)
job = hook.get_job(job_name=result.job, region=self.region, project_id=self.project_id)
return Job.to_dict(job)
else:
self.defer(
Expand Down Expand Up @@ -333,7 +333,7 @@ def execute_complete(self, context: Context, event: dict):

hook: CloudRunHook = CloudRunHook(self.gcp_conn_id, self.impersonation_chain)

job = hook.get_job(job_name=event["job_name"], region=self.region)
job = hook.get_job(job_name=event["job_name"], region=self.region, project_id=self.project_id)
return Job.to_dict(job)

def _fail_if_execution_failed(self, execution: Execution):
Expand Down
12 changes: 12 additions & 0 deletions tests/providers/google/cloud/operators/test_cloud_run.py
Expand Up @@ -102,6 +102,10 @@ def test_execute_success(self, hook_mock):

operator.execute(context=mock.MagicMock())

hook_mock.return_value.get_job.assert_called_once_with(
job_name=mock.ANY, region=REGION, project_id=PROJECT_ID
)

hook_mock.return_value.execute_job.assert_called_once_with(
job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID, overrides=None
)
Expand Down Expand Up @@ -214,6 +218,10 @@ def test_execute_deferrable_execute_complete_method_success(self, hook_mock):
event = {"status": RunJobStatus.SUCCESS.value, "job_name": JOB_NAME}

result = operator.execute_complete(mock.MagicMock(), event)

hook_mock.return_value.get_job.assert_called_once_with(
job_name=mock.ANY, region=REGION, project_id=PROJECT_ID
)
assert result["name"] == JOB_NAME

@mock.patch(CLOUD_RUN_HOOK_PATH)
Expand All @@ -233,6 +241,10 @@ def test_execute_overrides(self, hook_mock):

operator.execute(context=mock.MagicMock())

hook_mock.return_value.get_job.assert_called_once_with(
job_name=mock.ANY, region=REGION, project_id=PROJECT_ID
)

hook_mock.return_value.execute_job.assert_called_once_with(
job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID, overrides=overrides
)
Expand Down

0 comments on commit 5fc866a

Please sign in to comment.