Skip to content

Commit

Permalink
Add DataprocCancelOperationOperator (#28456)
Browse files Browse the repository at this point in the history
Co-authored-by: Beata Kossakowska <[email protected]>
  • Loading branch information
bkossakowska and Beata Kossakowska committed Jan 16, 2023
1 parent d24527b commit dc3a3c7
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 10 deletions.
4 changes: 4 additions & 0 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ def get_batch_client(self, region: str | None = None) -> BatchControllerClient:
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
)

def get_operations_client(self, region):
"""Returns OperationsClient"""
return self.get_batch_client(region=region).transport.operations_client

def wait_for_operation(
self,
operation: Operation,
Expand Down
73 changes: 68 additions & 5 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,6 +2055,9 @@ class DataprocCreateBatchOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param asynchronous: Flag to return after creating batch to the Dataproc API.
This is useful for creating long-running batch and
waiting on them asynchronously using the DataprocBatchSensor
"""

template_fields: Sequence[str] = (
Expand All @@ -2080,6 +2083,7 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
result_retry: Retry | _MethodDefault = DEFAULT,
asynchronous: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -2095,6 +2099,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.operation: operation.Operation | None = None
self.asynchronous = asynchronous

def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
Expand All @@ -2114,10 +2119,13 @@ def execute(self, context: Context):
)
if self.operation is None:
raise RuntimeError("The operation should be set here!")
result = hook.wait_for_operation(
timeout=self.timeout, result_retry=self.result_retry, operation=self.operation
)
self.log.info("Batch %s created", self.batch_id)
if not self.asynchronous:
result = hook.wait_for_operation(
timeout=self.timeout, result_retry=self.result_retry, operation=self.operation
)
self.log.info("Batch %s created", self.batch_id)
else:
return self.operation.operation.name
except AlreadyExists:
self.log.info("Batch with given id already exists")
if self.batch_id is None:
Expand All @@ -2130,7 +2138,6 @@ def execute(self, context: Context):
timeout=self.timeout,
metadata=self.metadata,
)

# The existing batch may be a number of states other than 'SUCCEEDED'
if result.state != Batch.State.SUCCEEDED:
if result.state == Batch.State.FAILED or result.state == Batch.State.CANCELLED:
Expand Down Expand Up @@ -2355,3 +2362,59 @@ def execute(self, context: Context):
)
DataprocListLink.persist(context=context, task_instance=self, url=DATAPROC_BATCHES_LINK)
return [Batch.to_dict(result) for result in results]


class DataprocCancelOperationOperator(BaseOperator):
"""
Cancel the batch workload resource.
:param operation_name: Required. The name of the operation resource to be cancelled.
:param region: Required. The Cloud Dataproc region in which to handle the request.
:param project_id: Optional. The ID of the Google Cloud project that the cluster belongs to.
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:param metadata: Additional metadata that is provided to the method.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields: Sequence[str] = ("operation_name", "region", "project_id", "impersonation_chain")

def __init__(
self,
*,
operation_name: str,
region: str,
project_id: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.operation_name = operation_name
self.region = region
self.project_id = project_id
self.retry = retry
self.timeout = timeout
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)

self.log.info("Canceling operation: %s", self.operation_name)
hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name)
self.log.info("Operation canceled.")
76 changes: 75 additions & 1 deletion airflow/providers/google/cloud/sensors/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import TYPE_CHECKING, Sequence

from google.api_core.exceptions import ServerError
from google.cloud.dataproc_v1.types import JobStatus
from google.cloud.dataproc_v1.types import Batch, JobStatus

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook
Expand Down Expand Up @@ -109,3 +109,77 @@ def poke(self, context: Context) -> bool:

self.log.info("Waiting for job %s to complete.", self.dataproc_job_id)
return False


class DataprocBatchSensor(BaseSensorOperator):
"""
Check for the state of batch.
:param batch_id: The Dataproc batch ID to poll. (templated)
:param region: Required. The Cloud Dataproc region in which to handle the request. (templated)
:param project_id: The ID of the google cloud project in which
to create the cluster. (templated)
:param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform.
:param wait_timeout: How many seconds wait for job to be ready.
"""

template_fields: Sequence[str] = ("project_id", "region", "batch_id")
ui_color = "#f0eee4"

