Skip to content

Commit

Permalink
Fixes to dataproc operators and hook (#14086)
Browse files Browse the repository at this point in the history
Two quick fixes to Dataproc operators and hooks.

Add more templated fields to the DataprocClusterDeleteOperator
as per #13454. There were a few other fields which could easily be templated so I added them as well.

Don't use the global-dataproc.googleapis.com:443 URL when creating dataproc clients.
This was partially done in #12907 but the other two client creation methods were not updated. Using the global-dataproc URL results in 404s when trying to create clusters in the global region.

We don't need to specify the default endpoint as it is used by default in the dataproc client library:

https://github.com/googleapis/python-dataproc/blob/6f27109faf03dd13f25294e57960f0d9e1a9fa27/google/cloud/dataproc_v1beta2/services/cluster_controller/client.py#L117
  • Loading branch information
SamWheating committed Feb 10, 2021
1 parent 9036ce2 commit 1da6972
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
8 changes: 6 additions & 2 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ class DataprocHook(GoogleBaseHook):

def get_cluster_client(self, location: Optional[str] = None) -> ClusterControllerClient:
"""Returns ClusterControllerClient."""
client_options = {'api_endpoint': f'{location}-dataproc.googleapis.com:443'} if location else None
client_options = None
if location and location != 'global':
client_options = {'api_endpoint': f'{location}-dataproc.googleapis.com:443'}

return ClusterControllerClient(
credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
Expand All @@ -227,7 +229,9 @@ def get_template_client(self, location: Optional[str] = None) -> WorkflowTemplat

def get_job_client(self, location: Optional[str] = None) -> JobControllerClient:
"""Returns JobControllerClient."""
client_options = {'api_endpoint': f'{location}-dataproc.googleapis.com:443'} if location else None
client_options = None
if location and location != 'global':
client_options = {'api_endpoint': f'{location}-dataproc.googleapis.com:443'}

return JobControllerClient(
credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,11 +767,11 @@ class DataprocDeleteClusterOperator(BaseOperator):
"""
Deletes a cluster in a project.
:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to (templated).
:type project_id: str
:param region: Required. The Cloud Dataproc region in which to handle the request.
:param region: Required. The Cloud Dataproc region in which to handle the request (templated).
:type region: str
:param cluster_name: Required. The cluster name.
:param cluster_name: Required. The cluster name (templated).
:type cluster_name: str
:param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail
if cluster with specified UUID does not exist.
Expand Down Expand Up @@ -801,7 +801,7 @@ class DataprocDeleteClusterOperator(BaseOperator):
:type impersonation_chain: Union[str, Sequence[str]]
"""

template_fields = ('impersonation_chain',)
template_fields = ('project_id', 'region', 'cluster_name', 'impersonation_chain')

@apply_defaults
def __init__(
Expand Down
26 changes: 24 additions & 2 deletions tests/providers/google/cloud/hooks/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,18 @@ def test_get_cluster_client(self, mock_client, mock_client_info, mock_get_creden
mock_client.assert_called_once_with(
credentials=mock_get_credentials.return_value,
client_info=mock_client_info.return_value,
client_options={"api_endpoint": f"{GCP_LOCATION}-dataproc.googleapis.com:443"},
client_options=None,
)

@mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
@mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock)
@mock.patch(DATAPROC_STRING.format("ClusterControllerClient"))
def test_get_cluster_client_region(self, mock_client, mock_client_info, mock_get_credentials):
self.hook.get_cluster_client(location='region1')
mock_client.assert_called_once_with(
credentials=mock_get_credentials.return_value,
client_info=mock_client_info.return_value,
client_options={'api_endpoint': 'region1-dataproc.googleapis.com:443'},
)

@mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
Expand Down Expand Up @@ -97,7 +108,18 @@ def test_get_job_client(self, mock_client, mock_client_info, mock_get_credential
mock_client.assert_called_once_with(
credentials=mock_get_credentials.return_value,
client_info=mock_client_info.return_value,
client_options={"api_endpoint": f"{GCP_LOCATION}-dataproc.googleapis.com:443"},
client_options=None,
)

@mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
@mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock)
@mock.patch(DATAPROC_STRING.format("JobControllerClient"))
def test_get_job_client_region(self, mock_client, mock_client_info, mock_get_credentials):
self.hook.get_job_client(location='region1')
mock_client.assert_called_once_with(
credentials=mock_get_credentials.return_value,
client_info=mock_client_info.return_value,
client_options={'api_endpoint': 'region1-dataproc.googleapis.com:443'},
)

@mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
Expand Down

0 comments on commit 1da6972

Please sign in to comment.