Skip to content

Commit

Permalink
Add deferrable capability to existing ``DataprocDeleteClusterOperator…
Browse files Browse the repository at this point in the history
…`` (#29349)

* Add deferrable capability to existing DataprocDeleteClusterOperator

Using param deferrable=True, add support for deleting a Google
Dataproc cluster asynchronously.
  • Loading branch information
pankajkoti committed Feb 3, 2023
1 parent d37ef06 commit 872df12
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 4 deletions.
50 changes: 46 additions & 4 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from airflow.providers.google.cloud.triggers.dataproc import (
DataprocBatchTrigger,
DataprocClusterTrigger,
DataprocDeleteClusterTrigger,
DataprocSubmitTrigger,
)
from airflow.utils import timezone
Expand Down Expand Up @@ -822,6 +823,8 @@ class DataprocDeleteClusterOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param polling_interval_seconds: Time (seconds) to wait between calls to check the cluster status.
"""

template_fields: Sequence[str] = ("project_id", "region", "cluster_name", "impersonation_chain")
Expand All @@ -835,13 +838,17 @@ def __init__(
cluster_uuid: str | None = None,
request_id: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
timeout: float = 1 * 60 * 60,
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
polling_interval_seconds: int = 10,
**kwargs,
):
super().__init__(**kwargs)
if deferrable and polling_interval_seconds <= 0:
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
self.project_id = project_id
self.region = region
self.cluster_name = cluster_name
Expand All @@ -852,11 +859,48 @@ def __init__(
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds

def execute(self, context: Context) -> None:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
operation = self._delete_cluster(hook)
if not self.deferrable:
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
self.log.info("Cluster deleted.")
else:
end_time: float = time.time() + self.timeout
self.defer(
trigger=DataprocDeleteClusterTrigger(
gcp_conn_id=self.gcp_conn_id,
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
request_id=self.request_id,
retry=self.retry,
end_time=end_time,
metadata=self.metadata,
impersonation_chain=self.impersonation_chain,
polling_interval=self.polling_interval_seconds,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> Any:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event and event["status"] == "error":
raise AirflowException(event["message"])
elif event is None:
raise AirflowException("No event received in trigger callback")
self.log.info("Cluster deleted.")

def _delete_cluster(self, hook: DataprocHook):
self.log.info("Deleting cluster: %s", self.cluster_name)
operation = hook.delete_cluster(
return hook.delete_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
Expand All @@ -866,8 +910,6 @@ def execute(self, context: Context) -> None:
timeout=self.timeout,
metadata=self.metadata,
)
operation.result()
self.log.info("Cluster deleted.")


class DataprocJobBaseOperator(BaseOperator):
Expand Down
91 changes: 91 additions & 0 deletions airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from __future__ import annotations

import asyncio
import time
import warnings
from typing import Any, AsyncIterator, Sequence

from google.api_core.exceptions import NotFound
from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus

from airflow import AirflowException
Expand Down Expand Up @@ -219,3 +221,92 @@ async def run(self):
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"batch_id": self.batch_id, "batch_state": state})


class DataprocDeleteClusterTrigger(BaseTrigger):
"""
Asynchronously checks the status of a cluster.
:param cluster_name: The name of the cluster
:param end_time: Time in second left to check the cluster status
:param project_id: The ID of the Google Cloud project the cluster belongs to
:param region: The Cloud Dataproc region in which to handle the request
:param metadata: Additional metadata that is provided to the method
:param gcp_conn_id: The connection ID to use when fetching connection info.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account.
:param polling_interval: Time in seconds to sleep between checks of cluster status
"""

def __init__(
self,
cluster_name: str,
end_time: float,
project_id: str | None = None,
region: str | None = None,
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
polling_interval: float = 5.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.cluster_name = cluster_name
self.end_time = end_time
self.project_id = project_id
self.region = region
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.polling_interval = polling_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes DataprocDeleteClusterTrigger arguments and classpath."""
return (
"airflow.providers.google.cloud.triggers.dataproc.DataprocDeleteClusterTrigger",
{
"cluster_name": self.cluster_name,
"end_time": self.end_time,
"project_id": self.project_id,
"region": self.region,
"metadata": self.metadata,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval": self.polling_interval,
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
"""Wait until cluster is deleted completely"""
hook = self._get_hook()
while self.end_time > time.time():
try:
cluster = await hook.get_cluster(
region=self.region, # type: ignore[arg-type]
cluster_name=self.cluster_name,
project_id=self.project_id, # type: ignore[arg-type]
metadata=self.metadata,
)
self.log.info(
"Cluster status is %s. Sleeping for %s seconds.",
cluster.status.state,
self.polling_interval,
)
await asyncio.sleep(self.polling_interval)
except NotFound:
yield TriggerEvent({"status": "success", "message": ""})
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
yield TriggerEvent({"status": "error", "message": "Timeout"})

def _get_hook(self) -> DataprocAsyncHook:
return DataprocAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ To delete a cluster you can use:
:start-after: [START how_to_cloud_dataproc_delete_cluster_operator]
:end-before: [END how_to_cloud_dataproc_delete_cluster_operator]

You can use deferrable mode for this action in order to run the operator asynchronously:

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_dataproc_delete_cluster_operator_async]
:end-before: [END how_to_cloud_dataproc_delete_cluster_operator_async]

Submit a job to a cluster
-------------------------

Expand Down
42 changes: 42 additions & 0 deletions tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from airflow.providers.google.cloud.triggers.dataproc import (
DataprocBatchTrigger,
DataprocClusterTrigger,
DataprocDeleteClusterTrigger,
DataprocSubmitTrigger,
)
from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
Expand Down Expand Up @@ -875,6 +876,47 @@ def test_execute(self, mock_hook):
metadata=METADATA,
)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
mock_hook.return_value.create_cluster.return_value = None
operator = DataprocDeleteClusterOperator(
task_id=TASK_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
request_id=REQUEST_ID,
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
impersonation_chain=IMPERSONATION_CHAIN,
deferrable=True,
)

with pytest.raises(TaskDeferred) as exc:
operator.execute(mock.MagicMock())

mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)

mock_hook.return_value.delete_cluster.assert_called_once_with(
project_id=GCP_PROJECT,
region=GCP_REGION,
cluster_name=CLUSTER_NAME,
cluster_uuid=None,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)

mock_hook.return_value.wait_for_operation.assert_not_called()
assert isinstance(exc.value.trigger, DataprocDeleteClusterTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME


class TestDataprocSubmitJobOperator(DataprocJobTestBase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,16 @@
)
# [END how_to_cloud_dataproc_update_cluster_operator_async]

# [START how_to_cloud_dataproc_delete_cluster_operator_async]
delete_cluster = DataprocDeleteClusterOperator(
task_id="delete_cluster",
project_id=PROJECT_ID,
cluster_name=CLUSTER_NAME,
region=REGION,
trigger_rule=TriggerRule.ALL_DONE,
deferrable=True,
)
# [END how_to_cloud_dataproc_delete_cluster_operator_async]

create_cluster >> update_cluster >> delete_cluster

Expand Down

0 comments on commit 872df12

Please sign in to comment.