Skip to content

Commit

Permalink
Improve handling Dataproc cluster creation with ERROR state (#9593)
Browse files Browse the repository at this point in the history
Handle cluster in DELETING state

Extend tests

fixup! Extend tests

fixup! fixup! Extend tests

fixup! fixup! fixup! Extend tests
  • Loading branch information
turbaszek committed Aug 6, 2020
1 parent 1e36666 commit 0103226
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 40 deletions.
10 changes: 6 additions & 4 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@ def diagnose_cluster(
metadata: Optional[Sequence[Tuple[str, str]]] = None,
):
"""
Gets cluster diagnostic information. After the operation completes, the Operation.response field
contains ``DiagnoseClusterOutputLocation``.
Gets cluster diagnostic information. After the operation completes GCS uri to
diagnose is returned
:param project_id: Required. The ID of the Google Cloud Platform project that the cluster belongs to.
:type project_id: str
Expand All @@ -376,15 +376,17 @@ def diagnose_cluster(
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_cluster_client(location=region)
result = client.diagnose_cluster(
operation = client.diagnose_cluster(
project_id=project_id,
region=region,
cluster_name=cluster_name,
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result
operation.result()
gcs_uri = str(operation.operation.response.value)
return gcs_uri

@GoogleBaseHook.fallback_to_default_project_id
def get_cluster(
Expand Down
167 changes: 133 additions & 34 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union

from google.api_core.exceptions import AlreadyExists
from google.api_core.retry import Retry
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.retry import Retry, exponential_sleep_generator
from google.cloud.dataproc_v1beta2.types import ( # pylint: disable=no-name-in-module
Cluster, Duration, FieldMask,
)
Expand Down Expand Up @@ -418,9 +418,14 @@ def make(self):
class DataprocCreateClusterOperator(BaseOperator):
"""
Create a new cluster on Google Cloud Dataproc. The operator will wait until the
creation is successful or an error occurs in the creation process.
creation is successful or an error occurs in the creation process. If the cluster
already exists and ``use_if_exists`` is True then the operator will:
The parameters allow to configure the cluster. Please refer to
- if cluster state is ERROR then delete it if specified and raise error
- if cluster state is CREATING wait for it and then check for ERROR state
- if cluster state is DELETING wait for it and then create new cluster
Please refer to
https://cloud.google.com/dataproc/docs/reference/rest/v1/projects.regions.clusters
Expand All @@ -436,6 +441,11 @@ class DataprocCreateClusterOperator(BaseOperator):
:type project_id: str
:param region: leave as 'global', might become relevant in the future. (templated)
:type region: str
:parm delete_on_error: If true the cluster will be deleted if created with ERROR state. Default
value is true.
:type delete_on_error: bool
:parm use_if_exists: If true use existing cluster
:type use_if_exists: bool
:param request_id: Optional. A unique id used to identify the request. If the server receives two
``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the
first ``google.longrunning.Operation`` created and stored in the backend is returned.
Expand All @@ -455,16 +465,21 @@ class DataprocCreateClusterOperator(BaseOperator):
template_fields = ('project_id', 'region', 'cluster')

@apply_defaults
def __init__(self, *,
region: str = 'global',
project_id: Optional[str] = None,
cluster: Optional[Dict] = None,
request_id: Optional[str] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
gcp_conn_id: str = "google_cloud_default",
**kwargs) -> None:
def __init__( # pylint: disable=too-many-arguments
self,
*,
region: str = 'global',
project_id: Optional[str] = None,
cluster: Optional[Dict] = None,
request_id: Optional[str] = None,
delete_on_error: bool = True,
use_if_exists: bool = True,
retry: Optional[Retry] = None,
timeout: float = 1 * 60 * 60,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
gcp_conn_id: str = "google_cloud_default",
**kwargs
) -> None:
# TODO: remove one day
if cluster is None:
warnings.warn(
Expand Down Expand Up @@ -492,40 +507,124 @@ def __init__(self, *,
super().__init__(**kwargs)

self.cluster = cluster
self.cluster_name = cluster.get('cluster_name')
try:
self.cluster_name = cluster['cluster_name']
except KeyError:
raise AirflowException("`config` misses `cluster_name` key")
self.project_id = project_id
self.region = region
self.request_id = request_id
self.retry = retry
self.timeout = timeout
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.delete_on_error = delete_on_error
self.use_if_exists = use_if_exists

def _create_cluster(self, hook):
operation = hook.create_cluster(
project_id=self.project_id,
region=self.region,
cluster=self.cluster,
request_id=self.request_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
cluster = operation.result()
self.log.info("Cluster created.")
return cluster

def _delete_cluster(self, hook):
self.log.info("Deleting the cluster")
hook.delete_cluster(
region=self.region,
cluster_name=self.cluster_name,
project_id=self.project_id,
)

def _get_cluster(self, hook: DataprocHook):
return hook.get_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)

def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None:
if cluster.status.state != cluster.status.ERROR:
return
self.log.info("Cluster is in ERROR state")
gcs_uri = hook.diagnose_cluster(
region=self.region,
cluster_name=self.cluster_name,
project_id=self.project_id,
)
self.log.info(
'Diagnostic information for cluster %s available at: %s',
self.cluster_name, gcs_uri
)
if self.delete_on_error:
self._delete_cluster(hook)
raise AirflowException("Cluster was created but was in ERROR state.")
raise AirflowException("Cluster was created but is in ERROR state")

def _wait_for_cluster_in_deleting_state(self, hook: DataprocHook) -> None:
time_left = self.timeout
for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120):
if time_left < 0:
raise AirflowException(
f"Cluster {self.cluster_name} is still DELETING state, aborting"
)
time.sleep(time_to_sleep)
time_left = time_left - time_to_sleep
try:
self._get_cluster(hook)
except NotFound:
break

def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster:
time_left = self.timeout
cluster = self._get_cluster(hook)
for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120):
if cluster.status.state != cluster.status.CREATING:
break
if time_left < 0:
raise AirflowException(
f"Cluster {self.cluster_name} is still CREATING state, aborting"
)
time.sleep(time_to_sleep)
time_left = time_left - time_to_sleep
cluster = self._get_cluster(hook)
return cluster

def execute(self, context):
self.log.info('Creating cluster: %s', self.cluster_name)
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
try:
operation = hook.create_cluster(
project_id=self.project_id,
region=self.region,
cluster=self.cluster,
request_id=self.request_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
cluster = operation.result()
self.log.info("Cluster created.")
# First try to create a new cluster
cluster = self._create_cluster(hook)
except AlreadyExists:
cluster = hook.get_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
if not self.use_if_exists:
raise
self.log.info("Cluster already exists.")
cluster = self._get_cluster(hook)

# Check if cluster is not in ERROR state
self._handle_error_state(hook, cluster)
if cluster.status.state == cluster.status.CREATING:
# Wait for cluster to be be created
cluster = self._wait_for_cluster_in_creating_state(hook)
self._handle_error_state(hook, cluster)
elif cluster.status.state == cluster.status.DELETING:
# Wait for cluster to be deleted
self._wait_for_cluster_in_deleting_state(hook)
# Create new cluster
cluster = self._create_cluster(hook)
self._handle_error_state(hook, cluster)

return MessageToDict(cluster)


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
'google-cloud-bigtable>=1.0.0',
'google-cloud-container>=0.1.1,<2.0',
'google-cloud-datacatalog>=0.5.0,<0.8',
'google-cloud-dataproc>=0.5.0',
'google-cloud-dataproc>=1.0.1',
'google-cloud-dlp>=0.11.0',
'google-cloud-kms>=1.2.1,<2.0.0',
'google-cloud-language>=1.1.1',
Expand Down
1 change: 1 addition & 0 deletions tests/providers/google/cloud/hooks/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_diagnose_cluster(self, mock_client):
retry=None,
timeout=None,
)
mock_client.return_value.diagnose_cluster.return_value.result.assert_called_once_with()

@mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
def test_get_cluster(self, mock_client):
Expand Down
95 changes: 94 additions & 1 deletion tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
from typing import Any
from unittest import mock

from google.api_core.exceptions import AlreadyExists
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.retry import Retry

from airflow import AirflowException
from airflow.providers.google.cloud.operators.dataproc import (
ClusterGenerator, DataprocCreateClusterOperator, DataprocDeleteClusterOperator,
DataprocInstantiateInlineWorkflowTemplateOperator, DataprocInstantiateWorkflowTemplateOperator,
Expand Down Expand Up @@ -225,6 +226,7 @@ def test_execute(self, mock_hook):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_if_cluster_exists(self, mock_hook):
mock_hook.return_value.create_cluster.side_effect = [AlreadyExists("test")]
mock_hook.return_value.get_cluster.return_value.status.state = 0
op = DataprocCreateClusterOperator(
task_id=TASK_ID,
region=GCP_LOCATION,
Expand Down Expand Up @@ -256,6 +258,97 @@ def test_execute_if_cluster_exists(self, mock_hook):
metadata=METADATA,
)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_if_cluster_exists_do_not_use(self, mock_hook):
mock_hook.return_value.create_cluster.side_effect = [AlreadyExists("test")]
mock_hook.return_value.get_cluster.return_value.status.state = 0
op = DataprocCreateClusterOperator(
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
cluster=CLUSTER,
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
request_id=REQUEST_ID,
use_if_exists=False
)
with self.assertRaises(AlreadyExists):
op.execute(context={})

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_if_cluster_exists_in_error_state(self, mock_hook):
mock_hook.return_value.create_cluster.side_effect = [AlreadyExists("test")]
cluster_status = mock_hook.return_value.get_cluster.return_value.status
cluster_status.state = 0
cluster_status.ERROR = 0

op = DataprocCreateClusterOperator(
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
cluster=CLUSTER,
delete_on_error=True,
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
request_id=REQUEST_ID,
)
with self.assertRaises(AirflowException):
op.execute(context={})

mock_hook.return_value.diagnose_cluster.assert_called_once_with(
region=GCP_LOCATION,
project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
)
mock_hook.return_value.delete_cluster.assert_called_once_with(
region=GCP_LOCATION,
project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
)

@mock.patch(DATAPROC_PATH.format("exponential_sleep_generator"))
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._create_cluster"))
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._get_cluster"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_if_cluster_exists_in_deleting_state(
self, mock_hook, mock_get_cluster, mock_create_cluster, mock_generator
):
cluster = mock.MagicMock()
cluster.status.state = 0
cluster.status.DELETING = 0

cluster2 = mock.MagicMock()
cluster2.status.state = 0
cluster2.status.ERROR = 0

mock_create_cluster.side_effect = [AlreadyExists("test"), cluster2]
mock_generator.return_value = [0]
mock_get_cluster.side_effect = [cluster, NotFound("test")]

op = DataprocCreateClusterOperator(
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
cluster=CLUSTER,
delete_on_error=True,
gcp_conn_id=GCP_CONN_ID,
)
with self.assertRaises(AirflowException):
op.execute(context={})

calls = [mock.call(mock_hook.return_value), mock.call(mock_hook.return_value)]
mock_get_cluster.assert_has_calls(calls)
mock_create_cluster.assert_has_calls(calls)
mock_hook.return_value.diagnose_cluster.assert_called_once_with(
region=GCP_LOCATION,
project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
)


class TestDataprocClusterScaleOperator(unittest.TestCase):
def test_deprecation_warning(self):
Expand Down

0 comments on commit 0103226

Please sign in to comment.