def __init__(
self,
*,
batch_id: str,
region: str,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
wait_timeout: int | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.batch_id = batch_id
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.region = region
self.wait_timeout = wait_timeout
self.start_sensor_time: float | None = None

def execute(self, context: Context) -> None:
self.start_sensor_time = time.monotonic()
super().execute(context)

def _duration(self):
return time.monotonic() - self.start_sensor_time

def poke(self, context: Context) -> bool:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
if self.wait_timeout:
try:
batch = hook.get_batch(batch_id=self.batch_id, region=self.region, project_id=self.project_id)
except ServerError as err:
duration = self._duration()
self.log.info("DURATION RUN: %f", duration)

if duration > self.wait_timeout:
raise AirflowException(
f"Timeout: dataproc batch {self.batch_id} is not ready after {self.wait_timeout}s"
)
self.log.info("Retrying. Dataproc API returned server error when waiting for batch: %s", err)
return False
else:
batch = hook.get_batch(batch_id=self.batch_id, region=self.region, project_id=self.project_id)

state = batch.state
if state == Batch.State.FAILED:
raise AirflowException("Batch failed")
elif state in {
Batch.State.CANCELLED,
Batch.State.CANCELLING,
}:
raise AirflowException("Batch was cancelled.")
elif state == Batch.State.SUCCEEDED:
self.log.debug("Batch %s completed successfully.", self.batch_id)
return True

self.log.info("Waiting for the batch %s to complete.", self.batch_id)
return False
22 changes: 22 additions & 0 deletions docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,16 @@ After Cluster was created you should add it to the Batch configuration.
:start-after: [START how_to_cloud_dataproc_create_batch_operator_with_persistent_history_server]
:end-before: [END how_to_cloud_dataproc_create_batch_operator_with_persistent_history_server]

To check if operation succeeded you can use

:class:`~airflow.providers.google.cloud.sensors.dataproc.DataprocBatchSensor`.

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_batch.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_dataproc_batch_async_sensor]
:end-before: [END how_to_cloud_dataproc_batch_async_sensor]

Get a Batch
-----------

Expand Down Expand Up @@ -315,6 +325,18 @@ To delete a batch you can use:
:start-after: [START how_to_cloud_dataproc_delete_batch_operator]
:end-before: [END how_to_cloud_dataproc_delete_batch_operator]

Cancel a Batch Operation
------------------------

To cancel a operation you can use:
:class: ``~airflow.providers.google.cloud.operators.dataproc.DataprocCancelOperationOperator``.

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_batch.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_dataproc_cancel_operation_operator]
:end-before: [END how_to_cloud_dataproc_cancel_operation_operator]

References
^^^^^^^^^^
For further information, take a look at:
Expand Down
73 changes: 71 additions & 2 deletions tests/providers/google/cloud/sensors/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@

import pytest
from google.api_core.exceptions import ServerError
from google.cloud.dataproc_v1.types import JobStatus
from google.cloud.dataproc_v1.types import Batch, JobStatus

from airflow import AirflowException
from airflow.providers.google.cloud.sensors.dataproc import DataprocJobSensor
from airflow.providers.google.cloud.sensors.dataproc import DataprocBatchSensor, DataprocJobSensor
from airflow.version import version as airflow_version

AIRFLOW_VERSION = "v" + airflow_version.replace(".", "-").replace("+", "-")
Expand Down Expand Up @@ -184,3 +184,72 @@ def test_wait_timeout_raise_exception(self, mock_hook):

with pytest.raises(AirflowException, match="Timeout: dataproc job job_id is not ready after 300s"):
sensor.poke(context={})


class TestDataprocBatchSensor(unittest.TestCase):
def create_batch(self, state: int):
batch = mock.Mock()
batch.state = mock.Mock()
batch.state = state
return batch

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_succeeded(self, mock_hook):
batch = self.create_batch(Batch.State.SUCCEEDED)
mock_hook.return_value.get_batch.return_value = batch

sensor = DataprocBatchSensor(
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
batch_id="batch_id",
poke_interval=10,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
)
ret = sensor.poke(context={})
mock_hook.return_value.get_batch.assert_called_once_with(
batch_id="batch_id", region=GCP_LOCATION, project_id=GCP_PROJECT
)
assert ret

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_cancelled(self, mock_hook):
batch = self.create_batch(Batch.State.CANCELLED)
mock_hook.return_value.get_batch.return_value = batch

sensor = DataprocBatchSensor(
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
batch_id="batch_id",
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
)
with pytest.raises(AirflowException, match="Batch was cancelled."):
sensor.poke(context={})

mock_hook.return_value.get_batch.assert_called_once_with(
batch_id="batch_id", region=GCP_LOCATION, project_id=GCP_PROJECT
)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_error(self, mock_hook):
batch = self.create_batch(Batch.State.FAILED)
mock_hook.return_value.get_batch.return_value = batch

sensor = DataprocBatchSensor(
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
batch_id="batch_id",
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
)

with pytest.raises(AirflowException, match="Batch failed"):
sensor.poke(context={})

mock_hook.return_value.get_batch.assert_called_once_with(
batch_id="batch_id", region=GCP_LOCATION, project_id=GCP_PROJECT
)

0 comments on commit dc3a3c7

Please sign in to comment.