Skip to content

Commit

Permalink
Make the deferrable version of DataprocCreateBatchOperator handle a b…
Browse files Browse the repository at this point in the history
…atch_id that already exists (#32216)
  • Loading branch information
kristopherkane committed Jun 29, 2023
1 parent f2e2125 commit 7d2ec76
Showing 1 changed file with 52 additions and 35 deletions.
87 changes: 52 additions & 35 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def __init__(
enable_component_gateway: bool | None = False,
**kwargs,
) -> None:

self.project_id = project_id
self.num_masters = num_masters
self.num_workers = num_workers
Expand Down Expand Up @@ -489,7 +488,6 @@ def __init__(
polling_interval_seconds: int = 10,
**kwargs,
) -> None:

# TODO: remove one day
if cluster_config is None and virtual_cluster_config is None:
warnings.warn(
Expand Down Expand Up @@ -2333,6 +2331,7 @@ def execute(self, context: Context):
return self.operation.operation.name

else:
# processing ends in execute_complete
self.defer(
trigger=DataprocBatchTrigger(
batch_id=self.batch_id,
Expand All @@ -2350,35 +2349,35 @@ def execute(self, context: Context):
# This is only likely to happen if batch_id was provided
# Could be running if Airflow was restarted after task started
# poll until a final state is reached
if self.batch_id:
self.log.info("Attaching to the job (%s) if it is still running.", self.batch_id)
result = hook.wait_for_batch(
batch_id=self.batch_id,
region=self.region,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
wait_check_interval=self.polling_interval_seconds,

self.log.info("Attaching to the job %s if it is still running.", self.batch_id)

# deferrable handling of a batch_id that already exists - processing ends in execute_complete
if self.deferrable:
self.defer(
trigger=DataprocBatchTrigger(
batch_id=self.batch_id,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)
# It is possible we don't have a result in the case where batch_id was not provide, one was generated
# by chance, AlreadyExists was caught, but we can't reattach because we don't have the generated id
if result is None:
raise AirflowException("The job could not be reattached because the id was generated.")

# The existing batch may be a number of states other than 'SUCCEEDED'\
# wait_for_operation doesn't fail if the job is cancelled, so we will check for it here which also
# finds a cancelling|canceled|unspecified job from wait_for_batch
# non-deferrable handling of a batch_id that already exists
result = hook.wait_for_batch(
batch_id=self.batch_id,
region=self.region,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
wait_check_interval=self.polling_interval_seconds,
)
batch_id = self.batch_id or result.name.split("/")[-1]
link = DATAPROC_BATCH_LINK.format(region=self.region, project_id=self.project_id, resource=batch_id)
if result.state == Batch.State.FAILED:
raise AirflowException(f"Batch job {batch_id} failed. Driver Logs: {link}")
if result.state in (Batch.State.CANCELLED, Batch.State.CANCELLING):
raise AirflowException(f"Batch job {batch_id} was cancelled. Driver logs: {link}")
if result.state == Batch.State.STATE_UNSPECIFIED:
raise AirflowException(f"Batch job {batch_id} unspecified. Driver logs: {link}")
self.log.info("Batch job %s completed. Driver logs: %s", batch_id, link)
DataprocLink.persist(context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id)
self.handle_batch_status(context, result.state, batch_id)
return Batch.to_dict(result)

def execute_complete(self, context, event=None) -> None:
Expand All @@ -2389,19 +2388,37 @@ def execute_complete(self, context, event=None) -> None:
"""
if event is None:
raise AirflowException("Batch failed.")
batch_state = event["batch_state"]
state = event["batch_state"]
batch_id = event["batch_id"]

if batch_state == Batch.State.FAILED:
raise AirflowException(f"Batch failed:\n{batch_id}")
if batch_state == Batch.State.CANCELLED:
raise AirflowException(f"Batch was cancelled:\n{batch_id}")
self.log.info("%s completed successfully.", self.task_id)
self.handle_batch_status(context, state, batch_id)

def on_kill(self):
if self.operation:
self.operation.cancel()

def handle_batch_status(self, context: Context, state: Batch.State, batch_id: str) -> None:
# The existing batch may be a number of states other than 'SUCCEEDED'\
# wait_for_operation doesn't fail if the job is cancelled, so we will check for it here which also
# finds a cancelling|canceled|unspecified job from wait_for_batch or the deferred trigger
link = DATAPROC_BATCH_LINK.format(region=self.region, project_id=self.project_id, resource=batch_id)
if state == Batch.State.FAILED:
DataprocLink.persist(
context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id
)
raise AirflowException("Batch job %s failed. Driver Logs: %s", batch_id, link)
if state in (Batch.State.CANCELLED, Batch.State.CANCELLING):
DataprocLink.persist(
context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id
)
raise AirflowException("Batch job %s was cancelled. Driver logs: %s", batch_id, link)
if state == Batch.State.STATE_UNSPECIFIED:
DataprocLink.persist(
context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id
)
raise AirflowException("Batch job %s unspecified. Driver logs: %s", batch_id, link)
self.log.info("Batch job %s completed. Driver logs: %s", batch_id, link)
DataprocLink.persist(context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id)


class DataprocDeleteBatchOperator(GoogleCloudBaseOperator):
"""Delete the batch workload resource.
Expand Down

0 comments on commit 7d2ec76

Please sign in to comment.