Skip to content

Commit

Permalink
Adding Support for Google Cloud's Data Pipelines Run Operator (#32846)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: shaniyaclement <[email protected]>
Co-authored-by: Brenda Pham <[email protected]>
Co-authored-by: Shaniya Clement <[email protected]>
  • Loading branch information
4 people committed Aug 21, 2023
1 parent 46fa5a2 commit c8de9a5
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 1 deletion.
31 changes: 31 additions & 0 deletions airflow/providers/google/cloud/hooks/datapipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,37 @@ def create_data_pipeline(
response = request.execute(num_retries=self.num_retries)
return response

@GoogleBaseHook.fallback_to_default_project_id
def run_data_pipeline(
self,
data_pipeline_name: str,
project_id: str,
location: str = DEFAULT_DATAPIPELINE_LOCATION,
) -> None:
"""
Runs a Data Pipelines Instance using the Data Pipelines API.
:param data_pipeline_name: The display name of the pipeline. In example
projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
:param project_id: The ID of the GCP project that owns the job.
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
Returns the created Job in JSON representation.
"""
parent = self.build_parent_name(project_id, location)
service = self.get_conn()
request = (
service.projects()
.locations()
.pipelines()
.run(
name=f"{parent}/pipelines/{data_pipeline_name}",
body={},
)
)
response = request.execute(num_retries=self.num_retries)
return response

@staticmethod
def build_parent_name(project_id: str, location: str):
return f"projects/{project_id}/locations/{location}"
52 changes: 52 additions & 0 deletions airflow/providers/google/cloud/operators/datapipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,55 @@ def execute(self, context: Context):
raise AirflowException(self.data_pipeline.get("error").get("message"))

return self.data_pipeline


class RunDataPipelineOperator(GoogleCloudBaseOperator):
"""
Runs a Data Pipelines Instance using the Data Pipelines API.
:param data_pipeline_name: The display name of the pipeline. In example
projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
:param project_id: The ID of the GCP project that owns the job.
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
:param gcp_conn_id: The connection ID to connect to the Google Cloud
Platform.
Returns the created Job in JSON representation.
"""

def __init__(
self,
data_pipeline_name: str,
project_id: str | None = None,
location: str = DEFAULT_DATAPIPELINE_LOCATION,
gcp_conn_id: str = "google_cloud_default",
**kwargs,
) -> None:
super().__init__(**kwargs)

self.data_pipeline_name = data_pipeline_name
self.project_id = project_id
self.location = location
self.gcp_conn_id = gcp_conn_id

def execute(self, context: Context):
self.data_pipeline_hook = DataPipelineHook(gcp_conn_id=self.gcp_conn_id)

if self.data_pipeline_name is None:
raise AirflowException("Data Pipeline name not given; cannot run unspecified pipeline.")
if self.project_id is None:
raise AirflowException("Data Pipeline Project ID not given; cannot run pipeline.")
if self.location is None:
raise AirflowException("Data Pipeline location not given; cannot run pipeline.")

self.response = self.data_pipeline_hook.run_data_pipeline(
data_pipeline_name=self.data_pipeline_name,
project_id=self.project_id,
location=self.location,
)

if self.response:
if "error" in self.response:
raise AirflowException(self.response.get("error").get("message"))

return self.response
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,35 @@ Here is an example of how you can create a Data Pipelines instance by running th
:start-after: [START howto_operator_create_data_pipeline]
:end-before: [END howto_operator_create_data_pipeline]

Running a Data Pipeline
^^^^^^^^^^^^^^^^^^^^^^^

To run a Data Pipelines instance, use :class:`~airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator`.
The operator accesses Google Cloud's Data Pipelines API and calls upon the
`run method <https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines/run>`__
to run the given pipeline.

:class:`~airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator` can take in four parameters:

- ``data_pipeline_name``: the name of the Data Pipelines instance
- ``project_id``: the ID of the GCP project that owns the job
- ``location``: the location of the Data Pipelines instance
- ``gcp_conn_id``: the connection ID to connect to the Google Cloud Platform

Only the Data Pipeline name and Project ID are required parameters, as the Location and GCP Connection ID have default values.
The Project ID and Location will be used to build the parent name, which is where the given Data Pipeline should be located.

You can run a Data Pipelines instance by running the above parameters with RunDataPipelineOperator:

.. exampleinclude:: /../../tests/system/providers/google/cloud/datapipelines/example_datapipeline.py
:language: python
:dedent: 4
:start-after: [START howto_operator_run_data_pipeline]
:end-before: [END howto_operator_run_data_pipeline]

Once called, the RunDataPipelineOperator will return the Google Cloud `Dataflow Job <https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/Job>`__
created by running the given pipeline.

For further information regarding the API usage, see
`Data Pipelines API REST Resource <https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines#Pipeline>`__
in the Google Cloud documentation.
23 changes: 23 additions & 0 deletions tests/providers/google/cloud/hooks/test_datapipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,26 @@ def test_create_data_pipeline(self, mock_connection):
body=TEST_BODY,
)
assert result == {"name": TEST_PARENT}

