Skip to content

Commit

Permalink
Refactor DataprocOperators to support google-cloud-dataproc 2.0 (#13256)
Browse files Browse the repository at this point in the history
  • Loading branch information
turbaszek committed Jan 18, 2021
1 parent f74da50 commit 309788e
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 144 deletions.
2 changes: 2 additions & 0 deletions airflow/providers/google/ADDITIONAL_INFO.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ Details are covered in the UPDATING.md files for each library, but there are som
| [``google-cloud-automl``](https://pypi.org/project/google-cloud-automl/) | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-automl/blob/master/UPGRADING.md) |
| [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) |
| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
| [``google-cloud-dataproc``](https://pypi.org/project/google-cloud-dataproc/) | ``>=1.0.1,<2.0.0`` | ``>=2.2.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-dataproc/blob/master/UPGRADING.md) |
| [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |
| [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
| [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) |
| [``google-cloud-tasks``](https://pypi.org/project/google-cloud-tasks/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-tasks/blob/master/UPGRADING.md) |


### The field names use the snake_case convention

If your DAG uses an object from the above mentioned libraries passed by XCom, it is necessary to update the naming convention of the fields that are read. Previously, the fields used the CamelSnake convention, now the snake_case convention is used.
Expand Down
104 changes: 46 additions & 58 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,16 @@
from google.api_core.exceptions import ServerError
from google.api_core.retry import Retry
from google.cloud.dataproc_v1beta2 import ( # pylint: disable=no-name-in-module
ClusterControllerClient,
JobControllerClient,
WorkflowTemplateServiceClient,
)
from google.cloud.dataproc_v1beta2.types import ( # pylint: disable=no-name-in-module
Cluster,
Duration,
FieldMask,
ClusterControllerClient,
Job,
JobControllerClient,
JobStatus,
WorkflowTemplate,
WorkflowTemplateServiceClient,
)
from google.protobuf.duration_pb2 import Duration
from google.protobuf.field_mask_pb2 import FieldMask

from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
Expand Down Expand Up @@ -291,10 +289,12 @@ def create_cluster(

client = self.get_cluster_client(location=region)
result = client.create_cluster(
project_id=project_id,
region=region,
cluster=cluster,
request_id=request_id,
request={
'project_id': project_id,
'region': region,
'cluster': cluster,
'request_id': request_id,
},
retry=retry,
timeout=timeout,
metadata=metadata,
Expand Down Expand Up @@ -340,11 +340,13 @@ def delete_cluster(
"""
client = self.get_cluster_client(location=region)
result = client.delete_cluster(
project_id=project_id,
region=region,
cluster_name=cluster_name,
cluster_uuid=cluster_uuid,
request_id=request_id,
request={
'project_id': project_id,
'region': region,
'cluster_name': cluster_name,
'cluster_uuid': cluster_uuid,
'request_id': request_id,
},
retry=retry,
timeout=timeout,
metadata=metadata,
Expand Down Expand Up @@ -382,9 +384,7 @@ def diagnose_cluster(
"""
client = self.get_cluster_client(location=region)
operation = client.diagnose_cluster(
project_id=project_id,
region=region,
cluster_name=cluster_name,
request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name},
retry=retry,
timeout=timeout,
metadata=metadata,
Expand Down Expand Up @@ -423,9 +423,7 @@ def get_cluster(
"""
client = self.get_cluster_client(location=region)
result = client.get_cluster(
project_id=project_id,
region=region,
cluster_name=cluster_name,
request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name},
retry=retry,
timeout=timeout,
metadata=metadata,
Expand Down Expand Up @@ -467,10 +465,7 @@ def list_clusters(
"""
client = self.get_cluster_client(location=region)
result = client.list_clusters(
project_id=project_id,
region=region,
filter_=filter_,
page_size=page_size,
request={'project_id': project_id, 'region': region, 'filter': filter_, 'page_size': page_size},
retry=retry,
timeout=timeout,
metadata=metadata,
Expand Down Expand Up @@ -551,13 +546,15 @@ def update_cluster( # pylint: disable=too-many-arguments
"""
client = self.get_cluster_client(location=location)
operation = client.update_cluster(
project_id=project_id,
region=location,
cluster_name=cluster_name,
cluster=cluster,
update_mask=update_mask,
graceful_decommission_timeout=graceful_decommission_timeout,
request_id=request_id,
request={
'project_id': project_id,
'region': location,
'cluster_name': cluster_name,
'cluster': cluster,
'update_mask': update_mask,
'graceful_decommission_timeout': graceful_decommission_timeout,
'request_id': request_id,
},
retry=retry,
timeout=timeout,
metadata=metadata,
Expand Down Expand Up @@ -593,10 +590,11 @@ def create_workflow_template(
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
metadata = metadata or ()
client = self.get_template_client(location)
parent = client.region_path(project_id, location)
parent = f'projects/{project_id}/regions/{location}'
return client.create_workflow_template(
parent=parent, template=template, retry=retry, timeout=timeout, metadata=metadata
request={'parent': parent, 'template': template}, retry=retry, timeout=timeout, metadata=metadata
)

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -643,13 +641,11 @@ def instantiate_workflow_template(
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
metadata = metadata or ()
client = self.get_template_client(location)
name = client.workflow_template_path(project_id, location, template_name)
name = f'projects/{project_id}/regions/{location}/workflowTemplates/{template_name}'
operation = client.instantiate_workflow_template(
name=name,
version=version,
parameters=parameters,
request_id=request_id,
request={'name': name, 'version': version, 'request_id': request_id, 'parameters': parameters},
retry=retry,
timeout=timeout,
metadata=metadata,
Expand Down Expand Up @@ -690,12 +686,11 @@ def instantiate_inline_workflow_template(
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
metadata = metadata or ()
client = self.get_template_client(location)
parent = client.region_path(project_id, location)
parent = f'projects/{project_id}/regions/{location}'
operation = client.instantiate_inline_workflow_template(
parent=parent,
template=template,
request_id=request_id,
request={'parent': parent, 'template': template, 'request_id': request_id},
retry=retry,
timeout=timeout,
metadata=metadata,
Expand All @@ -722,19 +717,19 @@ def wait_for_job(
"""
state = None
start = time.monotonic()
while state not in (JobStatus.ERROR, JobStatus.DONE, JobStatus.CANCELLED):
while state not in (JobStatus.State.ERROR, JobStatus.State.DONE, JobStatus.State.CANCELLED):
if timeout and start + timeout < time.monotonic():
raise AirflowException(f"Timeout: dataproc job {job_id} is not ready after {timeout}s")
time.sleep(wait_time)
try:
job = self.get_job(location=location, job_id=job_id, project_id=project_id)
job = self.get_job(project_id=project_id, location=location, job_id=job_id)
state = job.status.state
except ServerError as err:
self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err)

if state == JobStatus.ERROR:
if state == JobStatus.State.ERROR:
raise AirflowException(f'Job failed:\n{job}')
if state == JobStatus.CANCELLED:
if state == JobStatus.State.CANCELLED:
raise AirflowException(f'Job was cancelled:\n{job}')

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -767,9 +762,7 @@ def get_job(
"""
client = self.get_job_client(location=location)
job = client.get_job(
project_id=project_id,
region=location,
job_id=job_id,
request={'project_id': project_id, 'region': location, 'job_id': job_id},
retry=retry,
timeout=timeout,
metadata=metadata,
Expand Down Expand Up @@ -812,10 +805,7 @@ def submit_job(
"""
client = self.get_job_client(location=location)
return client.submit_job(
project_id=project_id,
region=location,
job=job,
request_id=request_id,
request={'project_id': project_id, 'region': location, 'job': job, 'request_id': request_id},
retry=retry,
timeout=timeout,
metadata=metadata,
Expand Down Expand Up @@ -884,9 +874,7 @@ def cancel_job(
client = self.get_job_client(location=location)

job = client.cancel_job(
project_id=project_id,
region=location,
job_id=job_id,
request={'project_id': project_id, 'region': location, 'job_id': job_id},
retry=retry,
timeout=timeout,
metadata=metadata,
Expand Down
30 changes: 13 additions & 17 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
#
"""This module contains Google Dataproc operators."""
# pylint: disable=C0302

import inspect
import ntpath
Expand All @@ -31,12 +30,9 @@

from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.retry import Retry, exponential_sleep_generator
from google.cloud.dataproc_v1beta2.types import ( # pylint: disable=no-name-in-module
Cluster,
Duration,
FieldMask,
)
from google.protobuf.json_format import MessageToDict
from google.cloud.dataproc_v1beta2 import Cluster # pylint: disable=no-name-in-module
from google.protobuf.duration_pb2 import Duration
from google.protobuf.field_mask_pb2 import FieldMask

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -562,7 +558,7 @@ def _get_cluster(self, hook: DataprocHook) -> Cluster:
)

def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None:
if cluster.status.state != cluster.status.ERROR:
if cluster.status.state != cluster.status.State.ERROR:
return
self.log.info("Cluster is in ERROR state")
gcs_uri = hook.diagnose_cluster(
Expand Down Expand Up @@ -590,7 +586,7 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster:
time_left = self.timeout
cluster = self._get_cluster(hook)
for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120):
if cluster.status.state != cluster.status.CREATING:
if cluster.status.state != cluster.status.State.CREATING:
break
if time_left < 0:
raise AirflowException(f"Cluster {self.cluster_name} is still CREATING state, aborting")
Expand All @@ -613,18 +609,18 @@ def execute(self, context) -> dict:

# Check if cluster is not in ERROR state
self._handle_error_state(hook, cluster)
if cluster.status.state == cluster.status.CREATING:
if cluster.status.state == cluster.status.State.CREATING:
# Wait for cluster to be be created
cluster = self._wait_for_cluster_in_creating_state(hook)
self._handle_error_state(hook, cluster)
elif cluster.status.state == cluster.status.DELETING:
elif cluster.status.state == cluster.status.State.DELETING:
# Wait for cluster to be deleted
self._wait_for_cluster_in_deleting_state(hook)
# Create new cluster
cluster = self._create_cluster(hook)
self._handle_error_state(hook, cluster)

return MessageToDict(cluster)
return Cluster.to_dict(cluster)


class DataprocScaleClusterOperator(BaseOperator):
Expand Down Expand Up @@ -1855,7 +1851,7 @@ class DataprocSubmitJobOperator(BaseOperator):
:type wait_timeout: int
"""

template_fields = ('project_id', 'location', 'job', 'impersonation_chain')
template_fields = ('project_id', 'location', 'job', 'impersonation_chain', 'request_id')
template_fields_renderers = {"job": "json"}

@apply_defaults
Expand Down Expand Up @@ -1941,14 +1937,14 @@ class DataprocUpdateClusterOperator(BaseOperator):
example, to change the number of workers in a cluster to 5, the ``update_mask`` parameter would be
specified as ``config.worker_config.num_instances``, and the ``PATCH`` request body would specify the
new value. If a dict is provided, it must be of the same form as the protobuf message
:class:`~google.cloud.dataproc_v1beta2.types.FieldMask`
:type update_mask: Union[Dict, google.cloud.dataproc_v1beta2.types.FieldMask]
:class:`~google.protobuf.field_mask_pb2.FieldMask`
:type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
:param graceful_decommission_timeout: Optional. Timeout for graceful YARN decommissioning. Graceful
decommissioning allows removing nodes from the cluster without interrupting jobs in progress. Timeout
specifies how long to wait for jobs in progress to finish before forcefully removing nodes (and
potentially interrupting jobs). Default timeout is 0 (for forceful decommission), and the maximum
allowed timeout is 1 day.
:type graceful_decommission_timeout: Union[Dict, google.cloud.dataproc_v1beta2.types.Duration]
:type graceful_decommission_timeout: Union[Dict, google.protobuf.duration_pb2.Duration]
:param request_id: Optional. A unique id used to identify the request. If the server receives two
``UpdateClusterRequest`` requests with the same id, then the second request will be ignored and the
first ``google.longrunning.Operation`` created and stored in the backend is returned.
Expand All @@ -1974,7 +1970,7 @@ class DataprocUpdateClusterOperator(BaseOperator):
:type impersonation_chain: Union[str, Sequence[str]]
"""

template_fields = ('impersonation_chain',)
template_fields = ('impersonation_chain', 'cluster_name')

@apply_defaults
def __init__( # pylint: disable=too-many-arguments
Expand Down
12 changes: 8 additions & 4 deletions airflow/providers/google/cloud/sensors/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,18 @@ def poke(self, context: dict) -> bool:
job = hook.get_job(job_id=self.dataproc_job_id, location=self.location, project_id=self.project_id)
state = job.status.state

if state == JobStatus.ERROR:
if state == JobStatus.State.ERROR:
raise AirflowException(f'Job failed:\n{job}')
elif state in {JobStatus.CANCELLED, JobStatus.CANCEL_PENDING, JobStatus.CANCEL_STARTED}:
elif state in {
JobStatus.State.CANCELLED,
JobStatus.State.CANCEL_PENDING,
JobStatus.State.CANCEL_STARTED,
}:
raise AirflowException(f'Job was cancelled:\n{job}')
elif JobStatus.DONE == state:
elif JobStatus.State.DONE == state:
self.log.debug("Job %s completed successfully.", self.dataproc_job_id)
return True
elif JobStatus.ATTEMPT_FAILURE == state:
elif JobStatus.State.ATTEMPT_FAILURE == state:
self.log.debug("Job %s attempt has failed.", self.dataproc_job_id)

self.log.info("Waiting for job %s to complete.", self.dataproc_job_id)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def get_sphinx_theme_version() -> str:
'google-cloud-bigtable>=1.0.0,<2.0.0',
'google-cloud-container>=0.1.1,<2.0.0',
'google-cloud-datacatalog>=3.0.0,<4.0.0',
'google-cloud-dataproc>=1.0.1,<2.0.0',
'google-cloud-dataproc>=2.2.0,<3.0.0',
'google-cloud-dlp>=0.11.0,<2.0.0',
'google-cloud-kms>=2.0.0,<3.0.0',
'google-cloud-language>=1.1.1,<2.0.0',
Expand Down

0 comments on commit 309788e

Please sign in to comment.