Skip to content

Commit

Permalink
Dataproc batches (#29136)
Browse files Browse the repository at this point in the history
* Make Dataproc batches reattach to running jobs.

* Dataproc Batches - moved reattach wait time up to the constructor

* rebasing on derrable work
  • Loading branch information
kristopherkane committed Feb 20, 2023
1 parent a770edf commit 7e3a9fc
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 29 deletions.
62 changes: 62 additions & 0 deletions airflow/providers/google/cloud/hooks/dataproc.py
Expand Up @@ -986,6 +986,68 @@ def list_batches(
)
return result

@GoogleBaseHook.fallback_to_default_project_id
def wait_for_batch(
self,
batch_id: str,
region: str,
project_id: str,
wait_check_interval: int = 10,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> Batch:
"""
Wait for a Batch job to complete.
After Batch job submission, the operator will wait for the job to complete, however, this is useful
in the case where Airflow is restarted or the task pid is killed for any reason. In this case, the
Batch create will happen again, AlreadyExists will be raised and caught, then should fall to this
function for waiting on completion.
:param batch_id: Required. The ID to use for the batch, which will become the final component
of the batch's resource name.
This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/.
:param region: Required. The Cloud Dataproc region in which to handle the request.
:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
:param wait_check_interval: The amount of time to pause between checks for job completion
:param retry: A retry object used to retry requests to get_batch.
If ``None`` is specified, requests will not be retried.
:param timeout: The amount of time, in seconds, to wait for the create_batch request to complete.
Note that if ``retry`` is specified, the timeout applies to each individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
state = None
first_loop: bool = True
while state not in [
Batch.State.CANCELLED,
Batch.State.FAILED,
Batch.State.SUCCEEDED,
Batch.State.STATE_UNSPECIFIED,
]:
try:
if not first_loop:
time.sleep(wait_check_interval)
first_loop = False
self.log.debug("Waiting for batch %s", batch_id)
result = self.get_batch(
batch_id=batch_id,
region=region,
project_id=project_id,
retry=retry,
timeout=timeout,
metadata=metadata,
)
state = result.state
except ServerError as err:
self.log.info(
"Retrying. Dataproc API returned server error when waiting for batch id %s: %s",
batch_id,
err,
)

return result


class DataprocAsyncHook(GoogleBaseHook):
"""
Expand Down
65 changes: 39 additions & 26 deletions airflow/providers/google/cloud/operators/dataproc.py
Expand Up @@ -2267,7 +2267,15 @@ def __init__(

def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
self.log.info("Creating batch")
# batch_id might not be set and will be generated
if self.batch_id:
link = DATAPROC_BATCH_LINK.format(
region=self.region, project_id=self.project_id, resource=self.batch_id
)
self.log.info("Creating batch %s", self.batch_id)
self.log.info("Once started, the batch job will be available at %s", link)
else:
self.log.info("Starting batch job. The batch ID will be generated since it was not provided.")
if self.region is None:
raise AirflowException("Region should be set here")
try:
Expand Down Expand Up @@ -2309,32 +2317,37 @@ def execute(self, context: Context):

except AlreadyExists:
self.log.info("Batch with given id already exists")
if self.batch_id is None:
raise AirflowException("Batch Id should be set here")
result = hook.get_batch(
batch_id=self.batch_id,
region=self.region,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
# The existing batch may be a number of states other than 'SUCCEEDED'
if result.state != Batch.State.SUCCEEDED:
if result.state == Batch.State.FAILED or result.state == Batch.State.CANCELLED:
raise AirflowException(
f"Existing Batch {self.batch_id} failed or cancelled. "
f"Error: {result.state_message}"
)
else:
# Batch state is either: RUNNING, PENDING, CANCELLING, or UNSPECIFIED
self.log.info(
f"Batch {self.batch_id} is in state {result.state.name}."
"Waiting for state change..."
)
result = hook.wait_for_operation(timeout=self.timeout, operation=result)

# This is only likely to happen if batch_id was provided
# Could be running if Airflow was restarted after task started
# poll until a final state is reached
if self.batch_id:
self.log.info("Attaching to the job (%s) if it is still running.", self.batch_id)
result = hook.wait_for_batch(
batch_id=self.batch_id,
region=self.region,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
wait_check_interval=self.polling_interval_seconds,
)
# It is possible we don't have a result in the case where batch_id was not provide, one was generated
# by chance, AlreadyExists was caught, but we can't reattach because we don't have the generated id
if result is None:
raise AirflowException("The job could not be reattached because the id was generated.")

# The existing batch may be a number of states other than 'SUCCEEDED'\
# wait_for_operation doesn't fail if the job is cancelled, so we will check for it here which also
# finds a cancelling|canceled|unspecified job from wait_for_batch
batch_id = self.batch_id or result.name.split("/")[-1]
link = DATAPROC_BATCH_LINK.format(region=self.region, project_id=self.project_id, resource=batch_id)
if result.state == Batch.State.FAILED:
raise AirflowException(f"Batch job {batch_id} failed. Driver Logs: {link}")
if result.state in (Batch.State.CANCELLED, Batch.State.CANCELLING):
raise AirflowException(f"Batch job {batch_id} was cancelled. Driver logs: {link}")
if result.state == Batch.State.STATE_UNSPECIFIED:
raise AirflowException(f"Batch job {batch_id} unspecified. Driver logs: {link}")
self.log.info("Batch job %s completed. Driver logs: %s", batch_id, link)
DataprocLink.persist(context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id)
return Batch.to_dict(result)

Expand Down
23 changes: 23 additions & 0 deletions tests/providers/google/cloud/hooks/test_dataproc.py
Expand Up @@ -23,6 +23,7 @@
import pytest
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud.dataproc_v1 import (
Batch,
BatchControllerAsyncClient,
ClusterControllerAsyncClient,
JobControllerAsyncClient,
Expand Down Expand Up @@ -419,6 +420,28 @@ def test_create_batch(self, mock_client):
timeout=None,
)

@mock.patch(DATAPROC_STRING.format("DataprocHook.get_batch"))
def test_wait_for_batch(self, mock_batch):
mock_batch.return_value = Batch(state=Batch.State.SUCCEEDED)
result: Batch = self.hook.wait_for_batch(
batch_id=BATCH_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
wait_check_interval=1,
retry=DEFAULT,
timeout=None,
metadata=(),
)
mock_batch.assert_called_once_with(
batch_id=BATCH_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
retry=DEFAULT,
timeout=None,
metadata=(),
)
assert result.state == Batch.State.SUCCEEDED

@mock.patch(DATAPROC_STRING.format("DataprocHook.get_batch_client"))
def test_delete_batch(self, mock_client):
self.hook.delete_batch(
Expand Down
85 changes: 82 additions & 3 deletions tests/providers/google/cloud/operators/test_dataproc.py
Expand Up @@ -1955,6 +1955,7 @@ def test_execute(self, mock_hook, to_dict_mock):
timeout=TIMEOUT,
metadata=METADATA,
)
mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED)
op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_batch.assert_called_once_with(
Expand Down Expand Up @@ -1985,6 +1986,7 @@ def test_execute_with_result_retry(self, mock_hook, to_dict_mock):
timeout=TIMEOUT,
metadata=METADATA,
)
mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED)
op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_batch.assert_called_once_with(
Expand Down Expand Up @@ -2014,14 +2016,91 @@ def test_execute_batch_failed(self, mock_hook, to_dict_mock):
timeout=TIMEOUT,
metadata=METADATA,
)
mock_hook.return_value.create_batch.side_effect = AlreadyExists("")
mock_hook.return_value.get_batch.return_value.state = Batch.State.FAILED
mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.FAILED)
with pytest.raises(AirflowException):
op.execute(context=MagicMock())
mock_hook.return_value.get_batch.assert_called_once_with(

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_batch_already_exists_succeeds(self, mock_hook):
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
region=GCP_REGION,
project_id=GCP_PROJECT,
batch=BATCH,
batch_id=BATCH_ID,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)
mock_hook.return_value.wait_for_operation.side_effect = AlreadyExists("")
mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.SUCCEEDED)
op.execute(context=MagicMock())
mock_hook.return_value.wait_for_batch.assert_called_once_with(
batch_id=BATCH_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
wait_check_interval=5,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_batch_already_exists_fails(self, mock_hook):
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
region=GCP_REGION,
project_id=GCP_PROJECT,
batch=BATCH,
batch_id=BATCH_ID,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)
mock_hook.return_value.wait_for_operation.side_effect = AlreadyExists("")
mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.FAILED)
with pytest.raises(AirflowException):
op.execute(context=MagicMock())
mock_hook.return_value.wait_for_batch.assert_called_once_with(
batch_id=BATCH_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
wait_check_interval=10,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_batch_already_exists_cancelled(self, mock_hook):
op = DataprocCreateBatchOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
region=GCP_REGION,
project_id=GCP_PROJECT,
batch=BATCH,
batch_id=BATCH_ID,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)
mock_hook.return_value.wait_for_operation.side_effect = AlreadyExists("")
mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.CANCELLED)
with pytest.raises(AirflowException):
op.execute(context=MagicMock())
mock_hook.return_value.wait_for_batch.assert_called_once_with(
batch_id=BATCH_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
wait_check_interval=10,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
Expand Down

0 comments on commit 7e3a9fc

Please sign in to comment.