@mock.patch("airflow.providers.google.cloud.hooks.datapipeline.DataPipelineHook.get_conn")
def test_run_data_pipeline(self, mock_connection):
"""
Test that run_data_pipeline is called with correct parameters and
calls Google Data Pipelines API
"""
mock_request = (
mock_connection.return_value.projects.return_value.locations.return_value.pipelines.return_value.run
)
mock_request.return_value.execute.return_value = {"job": {"id": TEST_JOB_ID}}

result = self.datapipeline_hook.run_data_pipeline(
data_pipeline_name=TEST_DATA_PIPELINE_NAME,
project_id=TEST_PROJECTID,
location=TEST_LOCATION,
)

mock_request.assert_called_once_with(
name=TEST_NAME,
body={},
)
assert result == {"job": {"id": TEST_JOB_ID}}
90 changes: 90 additions & 0 deletions tests/providers/google/cloud/operators/test_datapipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.operators.datapipeline import (
CreateDataPipelineOperator,
RunDataPipelineOperator,
)

TASK_ID = "test-datapipeline-operators"
Expand Down Expand Up @@ -136,3 +137,92 @@ def test_response_invalid(self):
}
with pytest.raises(AirflowException):
CreateDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())


class TestRunDataPipelineOperator:
@pytest.fixture
def run_operator(self):
"""
Create a RunDataPipelineOperator instance with test data
"""
return RunDataPipelineOperator(
task_id=TASK_ID,
data_pipeline_name=TEST_DATA_PIPELINE_NAME,
project_id=TEST_PROJECTID,
location=TEST_LOCATION,
gcp_conn_id=TEST_GCP_CONN_ID,
)

@mock.patch("airflow.providers.google.cloud.operators.datapipeline.DataPipelineHook")
def test_execute(self, data_pipeline_hook_mock, run_operator):
"""
Test Run Operator execute with correct parameters
"""
run_operator.execute(mock.MagicMock())
data_pipeline_hook_mock.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
)

data_pipeline_hook_mock.return_value.run_data_pipeline.assert_called_once_with(
data_pipeline_name=TEST_DATA_PIPELINE_NAME,
project_id=TEST_PROJECTID,
location=TEST_LOCATION,
)

def test_invalid_data_pipeline_name(self):
"""
Test that AirflowException is raised if Run Operator is not given a data pipeline name.
"""
init_kwargs = {
"task_id": TASK_ID,
"data_pipeline_name": None,
"project_id": TEST_PROJECTID,
"location": TEST_LOCATION,
"gcp_conn_id": TEST_GCP_CONN_ID,
}
with pytest.raises(AirflowException):
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())

def test_invalid_project_id(self):
"""
Test that AirflowException is raised if Run Operator is not given a project ID.
"""
init_kwargs = {
"task_id": TASK_ID,
"data_pipeline_name": TEST_DATA_PIPELINE_NAME,
"project_id": None,
"location": TEST_LOCATION,
"gcp_conn_id": TEST_GCP_CONN_ID,
}
with pytest.raises(AirflowException):
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())

def test_invalid_location(self):
"""
Test that AirflowException is raised if Run Operator is not given a location.
"""
init_kwargs = {
"task_id": TASK_ID,
"data_pipeline_name": TEST_DATA_PIPELINE_NAME,
"project_id": TEST_PROJECTID,
"location": None,
"gcp_conn_id": TEST_GCP_CONN_ID,
}
with pytest.raises(AirflowException):
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())

def test_invalid_response(self):
"""
Test that AirflowException is raised if Run Operator fails execution and returns error.
"""
init_kwargs = {
"task_id": TASK_ID,
"data_pipeline_name": TEST_DATA_PIPELINE_NAME,
"project_id": TEST_PROJECTID,
"location": TEST_LOCATION,
"gcp_conn_id": TEST_GCP_CONN_ID,
}
with pytest.raises(AirflowException):
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock()).return_value = {
"error": {"message": "example error"}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow import models
from airflow.providers.google.cloud.operators.datapipeline import (
CreateDataPipelineOperator,
RunDataPipelineOperator,
)
from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator
from airflow.providers.google.cloud.transfers.local_to_gcs import LocalFilesystemToGCSOperator
Expand All @@ -38,7 +39,7 @@
GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
GCP_LOCATION = os.environ.get("location", "us-central1")

PIPELINE_NAME = "defualt-pipeline-name"
PIPELINE_NAME = os.environ.get("DATA_PIPELINE_NAME", "defualt-pipeline-name")
PIPELINE_TYPE = "PIPELINE_TYPE_BATCH"

BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}"
Expand Down Expand Up @@ -117,6 +118,13 @@
# when "teardown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()

# [START howto_operator_run_data_pipeline]
run_data_pipeline = RunDataPipelineOperator(
task_id="run_data_pipeline",
data_pipeline_name=PIPELINE_NAME,
project_id=GCP_PROJECT_ID,
)
# [END howto_operator_run_data_pipeline]

from tests.system.utils import get_test_run # noqa: E402

Expand Down

0 comments on commit c8de9a5

Please sign in to comment.