Skip to content

Commit

Permalink
Add deferrable mode to DataprocCreateBatchOperator (#28457)
Browse files Browse the repository at this point in the history
  • Loading branch information
bkossakowska committed Jan 30, 2023
1 parent 5503587 commit 9d93517
Show file tree
Hide file tree
Showing 6 changed files with 385 additions and 12 deletions.
62 changes: 54 additions & 8 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@
DataprocLink,
DataprocListLink,
)
from airflow.providers.google.cloud.triggers.dataproc import DataprocClusterTrigger, DataprocSubmitTrigger
from airflow.providers.google.cloud.triggers.dataproc import (
DataprocBatchTrigger,
DataprocClusterTrigger,
DataprocSubmitTrigger,
)
from airflow.utils import timezone

if TYPE_CHECKING:
Expand Down Expand Up @@ -2134,6 +2138,8 @@ class DataprocCreateBatchOperator(BaseOperator):
:param asynchronous: Flag to return after creating batch to the Dataproc API.
This is useful for creating long-running batch and
waiting on them asynchronously using the DataprocBatchSensor
:param deferrable: Run operator in the deferrable mode.
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
"""

template_fields: Sequence[str] = (
Expand All @@ -2151,7 +2157,7 @@ def __init__(
region: str | None = None,
project_id: str | None = None,
batch: dict | Batch,
batch_id: str | None = None,
batch_id: str,
request_id: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
Expand All @@ -2160,9 +2166,13 @@ def __init__(
impersonation_chain: str | Sequence[str] | None = None,
result_retry: Retry | _MethodDefault = DEFAULT,
asynchronous: bool = False,
deferrable: bool = False,
polling_interval_seconds: int = 5,
**kwargs,
):
super().__init__(**kwargs)
if deferrable and polling_interval_seconds <= 0:
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
self.region = region
self.project_id = project_id
self.batch = batch
Expand All @@ -2176,6 +2186,8 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.operation: operation.Operation | None = None
self.asynchronous = asynchronous
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds

def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
Expand All @@ -2195,13 +2207,30 @@ def execute(self, context: Context):
)
if self.operation is None:
raise RuntimeError("The operation should be set here!")
if not self.asynchronous:
result = hook.wait_for_operation(
timeout=self.timeout, result_retry=self.result_retry, operation=self.operation
)
self.log.info("Batch %s created", self.batch_id)

if not self.deferrable:
if not self.asynchronous:
result = hook.wait_for_operation(
timeout=self.timeout, result_retry=self.result_retry, operation=self.operation
)
self.log.info("Batch %s created", self.batch_id)

else:
return self.operation.operation.name

else:
return self.operation.operation.name
self.defer(
trigger=DataprocBatchTrigger(
batch_id=self.batch_id,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)

except AlreadyExists:
self.log.info("Batch with given id already exists")
if self.batch_id is None:
Expand Down Expand Up @@ -2233,6 +2262,23 @@ def execute(self, context: Context):
DataprocLink.persist(context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id)
return Batch.to_dict(result)

def execute_complete(self, context, event=None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event is None:
raise AirflowException("Batch failed.")
batch_state = event["batch_state"]
batch_id = event["batch_id"]

if batch_state == Batch.State.FAILED:
raise AirflowException(f"Batch failed:\n{batch_id}")
if batch_state == Batch.State.CANCELLED:
raise AirflowException(f"Batch was cancelled:\n{batch_id}")
self.log.info("%s completed successfully.", self.task_id)

def on_kill(self):
if self.operation:
self.operation.cancel()
Expand Down
72 changes: 71 additions & 1 deletion airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import warnings
from typing import Any, AsyncIterator, Sequence

from google.cloud.dataproc_v1 import ClusterStatus, JobStatus
from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus

from airflow import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
Expand Down Expand Up @@ -149,3 +149,73 @@ def _get_hook(self) -> DataprocAsyncHook:
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)


class DataprocBatchTrigger(BaseTrigger):
"""
DataprocCreateBatchTrigger run on the trigger worker to perform create Build operation
:param batch_id: The ID of the build.
:param project_id: Google Cloud Project where the job is running
:param region: The Cloud Dataproc region in which to handle the request.
:param gcp_conn_id: Optional, the connection ID used to connect to Google Cloud Platform.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param polling_interval_seconds: polling period in seconds to check for the status
"""

def __init__(
self,
batch_id: str,
region: str,
project_id: str | None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
polling_interval_seconds: float = 5.0,
):
super().__init__()
self.batch_id = batch_id
self.project_id = project_id
self.region = region
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.polling_interval_seconds = polling_interval_seconds

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes DataprocBatchTrigger arguments and classpath."""
return (
"airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger",
{
"batch_id": self.batch_id,
"project_id": self.project_id,
"region": self.region,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
},
)

async def run(self):
hook = DataprocAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

while True:
batch = await hook.get_batch(
project_id=self.project_id, region=self.region, batch_id=self.batch_id
)
state = batch.state

if state in (Batch.State.FAILED, Batch.State.SUCCEEDED, Batch.State.CANCELLED):
break
self.log.info("Current state is %s", state)
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"batch_id": self.batch_id, "batch_state": state})
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,14 @@ To check if operation succeeded you can use
:start-after: [START how_to_cloud_dataproc_batch_async_sensor]
:end-before: [END how_to_cloud_dataproc_batch_async_sensor]

Also for all this action you can use operator in the deferrable mode:

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_batch_deferrable.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_dataproc_create_batch_operator_async]
:end-before: [END how_to_cloud_dataproc_create_batch_operator_async]

Get a Batch
-----------

Expand Down
47 changes: 46 additions & 1 deletion tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@
DataprocSubmitSparkSqlJobOperator,
DataprocUpdateClusterOperator,
)
from airflow.providers.google.cloud.triggers.dataproc import DataprocClusterTrigger, DataprocSubmitTrigger
from airflow.providers.google.cloud.triggers.dataproc import (
DataprocBatchTrigger,
DataprocClusterTrigger,
DataprocSubmitTrigger,
)
from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils.timezone import datetime
Expand Down Expand Up @@ -2032,3 +2036,44 @@ def test_execute(self, mock_hook):
timeout=TIMEOUT,
metadata=METADATA,
)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
def test_execute_deferrable(self, mock_trigger_hook, mock_hook):
mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID

op = DataprocCreateBatchOperator(
task_id=TASK_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
batch=BATCH,
batch_id="batch_id",
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
request_id=REQUEST_ID,
impersonation_chain=IMPERSONATION_CHAIN,
deferrable=True,
)
with pytest.raises(TaskDeferred) as exc:
op.execute(mock.MagicMock())

mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.create_batch.assert_called_once_with(
region=GCP_REGION,
project_id=GCP_PROJECT,
batch_id="batch_id",
batch=BATCH,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)
mock_hook.return_value.wait_for_job.assert_not_called()

assert isinstance(exc.value.trigger, DataprocBatchTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME

0 comments on commit 9d93517

Please sign in to comment.