Skip to content

Commit

Permalink
fix: respect connection ID and impersonation in GKEStartPodOperator (#…
Browse files Browse the repository at this point in the history
…36861)

The GKEStartPodOperator accepts `gcp_conn_id` and `impersonation_chain`
as parameters.

This PR ensures that those values are passed to and supported by the corresponding
hooks and triggers in deferrable and non-deferrable mode.
  • Loading branch information
m1racoli committed Jan 20, 2024
1 parent d3b4a91 commit f1758fd
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 5 deletions.
26 changes: 22 additions & 4 deletions airflow/providers/google/cloud/hooks/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,15 @@ def __init__(
self,
cluster_url: str,
ssl_ca_cert: str,
*args,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
):
super().__init__(*args, **kwargs)
super().__init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
**kwargs,
)
self._cluster_url = cluster_url
self._ssl_ca_cert = ssl_ca_cert

Expand Down Expand Up @@ -440,10 +445,23 @@ class GKEPodAsyncHook(GoogleBaseAsyncHook):
sync_hook_class = GKEPodHook
scopes = ["https://www.googleapis.com/auth/cloud-platform"]

def __init__(self, cluster_url: str, ssl_ca_cert: str, **kwargs) -> None:
def __init__(
self,
cluster_url: str,
ssl_ca_cert: str,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
self._cluster_url = cluster_url
self._ssl_ca_cert = ssl_ca_cert
super().__init__(cluster_url=cluster_url, ssl_ca_cert=ssl_ca_cert, **kwargs)
super().__init__(
cluster_url=cluster_url,
ssl_ca_cert=ssl_ca_cert,
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
**kwargs,
)

@contextlib.asynccontextmanager
async def get_conn(self, token: Token) -> async_client.ApiClient: # type: ignore[override]
Expand Down
3 changes: 3 additions & 0 deletions airflow/providers/google/cloud/operators/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ def hook(self) -> GKEPodHook:
gcp_conn_id=self.gcp_conn_id,
cluster_url=self._cluster_url,
ssl_ca_cert=self._ssl_ca_cert,
impersonation_chain=self.impersonation_chain,
)
return hook

Expand Down Expand Up @@ -577,6 +578,8 @@ def invoke_defer_method(self):
in_cluster=self.in_cluster,
base_container_name=self.base_container_name,
on_finish_action=self.on_finish_action,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
kwargs={"cluster_url": self._cluster_url, "ssl_ca_cert": self._ssl_ca_cert},
Expand Down
8 changes: 8 additions & 0 deletions airflow/providers/google/cloud/triggers/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def __init__(
startup_timeout: int = 120,
on_finish_action: str = "delete_pod",
should_delete_pod: bool | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
*args,
**kwargs,
):
Expand All @@ -96,6 +98,8 @@ def __init__(
self.in_cluster = in_cluster
self.get_logs = get_logs
self.startup_timeout = startup_timeout
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

if should_delete_pod is not None:
warnings.warn(
Expand Down Expand Up @@ -131,6 +135,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"base_container_name": self.base_container_name,
"should_delete_pod": self.should_delete_pod,
"on_finish_action": self.on_finish_action.value,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
},
)

Expand All @@ -139,6 +145,8 @@ def hook(self) -> GKEPodAsyncHook: # type: ignore[override]
return GKEPodAsyncHook(
cluster_url=self._cluster_url,
ssl_ca_cert=self._ssl_ca_cert,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)


Expand Down
9 changes: 8 additions & 1 deletion tests/providers/google/cloud/hooks/test_kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def async_hook(self):
return GKEPodAsyncHook(
cluster_url=CLUSTER_URL,
ssl_ca_cert=SSL_CA_CERT,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATE_CHAIN,
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -405,7 +407,12 @@ def setup_method(self):
with mock.patch(
BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id
):
self.gke_hook = GKEPodHook(gcp_conn_id="test", ssl_ca_cert=None, cluster_url=None)
self.gke_hook = GKEPodHook(
gcp_conn_id="test",
impersonation_chain=IMPERSONATE_CHAIN,
ssl_ca_cert=None,
cluster_url=None,
)
self.gke_hook._client = mock.Mock()

def refresh_token(request):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def trigger():
cluster_url=CLUSTER_URL,
ssl_ca_cert=SSL_CA_CERT,
base_container_name=BASE_CONTAINER_NAME,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)


Expand Down Expand Up @@ -101,6 +103,8 @@ def test_serialize_should_execute_successfully(self, trigger):
"base_container_name": BASE_CONTAINER_NAME,
"on_finish_action": ON_FINISH_ACTION,
"should_delete_pod": SHOULD_DELETE_POD,
"gcp_conn_id": GCP_CONN_ID,
"impersonation_chain": IMPERSONATION_CHAIN,
}

@pytest.mark.asyncio
Expand Down

0 comments on commit f1758fd

Please sign in to comment.