Skip to content

Commit

Permalink
Add protocol to define methods relied upon by KubernetesPodOperator (#…
Browse files Browse the repository at this point in the history
…31298)

Subclasses of KubernetesPodOperator, such as GKEStartPodOperator, may use
hooks that don't extend KubernetesHook.  We use this protocol to document the
methods used by KPO and ensure that these methods exist on such other hooks.
  • Loading branch information
dstandish committed May 16, 2023
1 parent 1bd538b commit caeca2d
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 73 deletions.
4 changes: 3 additions & 1 deletion airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.hooks.base import BaseHook
from airflow.kubernetes.kube_client import _disable_verify_ssl, _enable_tcp_keepalive
from airflow.models import Connection
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodOperatorHookProtocol
from airflow.utils import yaml

LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file kube_config from {}..."
Expand All @@ -47,7 +48,7 @@ def _load_body_to_dict(body: str) -> dict:
return body_dict


class KubernetesHook(BaseHook):
class KubernetesHook(BaseHook, PodOperatorHookProtocol):
"""
Creates Kubernetes API connection.
Expand Down Expand Up @@ -439,6 +440,7 @@ def get_pod_logs(
)

def get_pod(self, name: str, namespace: str) -> V1Pod:
"""Read pod object from kubernetes API."""
return self.core_v1_client.read_namespaced_pod(
name=name,
namespace=namespace,
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from airflow.providers.cncf.kubernetes.utils.pod_manager import (
PodLaunchFailedException,
PodManager,
PodOperatorHookProtocol,
PodPhase,
get_container_termination_message,
)
Expand Down Expand Up @@ -463,7 +464,7 @@ def pod_manager(self) -> PodManager:
return PodManager(kube_client=self.client)

@cached_property
def hook(self) -> KubernetesHook:
def hook(self) -> PodOperatorHookProtocol:
hook = KubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
Expand Down
35 changes: 35 additions & 0 deletions airflow/providers/cncf/kubernetes/utils/pod_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.kubernetes.pod_generator import PodDefaults
from airflow.typing_compat import Protocol
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.timezone import utcnow

Expand Down Expand Up @@ -72,6 +73,40 @@ class PodPhase:
terminal_states = {FAILED, SUCCEEDED}


class PodOperatorHookProtocol(Protocol):
"""
Protocol to define methods relied upon by KubernetesPodOperator
Subclasses of KubernetesPodOperator, such as GKEStartPodOperator, may use
hooks that don't extend KubernetesHook. We use this protocol to document the
methods used by KPO and ensure that these methods exist on such other hooks.
"""

@property
def core_v1_client(self) -> client.CoreV1Api:
"""Get authenticated CoreV1Api object."""

@property
def is_in_cluster(self) -> bool:
"""Expose whether the hook is configured with ``load_incluster_config`` or not"""

def get_pod(self, name: str, namespace: str) -> V1Pod:
"""Read pod object from kubernetes API."""

def _get_namespace(self) -> str | None:
"""
Returns the namespace that defined in the connection
TODO: in provider version 6.0, get rid of this method and make it the behavior of get_namespace.
"""

def get_xcom_sidecar_container_image(self) -> str | None:
"""Returns the xcom sidecar image that defined in the connection"""

def get_xcom_sidecar_container_resources(self) -> str | None:
"""Returns the xcom sidecar resources that defined in the connection"""


def get_container_status(pod: V1Pod, container_name: str) -> V1ContainerStatus | None:
"""Retrieves container status"""
container_statuses = pod.status.container_statuses if pod and pod.status else None
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/google/cloud/hooks/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from airflow import version
from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodOperatorHookProtocol
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
Expand Down Expand Up @@ -346,7 +347,7 @@ async def get_operation(
)


class GKEPodHook(GoogleBaseHook):
class GKEPodHook(GoogleBaseHook, PodOperatorHookProtocol):
"""Hook for managing Google Kubernetes Engine pod APIs."""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def cluster_hook(self) -> GKEHook:
)

@cached_property
def hook(self) -> GKEPodHook: # type: ignore[override]
def hook(self) -> GKEPodHook:
if self._cluster_url is None or self._ssl_ca_cert is None:
raise AttributeError(
"Cluster url and ssl_ca_cert should be defined before using self.hook method. "
Expand Down
3 changes: 3 additions & 0 deletions airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,9 @@ additional-extras:
- name: apache.beam
dependencies:
- apache-beam[gcp]
- name: cncf.kubernetes
dependencies:
- apache-airflow-providers-cncf-kubernetes>=6.2.0
- name: leveldb
dependencies:
- plyvel
Expand Down
36 changes: 0 additions & 36 deletions tests/ast_helpers.py

This file was deleted.

33 changes: 0 additions & 33 deletions tests/providers/google/cloud/hooks/test_kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import ast
import sys
from asyncio import Future

Expand All @@ -31,11 +30,8 @@
GKEAsyncHook,
GKEHook,
GKEPodAsyncHook,
GKEPodHook,
)
from airflow.providers.google.common.consts import CLIENT_INFO
from tests import REPO_ROOT
from tests.ast_helpers import extract_ast_class_def_by_name
from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id

if sys.version_info < (3, 8):
Expand Down Expand Up @@ -416,32 +412,3 @@ async def test_get_operation(self, mock_get_client, async_gke_hook, mock_async_g
mock_async_gke_cluster_client.get_operation.assert_called_once_with(
name=operation_path,
)


def test_hook_has_methods_required_by_kpo():
kpo_file = REPO_ROOT / "airflow/providers/cncf/kubernetes/operators/pod.py"
class_def = extract_ast_class_def_by_name(kpo_file.read_text(), "KubernetesPodOperator")

def get_hook_attr_refs(tree, attr):
for node in ast.walk(tree):
if isinstance(node, ast.Attribute):
if isinstance(node.value, ast.Attribute) and node.value.attr == attr:
yield node.attr

# use AST to find all hook attr references in KPO
methods = set(get_hook_attr_refs(class_def, "hook"))
# actually verify that GKE has all attrs referenced by KPO
assert methods.intersection(GKEPodHook.__dict__) == methods

# sanity check below
# the list here is not strictly required but it's helpful to verify that the test is working
# will need to be updated when new hook method / attr references added to KPO
expected = {
"core_v1_client",
"get_pod",
"_get_namespace",
"is_in_cluster",
"get_xcom_sidecar_container_image",
"get_xcom_sidecar_container_resources",
}
assert methods == expected

0 comments on commit caeca2d

Please sign in to comment.