Skip to content

Commit

Permalink
openlineage, gcs: add openlineage methods for GcsToGcsOperator (#31350)
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <[email protected]>
  • Loading branch information
mobuchowski committed Jul 27, 2023
1 parent 5aa62de commit b733667
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 13 deletions.
29 changes: 29 additions & 0 deletions airflow/providers/google/cloud/transfers/gcs_to_gcs.py
Expand Up @@ -233,6 +233,8 @@ def __init__(
self.source_object_required = source_object_required
self.exact_match = exact_match
self.match_glob = match_glob
self.resolved_source_objects: set[str] = set()
self.resolved_target_objects: set[str] = set()

def execute(self, context: Context):

Expand Down Expand Up @@ -540,7 +542,34 @@ def _copy_single_object(self, hook, source_object, destination_object):
destination_object,
)

self.resolved_source_objects.add(source_object)
if not destination_object:
self.resolved_target_objects.add(source_object)
else:
self.resolved_target_objects.add(destination_object)

hook.rewrite(self.source_bucket, source_object, self.destination_bucket, destination_object)

if self.move_object:
hook.delete(self.source_bucket, source_object)

def get_openlineage_events_on_complete(self, task_instance):
"""
Implementing _on_complete because execute method does preprocessing on internals.
This means we won't have to normalize self.source_object and self.source_objects,
destination bucket and so on.
"""
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

return OperatorLineage(
inputs=[
Dataset(namespace=f"gs://{self.source_bucket}", name=source)
for source in sorted(self.resolved_source_objects)
],
outputs=[
Dataset(namespace=f"gs://{self.destination_bucket}", name=target)
for target in sorted(self.resolved_target_objects)
],
)
18 changes: 10 additions & 8 deletions airflow/providers/openlineage/extractors/base.py
Expand Up @@ -83,6 +83,7 @@ def get_operator_classnames(cls) -> list[str]:
return []

def extract(self) -> OperatorLineage | None:
# OpenLineage methods are optional - if there's no method, return None
try:
return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) # type: ignore
except AttributeError:
Expand All @@ -100,19 +101,20 @@ def extract_on_complete(self, task_instance) -> OperatorLineage | None:

