Skip to content

Commit

Permalink
[AIRFLOW-6926] Fix Google Tasks operators return types and idempotency (
Browse files Browse the repository at this point in the history
  • Loading branch information
turbaszek committed Mar 3, 2020
1 parent 0d1e308 commit 8230ccc
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
timestamp = timestamp_pb2.Timestamp()
timestamp.FromDatetime(datetime.now() + timedelta(hours=12)) # pylint: disable=no-member

LOCATION = "asia-east2"
LOCATION = "europe-west1"
QUEUE_ID = "cloud-tasks-queue"
TASK_NAME = "task-to-run"

Expand Down
62 changes: 43 additions & 19 deletions airflow/providers/google/cloud/operators/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
"""
from typing import Dict, Optional, Sequence, Tuple, Union

from google.api_core.exceptions import AlreadyExists
from google.api_core.retry import Retry
from google.cloud.tasks_v2 import enums
from google.cloud.tasks_v2.types import FieldMask, Queue, Task
from google.protobuf.json_format import MessageToDict

from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.tasks import CloudTasksHook
Expand Down Expand Up @@ -98,15 +100,27 @@ def __init__(

def execute(self, context):
hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id)
return hook.create_queue(
location=self.location,
task_queue=self.task_queue,
project_id=self.project_id,
queue_name=self.queue_name,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
try:
queue = hook.create_queue(
location=self.location,
task_queue=self.task_queue,
project_id=self.project_id,
queue_name=self.queue_name,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
except AlreadyExists:
queue = hook.get_queue(
location=self.location,
project_id=self.project_id,
queue_name=self.queue_name,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)

return MessageToDict(queue)


class CloudTasksQueueUpdateOperator(BaseOperator):
Expand Down Expand Up @@ -181,7 +195,7 @@ def __init__(

def execute(self, context):
hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id)
return hook.update_queue(
queue = hook.update_queue(
task_queue=self.task_queue,
project_id=self.project_id,
location=self.location,
Expand All @@ -191,6 +205,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
return MessageToDict(queue)


class CloudTasksQueueGetOperator(BaseOperator):
Expand Down Expand Up @@ -244,14 +259,15 @@ def __init__(

def execute(self, context):
hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id)
return hook.get_queue(
queue = hook.get_queue(
location=self.location,
queue_name=self.queue_name,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
return MessageToDict(queue)


class CloudTasksQueuesListOperator(BaseOperator):
Expand Down Expand Up @@ -311,7 +327,7 @@ def __init__(

def execute(self, context):
hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id)
return hook.list_queues(
queues = hook.list_queues(
location=self.location,
project_id=self.project_id,
results_filter=self.results_filter,
Expand All @@ -320,6 +336,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
return [MessageToDict(q) for q in queues]


class CloudTasksQueueDeleteOperator(BaseOperator):
Expand Down Expand Up @@ -433,14 +450,15 @@ def __init__(

def execute(self, context):
hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id)
return hook.purge_queue(
queue = hook.purge_queue(
location=self.location,
queue_name=self.queue_name,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
return MessageToDict(queue)


class CloudTasksQueuePauseOperator(BaseOperator):
Expand Down Expand Up @@ -494,14 +512,15 @@ def __init__(

def execute(self, context):
hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id)
return hook.pause_queue(
queues = hook.pause_queue(
location=self.location,
queue_name=self.queue_name,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
return [MessageToDict(q) for q in queues]


class CloudTasksQueueResumeOperator(BaseOperator):
Expand Down Expand Up @@ -555,14 +574,15 @@ def __init__(

def execute(self, context):
hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id)
return hook.resume_queue(
queue = hook.resume_queue(
location=self.location,
queue_name=self.queue_name,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
return MessageToDict(queue)


class CloudTasksTaskCreateOperator(BaseOperator):
Expand Down Expand Up @@ -638,7 +658,7 @@ def __init__( # pylint: disable=too-many-arguments

def execute(self, context):
hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id)
return hook.create_task(
task = hook.create_task(
location=self.location,
queue_name=self.queue_name,
task=self.task,
Expand All @@ -649,6 +669,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
return MessageToDict(task)


class CloudTasksTaskGetOperator(BaseOperator):
Expand Down Expand Up @@ -717,7 +738,7 @@ def __init__(

def execute(self, context):
hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id)
return hook.get_task(
task = hook.get_task(
location=self.location,
queue_name=self.queue_name,
task_name=self.task_name,
Expand All @@ -727,6 +748,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
return MessageToDict(task)


class CloudTasksTasksListOperator(BaseOperator):
Expand Down Expand Up @@ -790,7 +812,7 @@ def __init__(

def execute(self, context):
hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id)
return hook.list_tasks(
tasks = hook.list_tasks(
location=self.location,
queue_name=self.queue_name,
project_id=self.project_id,
Expand All @@ -800,6 +822,7 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
return [MessageToDict(t) for t in tasks]


class CloudTasksTaskDeleteOperator(BaseOperator):
Expand Down Expand Up @@ -939,7 +962,7 @@ def __init__(

def execute(self, context):
hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id)
return hook.run_task(
task = hook.run_task(
location=self.location,
queue_name=self.queue_name,
task_name=self.task_name,
Expand All @@ -949,3 +972,4 @@ def execute(self, context):
timeout=self.timeout,
metadata=self.metadata,
)
return MessageToDict(task)
26 changes: 13 additions & 13 deletions tests/providers/google/cloud/operators/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
class TestCloudTasksQueueCreate(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
def test_create_queue(self, mock_hook):
mock_hook.return_value.create_queue.return_value = {}
mock_hook.return_value.create_queue.return_value = mock.MagicMock()
operator = CloudTasksQueueCreateOperator(
location=LOCATION, task_queue=Queue(), task_id="id"
)
Expand All @@ -64,7 +64,7 @@ def test_create_queue(self, mock_hook):
class TestCloudTasksQueueUpdate(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
def test_update_queue(self, mock_hook):
mock_hook.return_value.update_queue.return_value = {}
mock_hook.return_value.update_queue.return_value = mock.MagicMock()
operator = CloudTasksQueueUpdateOperator(
task_queue=Queue(name=FULL_QUEUE_PATH), task_id="id"
)
Expand All @@ -85,7 +85,7 @@ def test_update_queue(self, mock_hook):
class TestCloudTasksQueueGet(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
def test_get_queue(self, mock_hook):
mock_hook.return_value.get_queue.return_value = {}
mock_hook.return_value.get_queue.return_value = mock.MagicMock()
operator = CloudTasksQueueGetOperator(
location=LOCATION, queue_name=QUEUE_ID, task_id="id"
)
Expand All @@ -104,7 +104,7 @@ def test_get_queue(self, mock_hook):
class TestCloudTasksQueuesList(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
def test_list_queues(self, mock_hook):
mock_hook.return_value.list_queues.return_value = {}
mock_hook.return_value.list_queues.return_value = mock.MagicMock()
operator = CloudTasksQueuesListOperator(location=LOCATION, task_id="id")
operator.execute(context=None)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID)
Expand All @@ -122,7 +122,7 @@ def test_list_queues(self, mock_hook):
class TestCloudTasksQueueDelete(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
def test_delete_queue(self, mock_hook):
mock_hook.return_value.delete_queue.return_value = {}
mock_hook.return_value.delete_queue.return_value = mock.MagicMock()
operator = CloudTasksQueueDeleteOperator(
location=LOCATION, queue_name=QUEUE_ID, task_id="id"
)
Expand All @@ -141,7 +141,7 @@ def test_delete_queue(self, mock_hook):
class TestCloudTasksQueuePurge(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
def test_delete_queue(self, mock_hook):
mock_hook.return_value.purge_queue.return_value = {}
mock_hook.return_value.purge_queue.return_value = mock.MagicMock()
operator = CloudTasksQueuePurgeOperator(
location=LOCATION, queue_name=QUEUE_ID, task_id="id"
)
Expand All @@ -160,7 +160,7 @@ def test_delete_queue(self, mock_hook):
class TestCloudTasksQueuePause(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
def test_pause_queue(self, mock_hook):
mock_hook.return_value.pause_queue.return_value = {}
mock_hook.return_value.pause_queue.return_value = mock.MagicMock()
operator = CloudTasksQueuePauseOperator(
location=LOCATION, queue_name=QUEUE_ID, task_id="id"
)
Expand All @@ -179,7 +179,7 @@ def test_pause_queue(self, mock_hook):
class TestCloudTasksQueueResume(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
def test_resume_queue(self, mock_hook):
mock_hook.return_value.resume_queue.return_value = {}
mock_hook.return_value.resume_queue.return_value = mock.MagicMock()
operator = CloudTasksQueueResumeOperator(
location=LOCATION, queue_name=QUEUE_ID, task_id="id"
)
Expand All @@ -198,7 +198,7 @@ def test_resume_queue(self, mock_hook):
class TestCloudTasksTaskCreate(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
def test_create_task(self, mock_hook):
mock_hook.return_value.create_task.return_value = {}
mock_hook.return_value.create_task.return_value = mock.MagicMock()
operator = CloudTasksTaskCreateOperator(
location=LOCATION, queue_name=QUEUE_ID, task=Task(), task_id="id"
)
Expand All @@ -220,7 +220,7 @@ def test_create_task(self, mock_hook):
class TestCloudTasksTaskGet(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
def test_get_task(self, mock_hook):
mock_hook.return_value.get_task.return_value = {}
mock_hook.return_value.get_task.return_value = mock.MagicMock()
operator = CloudTasksTaskGetOperator(
location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, task_id="id"
)
Expand All @@ -241,7 +241,7 @@ def test_get_task(self, mock_hook):
class TestCloudTasksTasksList(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
def test_list_tasks(self, mock_hook):
mock_hook.return_value.list_tasks.return_value = {}
mock_hook.return_value.list_tasks.return_value = mock.MagicMock()
operator = CloudTasksTasksListOperator(
location=LOCATION, queue_name=QUEUE_ID, task_id="id"
)
Expand All @@ -262,7 +262,7 @@ def test_list_tasks(self, mock_hook):
class TestCloudTasksTaskDelete(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
def test_delete_task(self, mock_hook):
mock_hook.return_value.delete_task.return_value = {}
mock_hook.return_value.delete_task.return_value = mock.MagicMock()
operator = CloudTasksTaskDeleteOperator(
location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, task_id="id"
)
Expand All @@ -282,7 +282,7 @@ def test_delete_task(self, mock_hook):
class TestCloudTasksTaskRun(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
def test_run_task(self, mock_hook):
mock_hook.return_value.run_task.return_value = {}
mock_hook.return_value.run_task.return_value = mock.MagicMock()
operator = CloudTasksTaskRunOperator(
location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, task_id="id"
)
Expand Down

0 comments on commit 8230ccc

Please sign in to comment.