Skip to content

Commit

Permalink
Add on_finish_action to KubernetesPodOperator (#30718)
Browse files Browse the repository at this point in the history
* Add a new arg for KPO to only delete the pod when it doesn't fail

* deprecate is_delete_operator_pod and add on_finish_action

* Add deprecated properties and fix unit tests

* add missing attribute

* Apply suggestions from code review

Co-authored-by: Jed Cunningham <[email protected]>

* update GKEStartPodOperator to be consistent with KPO

* update EksPodOperator to be consistent with KPO

* update unit tests and the method used to check the kpo compatibility

* Fix a bug and add a new unit test for each provider

* warn with AirflowProviderDeprecationWarning instead of DeprecationWarning

* Bump KPO min version in GCP provider and add a new one to AWS provider

* Add the new param to the GKE trigger

* Apply suggestions from code review

Co-authored-by: Jarek Potiuk <[email protected]>

---------

Co-authored-by: Jed Cunningham <[email protected]>
Co-authored-by: Jarek Potiuk <[email protected]>
  • Loading branch information
3 people committed Jun 30, 2023
1 parent 64a3787 commit dd937e5
Show file tree
Hide file tree
Showing 21 changed files with 430 additions and 76 deletions.
34 changes: 26 additions & 8 deletions airflow/providers/amazon/aws/operators/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
EksNodegroupTrigger,
)
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction

try:
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
Expand Down Expand Up @@ -854,10 +855,15 @@ class EksPodOperator(KubernetesPodOperator):
running Airflow in a distributed manner and aws_conn_id is None or
empty, then the default boto3 configuration would be used (and must be
maintained on each worker node).
:param on_finish_action: What to do when the pod reaches its final state, or the execution is interrupted.
If "delete_pod", the pod will be deleted regardless it's state; if "delete_succeeded_pod",
only succeeded pod will be deleted. You can set to "keep_pod" to keep the pod.
Current default is `keep_pod`, but this will be changed in the next major release of this provider.
:param is_delete_operator_pod: What to do when the pod reaches its final
state, or the execution is interrupted. If True, delete the
pod; if False, leave the pod. Current default is False, but this will be
pod; if False, leave the pod. Current default is False, but this will be
changed in the next major release of this provider.
Deprecated - use `on_finish_action` instead.
"""

Expand Down Expand Up @@ -885,19 +891,32 @@ def __init__(
pod_username: str | None = None,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
on_finish_action: str | None = None,
is_delete_operator_pod: bool | None = None,
**kwargs,
) -> None:
if is_delete_operator_pod is None:
if is_delete_operator_pod is not None:
warnings.warn(
f"You have not set parameter `is_delete_operator_pod` in class {self.__class__.__name__}. "
"Currently the default for this parameter is `False` but in a future release the default "
"will be changed to `True`. To ensure pods are not deleted in the future you will need to "
"set `is_delete_operator_pod=False` explicitly.",
"`is_delete_operator_pod` parameter is deprecated, please use `on_finish_action`",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
is_delete_operator_pod = False
kwargs["on_finish_action"] = (
OnFinishAction.DELETE_POD if is_delete_operator_pod else OnFinishAction.KEEP_POD
)
else:
if on_finish_action is not None:
kwargs["on_finish_action"] = OnFinishAction(on_finish_action)
else:
warnings.warn(
f"You have not set parameter `on_finish_action` in class {self.__class__.__name__}. "
"Currently the default for this parameter is `keep_pod` but in a future release"
" the default will be changed to `delete_pod`. To ensure pods are not deleted in"
" the future you will need to set `on_finish_action=keep_pod` explicitly.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
kwargs["on_finish_action"] = OnFinishAction.KEEP_POD

self.cluster_name = cluster_name
self.in_cluster = in_cluster
Expand All @@ -909,7 +928,6 @@ def __init__(
in_cluster=self.in_cluster,
namespace=self.namespace,
name=self.pod_name,
is_delete_operator_pod=is_delete_operator_pod,
**kwargs,
)
# There is no need to manage the kube_config file, as it will be generated automatically.
Expand Down
3 changes: 3 additions & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -667,3 +667,6 @@ additional-extras:
- name: aiobotocore
dependencies:
- aiobotocore[boto3]>=2.2.0
- name: cncf.kubernetes
dependencies:
- apache-airflow-providers-cncf-kubernetes>=7.2.0
43 changes: 34 additions & 9 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import re
import secrets
import string
import warnings
from collections.abc import Container
from contextlib import AbstractContextManager
from functools import cached_property
Expand All @@ -32,7 +33,7 @@
from slugify import slugify
from urllib3.exceptions import HTTPError

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.kubernetes import pod_generator
from airflow.kubernetes.pod_generator import PodGenerator
from airflow.kubernetes.secret import Secret
Expand All @@ -52,6 +53,7 @@
from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
from airflow.providers.cncf.kubernetes.utils import xcom_sidecar # type: ignore[attr-defined]
from airflow.providers.cncf.kubernetes.utils.pod_manager import (
OnFinishAction,
PodLaunchFailedException,
PodManager,
PodOperatorHookProtocol,
Expand Down Expand Up @@ -188,9 +190,6 @@ class KubernetesPodOperator(BaseOperator):
If more than one secret is required, provide a
comma separated list: secret_a,secret_b
:param service_account_name: Name of the service account
:param is_delete_operator_pod: What to do when the pod reaches its final
state, or the execution is interrupted. If True (default), delete the
pod; if False, leave the pod.
:param hostnetwork: If True enable host networking on the pod.
:param tolerations: A list of kubernetes tolerations.
:param security_context: security options the pod should run with (PodSecurityContext).
Expand Down Expand Up @@ -226,6 +225,13 @@ class KubernetesPodOperator(BaseOperator):
:param deferrable: Run operator in the deferrable mode.
:param poll_interval: Polling period in seconds to check for the status. Used only in deferrable mode.
:param log_pod_spec_on_failure: Log the pod's specification if a failure occurs
:param on_finish_action: What to do when the pod reaches its final state, or the execution is interrupted.
If "delete_pod", the pod will be deleted regardless it's state; if "delete_succeeded_pod",
only succeeded pod will be deleted. You can set to "keep_pod" to keep the pod.
:param is_delete_operator_pod: What to do when the pod reaches its final
state, or the execution is interrupted. If True (default), delete the
pod; if False, leave the pod.
Deprecated - use `on_finish_action` instead.
"""