def _get_openlineage_facets(self, get_facets_method, *args) -> OperatorLineage | None:
try:
facets = get_facets_method(*args)
facets: OperatorLineage = get_facets_method(*args)
# "rewrite" OperatorLineage to safeguard against different version of the same class
# that was existing in openlineage-airflow package outside of Airflow repo
return OperatorLineage(
inputs=facets.inputs,
outputs=facets.outputs,
run_facets=facets.run_facets,
job_facets=facets.job_facets,
)
except ImportError:
self.log.exception(
"OpenLineage provider method failed to import OpenLineage integration. "
"This should not happen."
)
except Exception:
self.log.exception("OpenLineage provider method failed to extract data from provider. ")
else:
return OperatorLineage(
inputs=facets.inputs,
outputs=facets.outputs,
run_facets=facets.run_facets,
job_facets=facets.job_facets,
)
return None
12 changes: 7 additions & 5 deletions dev/breeze/tests/test_selective_checks.py
Expand Up @@ -539,7 +539,7 @@ def test_expected_output_full_tests_needed(
{
"affected-providers-list-as-string": "amazon apache.beam apache.cassandra cncf.kubernetes "
"common.sql facebook google hashicorp microsoft.azure microsoft.mssql "
"mysql oracle postgres presto salesforce sftp ssh trino",
"mysql openlineage oracle postgres presto salesforce sftp ssh trino",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"needs-helm-tests": "false",
Expand All @@ -564,8 +564,8 @@ def test_expected_output_full_tests_needed(
{
"affected-providers-list-as-string": "amazon apache.beam apache.cassandra "
"cncf.kubernetes common.sql facebook google "
"hashicorp microsoft.azure microsoft.mssql mysql oracle postgres presto "
"salesforce sftp ssh trino",
"hashicorp microsoft.azure microsoft.mssql mysql openlineage oracle postgres "
"presto salesforce sftp ssh trino",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"image-build": "true",
Expand Down Expand Up @@ -666,7 +666,7 @@ def test_expected_output_pull_request_v2_3(
"affected-providers-list-as-string": "amazon apache.beam apache.cassandra "
"cncf.kubernetes common.sql "
"facebook google hashicorp microsoft.azure microsoft.mssql mysql "
"oracle postgres presto salesforce sftp ssh trino",
"openlineage oracle postgres presto salesforce sftp ssh trino",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"image-build": "true",
Expand All @@ -685,6 +685,7 @@ def test_expected_output_pull_request_v2_3(
"--package-filter apache-airflow-providers-microsoft-azure "
"--package-filter apache-airflow-providers-microsoft-mssql "
"--package-filter apache-airflow-providers-mysql "
"--package-filter apache-airflow-providers-openlineage "
"--package-filter apache-airflow-providers-oracle "
"--package-filter apache-airflow-providers-postgres "
"--package-filter apache-airflow-providers-presto "
Expand All @@ -697,7 +698,7 @@ def test_expected_output_pull_request_v2_3(
"skip-provider-tests": "false",
"parallel-test-types-list-as-string": "Providers[amazon] Always CLI "
"Providers[apache.beam,apache.cassandra,cncf.kubernetes,common.sql,facebook,"
"hashicorp,microsoft.azure,microsoft.mssql,mysql,oracle,postgres,presto,"
"hashicorp,microsoft.azure,microsoft.mssql,mysql,openlineage,oracle,postgres,presto,"
"salesforce,sftp,ssh,trino] Providers[google]",
},
id="CLI tests and Google-related provider tests should run if cli/chart files changed",
Expand Down Expand Up @@ -965,6 +966,7 @@ def test_upgrade_to_newer_dependencies(files: tuple[str, ...], expected_outputs:
"--package-filter apache-airflow-providers-microsoft-azure "
"--package-filter apache-airflow-providers-microsoft-mssql "
"--package-filter apache-airflow-providers-mysql "
"--package-filter apache-airflow-providers-openlineage "
"--package-filter apache-airflow-providers-oracle "
"--package-filter apache-airflow-providers-postgres "
"--package-filter apache-airflow-providers-presto "
Expand Down
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Expand Up @@ -467,6 +467,7 @@
"microsoft.azure",
"microsoft.mssql",
"mysql",
"openlineage",
"oracle",
"postgres",
"presto",
Expand Down
73 changes: 73 additions & 0 deletions tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
Expand Up @@ -21,6 +21,7 @@
from unittest import mock

import pytest
from openlineage.client.run import Dataset

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.transfers.gcs_to_gcs import WILDCARD, GCSToGCSOperator
Expand Down Expand Up @@ -827,3 +828,75 @@ def test_copy_files_into_a_folder(
for src, dst in zip(expected_source_objects, expected_destination_objects)
]
mock_hook.return_value.rewrite.assert_has_calls(mock_calls)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook")
def test_execute_simple_reports_openlineage(self, mock_hook):
operator = GCSToGCSOperator(
task_id=TASK_ID,
source_bucket=TEST_BUCKET,
source_object=SOURCE_OBJECTS_SINGLE_FILE[0],
destination_bucket=DESTINATION_BUCKET,
)

operator.execute(None)

lineage = operator.get_openlineage_events_on_complete(None)
assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0] == Dataset(
namespace=f"gs://{TEST_BUCKET}", name=SOURCE_OBJECTS_SINGLE_FILE[0]
)
assert lineage.outputs[0] == Dataset(
namespace=f"gs://{DESTINATION_BUCKET}", name=SOURCE_OBJECTS_SINGLE_FILE[0]
)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook")
def test_execute_multiple_reports_openlineage(self, mock_hook):
operator = GCSToGCSOperator(
task_id=TASK_ID,
source_bucket=TEST_BUCKET,
source_objects=SOURCE_OBJECTS_LIST,
destination_bucket=DESTINATION_BUCKET,
destination_object=DESTINATION_OBJECT,
)

operator.execute(None)

lineage = operator.get_openlineage_events_on_complete(None)
assert len(lineage.inputs) == 3
assert len(lineage.outputs) == 1
assert lineage.inputs == [
Dataset(namespace=f"gs://{TEST_BUCKET}", name=SOURCE_OBJECTS_LIST[0]),
Dataset(namespace=f"gs://{TEST_BUCKET}", name=SOURCE_OBJECTS_LIST[1]),
Dataset(namespace=f"gs://{TEST_BUCKET}", name=SOURCE_OBJECTS_LIST[2]),
]
assert lineage.outputs[0] == Dataset(namespace=f"gs://{DESTINATION_BUCKET}", name=DESTINATION_OBJECT)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook")
def test_execute_wildcard_reports_openlineage(self, mock_hook):
mock_hook.return_value.list.return_value = [
"test_object1.txt",
"test_object2.txt",
]

operator = GCSToGCSOperator(
task_id=TASK_ID,
source_bucket=TEST_BUCKET,
source_object=SOURCE_OBJECT_WILDCARD_SUFFIX,
destination_bucket=DESTINATION_BUCKET,
destination_object=DESTINATION_OBJECT,
)

operator.execute(None)

lineage = operator.get_openlineage_events_on_complete(None)
assert len(lineage.inputs) == 2
assert len(lineage.outputs) == 2
assert lineage.inputs == [
Dataset(namespace=f"gs://{TEST_BUCKET}", name="test_object1.txt"),
Dataset(namespace=f"gs://{TEST_BUCKET}", name="test_object2.txt"),
]
assert lineage.outputs == [
Dataset(namespace=f"gs://{DESTINATION_BUCKET}", name="foo/bar/1.txt"),
Dataset(namespace=f"gs://{DESTINATION_BUCKET}", name="foo/bar/2.txt"),
]

0 comments on commit b733667

Please sign in to comment.