Skip to content

Commit

Permalink
Support google-cloud-tasks>=2.0.0 (#13347)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj committed Jan 14, 2021
1 parent 61b1ea3 commit ef8617e
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 138 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/google/ADDITIONAL_INFO.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ Details are covered in the UPDATING.md files for each library, but there are som
| [``google-cloud-automl``](https://pypi.org/project/google-cloud-automl/) | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-automl/blob/master/UPGRADING.md) |
| [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) |
| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
| [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |
| [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
| [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) |
| [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |

| [``google-cloud-tasks``](https://pypi.org/project/google-cloud-tasks/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-tasks/blob/master/UPGRADING.md) |

### The field names use the snake_case convention

Expand Down
118 changes: 62 additions & 56 deletions airflow/providers/google/cloud/hooks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
which allows you to connect to Google Cloud Tasks service,
performing actions to queues or tasks.
"""

from typing import Dict, List, Optional, Sequence, Tuple, Union

from google.api_core.retry import Retry
from google.cloud.tasks_v2 import CloudTasksClient, enums
from google.cloud.tasks_v2.types import FieldMask, Queue, Task
from google.cloud.tasks_v2 import CloudTasksClient
from google.cloud.tasks_v2.types import Queue, Task
from google.protobuf.field_mask_pb2 import FieldMask

from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
Expand Down Expand Up @@ -120,20 +122,19 @@ def create_queue(
client = self.get_conn()

if queue_name:
full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
if isinstance(task_queue, Queue):
task_queue.name = full_queue_name
elif isinstance(task_queue, dict):
task_queue['name'] = full_queue_name
else:
raise AirflowException('Unable to set queue_name.')
full_location_path = CloudTasksClient.location_path(project_id, location)
full_location_path = f"projects/{project_id}/locations/{location}"
return client.create_queue(
parent=full_location_path,
queue=task_queue,
request={'parent': full_location_path, 'queue': task_queue},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -167,7 +168,7 @@ def update_queue(
:param update_mask: A mast used to specify which fields of the queue are being updated.
If empty, then all fields will be updated.
If a dict is provided, it must be of the same form as the protobuf message.
:type update_mask: dict or google.cloud.tasks_v2.types.FieldMask
:type update_mask: dict or google.protobuf.field_mask_pb2.FieldMask
:param retry: (Optional) A retry object used to retry requests.
If None is specified, requests will not be retried.
:type retry: google.api_core.retry.Retry
Expand All @@ -182,19 +183,18 @@ def update_queue(
client = self.get_conn()

if queue_name and location:
full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
if isinstance(task_queue, Queue):
task_queue.name = full_queue_name
elif isinstance(task_queue, dict):
task_queue['name'] = full_queue_name
else:
raise AirflowException('Unable to set queue_name.')
return client.update_queue(
queue=task_queue,
update_mask=update_mask,
request={'queue': task_queue, 'update_mask': update_mask},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -230,8 +230,10 @@ def get_queue(
"""
client = self.get_conn()

full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
return client.get_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
return client.get_queue(
request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
)

@GoogleBaseHook.fallback_to_default_project_id
def list_queues(
Expand Down Expand Up @@ -270,14 +272,12 @@ def list_queues(
"""
client = self.get_conn()

full_location_path = CloudTasksClient.location_path(project_id, location)
full_location_path = f"projects/{project_id}/locations/{location}"
queues = client.list_queues(
parent=full_location_path,
filter_=results_filter,
page_size=page_size,
request={'parent': full_location_path, 'filter': results_filter, 'page_size': page_size},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return list(queues)

Expand Down Expand Up @@ -313,8 +313,10 @@ def delete_queue(
"""
client = self.get_conn()

full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
client.delete_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
client.delete_queue(
request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
)

@GoogleBaseHook.fallback_to_default_project_id
def purge_queue(
Expand Down Expand Up @@ -349,8 +351,10 @@ def purge_queue(
"""
client = self.get_conn()

full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
return client.purge_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
return client.purge_queue(
request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
)

@GoogleBaseHook.fallback_to_default_project_id
def pause_queue(
Expand Down Expand Up @@ -385,8 +389,10 @@ def pause_queue(
"""
client = self.get_conn()

full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
return client.pause_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
return client.pause_queue(
request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
)

@GoogleBaseHook.fallback_to_default_project_id
def resume_queue(
Expand Down Expand Up @@ -421,8 +427,10 @@ def resume_queue(
"""
client = self.get_conn()

full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
return client.resume_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
return client.resume_queue(
request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
)

@GoogleBaseHook.fallback_to_default_project_id
def create_task(
Expand All @@ -432,7 +440,7 @@ def create_task(
task: Union[Dict, Task],
project_id: str,
task_name: Optional[str] = None,
response_view: Optional[enums.Task.View] = None,
response_view: Optional = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
Expand All @@ -455,7 +463,7 @@ def create_task(
:type task_name: str
:param response_view: (Optional) This field specifies which subset of the Task will
be returned.
:type response_view: google.cloud.tasks_v2.enums.Task.View
:type response_view: google.cloud.tasks_v2.Task.View
:param retry: (Optional) A retry object used to retry requests.
If None is specified, requests will not be retried.
:type retry: google.api_core.retry.Retry
Expand All @@ -470,21 +478,21 @@ def create_task(
client = self.get_conn()

if task_name:
full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name)
full_task_name = (
f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}"
)
if isinstance(task, Task):
task.name = full_task_name
elif isinstance(task, dict):
task['name'] = full_task_name
else:
raise AirflowException('Unable to set task_name.')
full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
return client.create_task(
parent=full_queue_name,
task=task,
response_view=response_view,
request={'parent': full_queue_name, 'task': task, 'response_view': response_view},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)

@GoogleBaseHook.fallback_to_default_project_id
Expand All @@ -494,7 +502,7 @@ def get_task(
queue_name: str,
task_name: str,
project_id: str,
response_view: Optional[enums.Task.View] = None,
response_view: Optional = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
Expand All @@ -513,7 +521,7 @@ def get_task(
:type project_id: str
:param response_view: (Optional) This field specifies which subset of the Task will
be returned.
:type response_view: google.cloud.tasks_v2.enums.Task.View
:type response_view: google.cloud.tasks_v2.Task.View
:param retry: (Optional) A retry object used to retry requests.
If None is specified, requests will not be retried.
:type retry: google.api_core.retry.Retry
Expand All @@ -527,13 +535,12 @@ def get_task(
"""
client = self.get_conn()

full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name)
full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}"
return client.get_task(
name=full_task_name,
response_view=response_view,
request={'name': full_task_name, 'response_view': response_view},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)

@GoogleBaseHook.fallback_to_default_project_id
Expand All @@ -542,7 +549,7 @@ def list_tasks(
location: str,
queue_name: str,
project_id: str,
response_view: Optional[enums.Task.View] = None,
response_view: Optional = None,
page_size: Optional[int] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
Expand All @@ -560,7 +567,7 @@ def list_tasks(
:type project_id: str
:param response_view: (Optional) This field specifies which subset of the Task will
be returned.
:type response_view: google.cloud.tasks_v2.enums.Task.View
:type response_view: google.cloud.tasks_v2.Task.View
:param page_size: (Optional) The maximum number of resources contained in the
underlying API response.
:type page_size: int
Expand All @@ -576,14 +583,12 @@ def list_tasks(
:rtype: list[google.cloud.tasks_v2.types.Task]
"""
client = self.get_conn()
full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
tasks = client.list_tasks(
parent=full_queue_name,
response_view=response_view,
page_size=page_size,
request={'parent': full_queue_name, 'response_view': response_view, 'page_size': page_size},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return list(tasks)

Expand Down Expand Up @@ -622,8 +627,10 @@ def delete_task(
"""
client = self.get_conn()

full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name)
client.delete_task(name=full_task_name, retry=retry, timeout=timeout, metadata=metadata)
full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}"
client.delete_task(
request={'name': full_task_name}, retry=retry, timeout=timeout, metadata=metadata or ()
)

@GoogleBaseHook.fallback_to_default_project_id
def run_task(
Expand All @@ -632,7 +639,7 @@ def run_task(
queue_name: str,
task_name: str,
project_id: str,
response_view: Optional[enums.Task.View] = None,
response_view: Optional = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = None,
Expand All @@ -651,7 +658,7 @@ def run_task(
:type project_id: str
:param response_view: (Optional) This field specifies which subset of the Task will
be returned.
:type response_view: google.cloud.tasks_v2.enums.Task.View
:type response_view: google.cloud.tasks_v2.Task.View
:param retry: (Optional) A retry object used to retry requests.
If None is specified, requests will not be retried.
:type retry: google.api_core.retry.Retry
Expand All @@ -665,11 +672,10 @@ def run_task(
"""
client = self.get_conn()

full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name)
full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}"
return client.run_task(
name=full_task_name,
response_view=response_view,
request={'name': full_task_name, 'response_view': response_view},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)

0 comments on commit ef8617e

Please sign in to comment.