# This field can be overloaded at the instance level via base_container_name
Expand Down Expand Up @@ -279,7 +285,6 @@ def __init__(
node_selector: dict | None = None,
image_pull_secrets: list[k8s.V1LocalObjectReference] | None = None,
service_account_name: str | None = None,
is_delete_operator_pod: bool = True,
hostnetwork: bool = False,
tolerations: list[k8s.V1Toleration] | None = None,
security_context: dict | None = None,
Expand All @@ -303,6 +308,8 @@ def __init__(
deferrable: bool = False,
poll_interval: float = 2,
log_pod_spec_on_failure: bool = True,
on_finish_action: str = "delete_pod",
is_delete_operator_pod: None | bool = None,
**kwargs,
) -> None:
# TODO: remove in provider 6.0.0 release. This is a mitigate step to advise users to switch to the
Expand Down Expand Up @@ -350,7 +357,6 @@ def __init__(
self.config_file = config_file
self.image_pull_secrets = convert_image_pull_secrets(image_pull_secrets) if image_pull_secrets else []
self.service_account_name = service_account_name
self.is_delete_operator_pod = is_delete_operator_pod
self.hostnetwork = hostnetwork
self.tolerations = (
[convert_toleration(toleration) for toleration in tolerations] if tolerations else []
Expand Down Expand Up @@ -384,6 +390,20 @@ def __init__(
self.poll_interval = poll_interval
self.remote_pod: k8s.V1Pod | None = None
self.log_pod_spec_on_failure = log_pod_spec_on_failure
if is_delete_operator_pod is not None:
warnings.warn(
"`is_delete_operator_pod` parameter is deprecated, please use `on_finish_action`",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
self.on_finish_action = (
OnFinishAction.DELETE_POD if is_delete_operator_pod else OnFinishAction.KEEP_POD
)
self.is_delete_operator_pod = is_delete_operator_pod
else:
self.on_finish_action = OnFinishAction(on_finish_action)
self.is_delete_operator_pod = self.on_finish_action == OnFinishAction.DELETE_POD

self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict

@cached_property
Expand Down Expand Up @@ -595,10 +615,10 @@ def invoke_defer_method(self):
config_file=self.config_file,
in_cluster=self.in_cluster,
poll_interval=self.poll_interval,
should_delete_pod=self.is_delete_operator_pod,
get_logs=self.get_logs,
startup_timeout=self.startup_timeout_seconds,
base_container_name=self.base_container_name,
on_finish_action=self.on_finish_action.value,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -669,7 +689,8 @@ def post_complete_action(self, *, pod, remote_pod, **kwargs):
def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
pod_phase = remote_pod.status.phase if hasattr(remote_pod, "status") else None

if pod_phase != PodPhase.SUCCEEDED or not self.is_delete_operator_pod:
# if the pod fails or success, but we don't want to delete it
if pod_phase != PodPhase.SUCCEEDED or self.on_finish_action == OnFinishAction.KEEP_POD:
self.patch_already_checked(remote_pod, reraise=False)

if pod_phase != PodPhase.SUCCEEDED:
Expand Down Expand Up @@ -722,7 +743,11 @@ def _read_pod_events(self, pod, *, reraise=True):
def process_pod_deletion(self, pod: k8s.V1Pod, *, reraise=True):
with _optionally_suppress(reraise=reraise):
if pod is not None:
if self.is_delete_operator_pod:
should_delete_pod = (self.on_finish_action == OnFinishAction.DELETE_POD) or (
self.on_finish_action == OnFinishAction.DELETE_SUCCEEDED_POD
and pod.status.phase == PodPhase.SUCCEEDED
)
if should_delete_pod:
self.log.info("Deleting pod: %s", pod.metadata.name)
self.pod_manager.delete_pod(pod)
else:
Expand Down
34 changes: 27 additions & 7 deletions airflow/providers/cncf/kubernetes/triggers/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import asyncio
import warnings
from asyncio import CancelledError
from datetime import datetime
from enum import Enum
Expand All @@ -25,8 +26,9 @@
import pytz
from kubernetes_asyncio.client.models import V1Pod

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodPhase
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodPhase
from airflow.triggers.base import BaseTrigger, TriggerEvent


Expand Down Expand Up @@ -57,11 +59,15 @@ class KubernetesPodTrigger(BaseTrigger):
:param poll_interval: Polling period in seconds to check for the status.
:param trigger_start_time: time in Datetime format when the trigger was started
:param in_cluster: run kubernetes client with in_cluster configuration.
:param get_logs: get the stdout of the container as logs of the tasks.
:param startup_timeout: timeout in seconds to start up the pod.
:param on_finish_action: What to do when the pod reaches its final state, or the execution is interrupted.
If "delete_pod", the pod will be deleted regardless it's state; if "delete_succeeded_pod",
only succeeded pod will be deleted. You can set to "keep_pod" to keep the pod.
:param should_delete_pod: What to do when the pod reaches its final
state, or the execution is interrupted. If True (default), delete the
pod; if False, leave the pod.
:param get_logs: get the stdout of the container as logs of the tasks.
:param startup_timeout: timeout in seconds to start up the pod.
Deprecated - use `on_finish_action` instead.
"""

def __init__(
Expand All @@ -75,9 +81,10 @@ def __init__(
cluster_context: str | None = None,
config_file: str | None = None,
in_cluster: bool | None = None,
should_delete_pod: bool = True,
get_logs: bool = True,
startup_timeout: int = 120,
on_finish_action: str = "delete_pod",
should_delete_pod: bool | None = None,
):
super().__init__()
self.pod_name = pod_name
Expand All @@ -89,10 +96,22 @@ def __init__(
self.cluster_context = cluster_context
self.config_file = config_file
self.in_cluster = in_cluster
self.should_delete_pod = should_delete_pod
self.get_logs = get_logs
self.startup_timeout = startup_timeout

if should_delete_pod is not None:
warnings.warn(
"`should_delete_pod` parameter is deprecated, please use `on_finish_action`",
AirflowProviderDeprecationWarning,
)
self.on_finish_action = (
OnFinishAction.DELETE_POD if should_delete_pod else OnFinishAction.KEEP_POD
)
self.should_delete_pod = should_delete_pod
else:
self.on_finish_action = OnFinishAction(on_finish_action)
self.should_delete_pod = self.on_finish_action == OnFinishAction.DELETE_POD

self._hook: AsyncKubernetesHook | None = None
self._since_time = None

Expand All @@ -109,10 +128,11 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"cluster_context": self.cluster_context,
"config_file": self.config_file,
"in_cluster": self.in_cluster,
"should_delete_pod": self.should_delete_pod,
"get_logs": self.get_logs,
"startup_timeout": self.startup_timeout,
"trigger_start_time": self.trigger_start_time,
"should_delete_pod": self.should_delete_pod,
"on_finish_action": self.on_finish_action.value,
},
)

Expand Down Expand Up @@ -191,7 +211,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
name=self.pod_name,
namespace=self.pod_namespace,
)
if self.should_delete_pod:
if self.on_finish_action == OnFinishAction.DELETE_POD:
self.log.info("Deleting pod...")
await self._get_async_hook().delete_pod(
name=self.pod_name,
Expand Down
9 changes: 9 additions & 0 deletions airflow/providers/cncf/kubernetes/utils/pod_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Launches PODs."""
from __future__ import annotations

import enum
import json
import logging
import math
Expand Down Expand Up @@ -585,3 +586,11 @@ def _exec_pod_command(self, resp, command: str) -> str | None:
if res:
return res
return res


class OnFinishAction(enum.Enum):
"""Action to take when the pod finishes."""

KEEP_POD = "keep_pod"
DELETE_POD = "delete_pod"
DELETE_SUCCEEDED_POD = "delete_succeeded_pod"

0 comments on commit dd937e5

Please sign in to comment.