Skip to content

Commit

Permalink
Add operator to diagnose cluster (#36899)
Browse files Browse the repository at this point in the history
1. Make diagnose_cluster hook return Operation object just like the rest of the hooks.
2. Rename DataprocWorkflowTrigger to DataprocOperationTrigger handle all types of
operations for get_operation
  • Loading branch information
flacode committed Jan 25, 2024
1 parent 67c774e commit 10ad8d9
Show file tree
Hide file tree
Showing 12 changed files with 565 additions and 49 deletions.
58 changes: 44 additions & 14 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import time
import uuid
from collections.abc import MutableSequence
from typing import TYPE_CHECKING, Any, Sequence

from google.api_core.client_options import ClientOptions
Expand Down Expand Up @@ -54,6 +55,7 @@
from google.api_core.retry_async import AsyncRetry
from google.protobuf.duration_pb2 import Duration
from google.protobuf.field_mask_pb2 import FieldMask
from google.type.interval_pb2 import Interval


class DataProcJobBuilder:
Expand Down Expand Up @@ -386,17 +388,25 @@ def diagnose_cluster(
region: str,
cluster_name: str,
project_id: str,
tarball_gcs_dir: str | None = None,
diagnosis_interval: dict | Interval | None = None,
jobs: MutableSequence[str] | None = None,
yarn_application_ids: MutableSequence[str] | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> str:
) -> Operation:
"""Get cluster diagnostic information.
After the operation completes, the GCS URI to diagnose is returned.
After the operation completes, the response contains the Cloud Storage URI of the diagnostic output report containing a summary of collected diagnostics.
:param project_id: Google Cloud project ID that the cluster belongs to.
:param region: Cloud Dataproc region in which to handle the request.
:param cluster_name: Name of the cluster.
:param tarball_gcs_dir: The output Cloud Storage directory for the diagnostic tarball. If not specified, a task-specific directory in the cluster's staging bucket will be used.
:param diagnosis_interval: Time interval in which diagnosis should be carried out on the cluster.
:param jobs: Specifies a list of jobs on which diagnosis is to be performed. Format: `projects/{project}/regions/{region}/jobs/{job}`
:param yarn_application_ids: Specifies a list of yarn applications on which diagnosis is to be performed.
:param retry: A retry object used to retry requests. If *None*, requests
will not be retried.
:param timeout: The amount of time, in seconds, to wait for the request
Expand All @@ -405,15 +415,21 @@ def diagnose_cluster(
:param metadata: Additional metadata that is provided to the method.
"""
client = self.get_cluster_client(region=region)
operation = client.diagnose_cluster(
request={"project_id": project_id, "region": region, "cluster_name": cluster_name},
result = client.diagnose_cluster(
request={
"project_id": project_id,
"region": region,
"cluster_name": cluster_name,
"tarball_gcs_dir": tarball_gcs_dir,
"diagnosis_interval": diagnosis_interval,
"jobs": jobs,
"yarn_application_ids": yarn_application_ids,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
operation.result()
gcs_uri = str(operation.operation.response.value)
return gcs_uri
return result

@GoogleBaseHook.fallback_to_default_project_id
def get_cluster(
Expand Down Expand Up @@ -1243,17 +1259,25 @@ async def diagnose_cluster(
region: str,
cluster_name: str,
project_id: str,
tarball_gcs_dir: str | None = None,
diagnosis_interval: dict | Interval | None = None,
jobs: MutableSequence[str] | None = None,
yarn_application_ids: MutableSequence[str] | None = None,
retry: AsyncRetry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> str:
) -> AsyncOperation:
"""Get cluster diagnostic information.
After the operation completes, the GCS URI to diagnose is returned.
After the operation completes, the response contains the Cloud Storage URI of the diagnostic output report containing a summary of collected diagnostics.
:param project_id: Google Cloud project ID that the cluster belongs to.
:param region: Cloud Dataproc region in which to handle the request.
:param cluster_name: Name of the cluster.
:param tarball_gcs_dir: The output Cloud Storage directory for the diagnostic tarball. If not specified, a task-specific directory in the cluster's staging bucket will be used.
:param diagnosis_interval: Time interval in which diagnosis should be carried out on the cluster.
:param jobs: Specifies a list of jobs on which diagnosis is to be performed. Format: `projects/{project}/regions/{region}/jobs/{job}`
:param yarn_application_ids: Specifies a list of yarn applications on which diagnosis is to be performed.
:param retry: A retry object used to retry requests. If *None*, requests
will not be retried.
:param timeout: The amount of time, in seconds, to wait for the request
Expand All @@ -1262,15 +1286,21 @@ async def diagnose_cluster(
:param metadata: Additional metadata that is provided to the method.
"""
client = self.get_cluster_client(region=region)
operation = await client.diagnose_cluster(
request={"project_id": project_id, "region": region, "cluster_name": cluster_name},
result = await client.diagnose_cluster(
request={
"project_id": project_id,
"region": region,
"cluster_name": cluster_name,
"tarball_gcs_dir": tarball_gcs_dir,
"diagnosis_interval": diagnosis_interval,
"jobs": jobs,
"yarn_application_ids": yarn_application_ids,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
operation.result()
gcs_uri = str(operation.operation.response.value)
return gcs_uri
return result

@GoogleBaseHook.fallback_to_default_project_id
async def get_cluster(
Expand Down
150 changes: 146 additions & 4 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import time
import uuid
import warnings
from collections.abc import MutableSequence
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
Expand Down Expand Up @@ -56,16 +57,18 @@
DataprocBatchTrigger,
DataprocClusterTrigger,
DataprocDeleteClusterTrigger,
DataprocOperationTrigger,
DataprocSubmitTrigger,
DataprocWorkflowTrigger,
)
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
from airflow.utils import timezone

if TYPE_CHECKING:
from google.api_core import operation
from google.api_core.retry_async import AsyncRetry
from google.protobuf.duration_pb2 import Duration
from google.protobuf.field_mask_pb2 import FieldMask
from google.type.interval_pb2 import Interval

from airflow.utils.context import Context

Expand Down Expand Up @@ -681,10 +684,13 @@ def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None:
return
self.log.info("Cluster is in ERROR state")
self.log.info("Gathering diagnostic information.")
gcs_uri = hook.diagnose_cluster(
operation = hook.diagnose_cluster(
region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
)
operation.result()
gcs_uri = str(operation.operation.response.value)
self.log.info("Diagnostic information for cluster %s available at: %s", self.cluster_name, gcs_uri)

if self.delete_on_error:
self._delete_cluster(hook)
# The delete op is asynchronous and can cause further failure if the cluster finishes
Expand Down Expand Up @@ -2054,7 +2060,7 @@ def execute(self, context: Context):
self.log.info("Workflow %s completed successfully", workflow_id)
else:
self.defer(
trigger=DataprocWorkflowTrigger(
trigger=DataprocOperationTrigger(
name=operation_name,
project_id=self.project_id,
region=self.region,
Expand Down Expand Up @@ -2196,7 +2202,7 @@ def execute(self, context: Context):
self.log.info("Workflow %s completed successfully", workflow_id)
else:
self.defer(
trigger=DataprocWorkflowTrigger(
trigger=DataprocOperationTrigger(
name=operation_name,
project_id=self.project_id or hook.project_id,
region=self.region,
Expand Down Expand Up @@ -2530,6 +2536,142 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
self.log.info("%s completed successfully.", self.task_id)


class DataprocDiagnoseClusterOperator(GoogleCloudBaseOperator):
"""Diagnose a cluster in a project.
After the operation completes, the response contains the Cloud Storage URI of the diagnostic output report containing a summary of collected diagnostics.
:param region: Required. The Cloud Dataproc region in which to handle the request (templated).
:param project_id: Optional. The ID of the Google Cloud project that the cluster belongs to (templated).
:param cluster_name: Required. The cluster name (templated).
:param tarball_gcs_dir: The output Cloud Storage directory for the diagnostic tarball. If not specified, a task-specific directory in the cluster's staging bucket will be used.
:param diagnosis_interval: Time interval in which diagnosis should be carried out on the cluster.
:param jobs: Specifies a list of jobs on which diagnosis is to be performed. Format: `projects/{project}/regions/{region}/jobs/{job}`
:param yarn_application_ids: Specifies a list of yarn applications on which diagnosis is to be performed.
:param metadata: Additional metadata that is provided to the method.
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param polling_interval_seconds: Time (seconds) to wait between calls to check the cluster status.
"""

template_fields: Sequence[str] = (
"project_id",
"region",
"cluster_name",
"impersonation_chain",
"tarball_gcs_dir",
"diagnosis_interval",
"jobs",
"yarn_application_ids",
)

def __init__(
self,
*,
region: str,
cluster_name: str,
project_id: str | None = None,
tarball_gcs_dir: str | None = None,
diagnosis_interval: dict | Interval | None = None,
jobs: MutableSequence[str] | None = None,
yarn_application_ids: MutableSequence[str] | None = None,
retry: AsyncRetry | _MethodDefault = DEFAULT,
timeout: float = 1 * 60 * 60,
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
polling_interval_seconds: int = 10,
**kwargs,
):
super().__init__(**kwargs)
if deferrable and polling_interval_seconds <= 0:
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
self.project_id = project_id
self.region = region
self.cluster_name = cluster_name
self.tarball_gcs_dir = tarball_gcs_dir
self.diagnosis_interval = diagnosis_interval
self.jobs = jobs
self.yarn_application_ids = yarn_application_ids
self.retry = retry
self.timeout = timeout
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds

def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
self.log.info("Collecting diagnostic tarball for cluster: %s", self.cluster_name)
operation = hook.diagnose_cluster(
region=self.region,
cluster_name=self.cluster_name,
project_id=self.project_id,
tarball_gcs_dir=self.tarball_gcs_dir,
diagnosis_interval=self.diagnosis_interval,
jobs=self.jobs,
yarn_application_ids=self.yarn_application_ids,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)

if not self.deferrable:
result = hook.wait_for_operation(
timeout=self.timeout, result_retry=self.retry, operation=operation
)
self.log.info(
"The diagnostic output for cluster %s is available at: %s",
self.cluster_name,
result.output_uri,
)
else:
self.defer(
trigger=DataprocOperationTrigger(
name=operation.operation.name,
operation_type=DataprocOperationType.DIAGNOSE.value,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
"""Callback for when the trigger fires.
This returns immediately. It relies on trigger to throw an exception,
otherwise it assumes execution was successful.
"""
if event:
status = event.get("status")
if status in ("failed", "error"):
self.log.exception("Unexpected error in the operation.")
raise AirflowException(event.get("message"))

self.log.info(
"The diagnostic output for cluster %s is available at: %s",
self.cluster_name,
event.get("output_uri"),
)


class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
"""Create a batch workload.
Expand Down

0 comments on commit 10ad8d9

Please sign in to comment.