Skip to content

Commit

Permalink
Add regional support to dataproc workflow template operators (#12907)
Browse files Browse the repository at this point in the history
Workflow templates of GCP can be regional or global. In case of
regional the GCP API endpoint rpc url should match to the same
region.

In case of global templates needed to pass 'global' as region.
It is not used for endpoint address but needed to as part of
template path.

closes: #12804
  • Loading branch information
otourzan committed Dec 21, 2020
1 parent 97eee35 commit f95b1c9
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 19 deletions.
16 changes: 9 additions & 7 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import warnings
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union

from cached_property import cached_property
from google.api_core.exceptions import ServerError
from google.api_core.retry import Retry
from google.cloud.dataproc_v1beta2 import ( # pylint: disable=no-name-in-module
Expand Down Expand Up @@ -218,11 +217,14 @@ def get_cluster_client(self, location: Optional[str] = None) -> ClusterControlle
credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
)

@cached_property
def get_template_client(self) -> WorkflowTemplateServiceClient:
def get_template_client(self, location: Optional[str] = None) -> WorkflowTemplateServiceClient:
"""Returns WorkflowTemplateServiceClient."""
client_options = None
if location and location != 'global':
client_options = {'api_endpoint': f'{location}-dataproc.googleapis.com:443'}

return WorkflowTemplateServiceClient(
credentials=self._get_credentials(), client_info=self.client_info
credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
)

def get_job_client(self, location: Optional[str] = None) -> JobControllerClient:
Expand Down Expand Up @@ -591,7 +593,7 @@ def create_workflow_template(
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_template_client
client = self.get_template_client(location)
parent = client.region_path(project_id, location)
return client.create_workflow_template(
parent=parent, template=template, retry=retry, timeout=timeout, metadata=metadata
Expand Down Expand Up @@ -641,7 +643,7 @@ def instantiate_workflow_template(
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_template_client
client = self.get_template_client(location)
name = client.workflow_template_path(project_id, location, template_name)
operation = client.instantiate_workflow_template(
name=name,
Expand Down Expand Up @@ -688,7 +690,7 @@ def instantiate_inline_workflow_template(
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_template_client
client = self.get_template_client(location)
parent = client.region_path(project_id, location)
operation = client.instantiate_inline_workflow_template(
parent=parent,
Expand Down
39 changes: 27 additions & 12 deletions tests/providers/google/cloud/hooks/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,23 @@ def test_get_cluster_client(self, mock_client, mock_client_info, mock_get_creden
@mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
@mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock)
@mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient"))
def test_get_template_client(self, mock_client, mock_client_info, mock_get_credentials):
_ = self.hook.get_template_client
def test_get_template_client_global(self, mock_client, mock_client_info, mock_get_credentials):
_ = self.hook.get_template_client()
mock_client.assert_called_once_with(
credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value
credentials=mock_get_credentials.return_value,
client_info=mock_client_info.return_value,
client_options=None,
)

@mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
@mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock)
@mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient"))
def test_get_template_client_region(self, mock_client, mock_client_info, mock_get_credentials):
_ = self.hook.get_template_client(location='region1')
mock_client.assert_called_once_with(
credentials=mock_get_credentials.return_value,
client_info=mock_client_info.return_value,
client_options={'api_endpoint': 'region1-dataproc.googleapis.com:443'},
)

@mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
Expand Down Expand Up @@ -193,34 +206,36 @@ def test_update_cluster(self, mock_client):
@mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
def test_create_workflow_template(self, mock_client):
template = {"test": "test"}
mock_client.region_path.return_value = PARENT
mock_client.return_value.region_path.return_value = PARENT
self.hook.create_workflow_template(location=GCP_LOCATION, template=template, project_id=GCP_PROJECT)
mock_client.region_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION)
mock_client.create_workflow_template.assert_called_once_with(
mock_client.return_value.region_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION)
mock_client.return_value.create_workflow_template.assert_called_once_with(
parent=PARENT, template=template, retry=None, timeout=None, metadata=None
)

@mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
def test_instantiate_workflow_template(self, mock_client):
template_name = "template_name"
mock_client.workflow_template_path.return_value = NAME
mock_client.return_value.workflow_template_path.return_value = NAME
self.hook.instantiate_workflow_template(
location=GCP_LOCATION, template_name=template_name, project_id=GCP_PROJECT
)
mock_client.workflow_template_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION, template_name)
mock_client.instantiate_workflow_template.assert_called_once_with(
mock_client.return_value.workflow_template_path.assert_called_once_with(
GCP_PROJECT, GCP_LOCATION, template_name
)
mock_client.return_value.instantiate_workflow_template.assert_called_once_with(
name=NAME, version=None, parameters=None, request_id=None, retry=None, timeout=None, metadata=None
)

@mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
def test_instantiate_inline_workflow_template(self, mock_client):
template = {"test": "test"}
mock_client.region_path.return_value = PARENT
mock_client.return_value.region_path.return_value = PARENT
self.hook.instantiate_inline_workflow_template(
location=GCP_LOCATION, template=template, project_id=GCP_PROJECT
)
mock_client.region_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION)
mock_client.instantiate_inline_workflow_template.assert_called_once_with(
mock_client.return_value.region_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION)
mock_client.return_value.instantiate_inline_workflow_template.assert_called_once_with(
parent=PARENT, template=template, request_id=None, retry=None, timeout=None, metadata=None
)

Expand Down

0 comments on commit f95b1c9

Please sign in to comment.