Skip to content

Commit

Permalink
[Airflow 13779] use provided parameters in the wait_for_pipeline_stat…
Browse files Browse the repository at this point in the history
…e hook (#17137)

I removed wait_for_pipeline_state from start_pipeline hook. By this call, I think we have a bug in this operator, for example when we have pipeline which starting more than 300 seconds, so it have a starting status, we get the error because this pipepline is not in correct state after 300 seconds. Even when we pass our parameters sucess_states and pipeline_timeout we get this error in this case, so I think when I pass both parameters the logic should use them not default. Why we have 300 second and these SUCCESS_STATES + [PipelineStates.RUNNING], because we had these values in the wait_for_pipeline_state call which I removed from hook, I think we should replace this 300 second and use default value from __init__ method (this is a open question I think)
  • Loading branch information
lwyszomi committed Aug 20, 2021
1 parent e21b54a commit d04aa13
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 32 deletions.
10 changes: 1 addition & 9 deletions airflow/providers/google/cloud/hooks/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,15 +464,7 @@ def start_pipeline(
raise AirflowException(f"Starting a pipeline failed with code {response.status}")

response_json = json.loads(response.data)
pipeline_id = response_json[0]["runId"]
self.wait_for_pipeline_state(
success_states=SUCCESS_STATES + [PipelineStates.RUNNING],
pipeline_name=pipeline_name,
pipeline_id=pipeline_id,
namespace=namespace,
instance_url=instance_url,
)
return pipeline_id
return response_json[0]["runId"]

def stop_pipeline(self, pipeline_name: str, instance_url: str, namespace: str = "default") -> None:
"""
Expand Down
29 changes: 16 additions & 13 deletions airflow/providers/google/cloud/operators/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from googleapiclient.errors import HttpError

from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook
from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, DataFusionHook, PipelineStates


class CloudDataFusionRestartInstanceOperator(BaseOperator):
Expand Down Expand Up @@ -808,9 +808,7 @@ def __init__(
) -> None:
super().__init__(**kwargs)
self.pipeline_name = pipeline_name
self.success_states = success_states
self.runtime_args = runtime_args
self.pipeline_timeout = pipeline_timeout
self.namespace = namespace
self.instance_name = instance_name
self.location = location
Expand All @@ -820,6 +818,13 @@ def __init__(
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain

if success_states:
self.success_states = success_states
self.pipeline_timeout = pipeline_timeout
else:
self.success_states = SUCCESS_STATES + [PipelineStates.RUNNING]
self.pipeline_timeout = 5 * 60

def execute(self, context: dict) -> None:
hook = DataFusionHook(
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -840,17 +845,15 @@ def execute(self, context: dict) -> None:
namespace=self.namespace,
runtime_args=self.runtime_args,
)

hook.wait_for_pipeline_state(
success_states=self.success_states,
pipeline_id=pipeline_id,
pipeline_name=self.pipeline_name,
namespace=self.namespace,
instance_url=api_url,
timeout=self.pipeline_timeout,
)
self.log.info("Pipeline started")
if self.success_states:
hook.wait_for_pipeline_state(
success_states=self.success_states,
pipeline_id=pipeline_id,
pipeline_name=self.pipeline_name,
namespace=self.namespace,
instance_url=api_url,
timeout=self.pipeline_timeout,
)


class CloudDataFusionStopPipelineOperator(BaseOperator):
Expand Down
12 changes: 2 additions & 10 deletions tests/providers/google/cloud/hooks/test_datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pytest

from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, DataFusionHook, PipelineStates
from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook
from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id

API_VERSION = "v1beta1"
Expand Down Expand Up @@ -180,8 +180,7 @@ def test_list_pipelines(self, mock_request, hook):
assert result == data

@mock.patch(HOOK_STR.format("DataFusionHook._cdap_request"))
@mock.patch(HOOK_STR.format("DataFusionHook.wait_for_pipeline_state"))
def test_start_pipeline(self, mock_wait_for_pipeline_state, mock_request, hook):
def test_start_pipeline(self, mock_request, hook):
run_id = 1234
mock_request.return_value = mock.MagicMock(status=200, data=f'[{{"runId":{run_id}}}]')

Expand All @@ -197,13 +196,6 @@ def test_start_pipeline(self, mock_wait_for_pipeline_state, mock_request, hook):
mock_request.assert_called_once_with(
url=f"{INSTANCE_URL}/v3/namespaces/default/start", method="POST", body=body
)
mock_wait_for_pipeline_state.assert_called_once_with(
instance_url=INSTANCE_URL,
namespace="default",
pipeline_name=PIPELINE_NAME,
pipeline_id=run_id,
success_states=SUCCESS_STATES + [PipelineStates.RUNNING],
)

@mock.patch(HOOK_STR.format("DataFusionHook._cdap_request"))
def test_stop_pipeline(self, mock_request, hook):
Expand Down
12 changes: 12 additions & 0 deletions tests/providers/google/cloud/operators/test_datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from unittest import mock

from airflow import DAG
from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, PipelineStates
from airflow.providers.google.cloud.operators.datafusion import (
CloudDataFusionCreateInstanceOperator,
CloudDataFusionCreatePipelineOperator,
Expand Down Expand Up @@ -194,7 +195,9 @@ def test_execute(self, mock_hook):
class TestCloudDataFusionStartPipelineOperator:
@mock.patch(HOOK_STR)
def test_execute(self, mock_hook):
PIPELINE_ID = "test_pipeline_id"
mock_hook.return_value.get_instance.return_value = {"apiEndpoint": INSTANCE_URL}
mock_hook.return_value.start_pipeline.return_value = PIPELINE_ID

op = CloudDataFusionStartPipelineOperator(
task_id="test_task",
Expand All @@ -219,6 +222,15 @@ def test_execute(self, mock_hook):
runtime_args=RUNTIME_ARGS,
)

mock_hook.return_value.wait_for_pipeline_state.assert_called_once_with(
success_states=SUCCESS_STATES + [PipelineStates.RUNNING],
pipeline_id=PIPELINE_ID,
pipeline_name=PIPELINE_NAME,
namespace=NAMESPACE,
instance_url=INSTANCE_URL,
timeout=300,
)


class TestCloudDataFusionStopPipelineOperator:
@mock.patch(HOOK_STR)
Expand Down

0 comments on commit d04aa13

Please sign in to comment.