Skip to content

Commit

Permalink
Cloud Storage assets & StorageLink update (#23865)
Browse files Browse the repository at this point in the history
Co-authored-by: Wojciech Januszek <[email protected]>
  • Loading branch information
wojsamjan and Wojciech Januszek committed Jun 6, 2022
1 parent 048b617 commit 80c1ce7
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 9 deletions.
Expand Up @@ -711,7 +711,7 @@ def execute(self, context: "Context"):

DataprocMetastoreLink.persist(context=context, task_instance=self, url=METASTORE_EXPORT_LINK)
uri = self._get_uri_from_destination(MetadataExport.to_dict(metadata_export)["destination_gcs_uri"])
StorageLink.persist(context=context, task_instance=self, uri=uri)
StorageLink.persist(context=context, task_instance=self, uri=uri, project_id=self.project_id)
return MetadataExport.to_dict(metadata_export)

def _get_uri_from_destination(self, destination_uri: str):
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/google/cloud/operators/datastore.py
Expand Up @@ -140,6 +140,7 @@ def execute(self, context: 'Context') -> dict:
context=context,
task_instance=self,
uri=f"{self.bucket}/{result['response']['outputUrl'].split('/')[3]}",
project_id=self.project_id or ds_hook.project_id,
)
return result

Expand Down
57 changes: 57 additions & 0 deletions airflow/providers/google/cloud/operators/gcs.py
Expand Up @@ -35,6 +35,7 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.common.links.storage import FileDetailsLink, StorageLink
from airflow.utils import timezone


Expand Down Expand Up @@ -107,6 +108,7 @@ class GCSCreateBucketOperator(BaseOperator):
'impersonation_chain',
)
ui_color = '#f0eee4'
operator_extra_links = (StorageLink(),)

def __init__(
self,
Expand Down Expand Up @@ -139,6 +141,12 @@ def execute(self, context: "Context") -> None:
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
StorageLink.persist(
context=context,
task_instance=self,
uri=self.bucket_name,
project_id=self.project_id or hook.project_id,
)
try:
hook.create_bucket(
bucket_name=self.bucket_name,
Expand Down Expand Up @@ -200,6 +208,8 @@ class GCSListObjectsOperator(BaseOperator):

ui_color = '#f0eee4'

operator_extra_links = (StorageLink(),)

def __init__(
self,
*,
Expand Down Expand Up @@ -234,6 +244,13 @@ def execute(self, context: "Context") -> list:
self.prefix,
)

StorageLink.persist(
context=context,
task_instance=self,
uri=self.bucket,
project_id=hook.project_id,
)

return hook.list(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter)


Expand Down Expand Up @@ -346,6 +363,7 @@ class GCSBucketCreateAclEntryOperator(BaseOperator):
'impersonation_chain',
)
# [END gcs_bucket_create_acl_template_fields]
operator_extra_links = (StorageLink(),)

def __init__(
self,
Expand All @@ -371,6 +389,12 @@ def execute(self, context: "Context") -> None:
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
StorageLink.persist(
context=context,
task_instance=self,
uri=self.bucket,
project_id=hook.project_id,
)
hook.insert_bucket_acl(
bucket_name=self.bucket, entity=self.entity, role=self.role, user_project=self.user_project
)
Expand Down Expand Up @@ -418,6 +442,7 @@ class GCSObjectCreateAclEntryOperator(BaseOperator):
'impersonation_chain',
)
# [END gcs_object_create_acl_template_fields]
operator_extra_links = (FileDetailsLink(),)

def __init__(
self,
Expand Down Expand Up @@ -447,6 +472,12 @@ def execute(self, context: "Context") -> None:
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
FileDetailsLink.persist(
context=context,
task_instance=self,
uri=f"{self.bucket}/{self.object_name}",
project_id=hook.project_id,
)
hook.insert_object_acl(
bucket_name=self.bucket,
object_name=self.object_name,
Expand Down Expand Up @@ -498,6 +529,7 @@ class GCSFileTransformOperator(BaseOperator):
'transform_script',
'impersonation_chain',
)
operator_extra_links = (FileDetailsLink(),)

def __init__(
self,
Expand Down Expand Up @@ -549,6 +581,12 @@ def execute(self, context: "Context") -> None:
self.log.info("Transformation succeeded. Output temporarily located at %s", destination_file.name)

self.log.info("Uploading file to %s as %s", self.destination_bucket, self.destination_object)
FileDetailsLink.persist(
context=context,
task_instance=self,
uri=f"{self.destination_bucket}/{self.destination_object}",
project_id=hook.project_id,
)
hook.upload(
bucket_name=self.destination_bucket,
object_name=self.destination_object,
Expand Down Expand Up @@ -628,6 +666,7 @@ class GCSTimeSpanFileTransformOperator(BaseOperator):
'source_impersonation_chain',
'destination_impersonation_chain',
)
operator_extra_links = (StorageLink(),)

@staticmethod
def interpolate_prefix(prefix: str, dt: datetime.datetime) -> Optional[str]:
Expand Down Expand Up @@ -718,6 +757,12 @@ def execute(self, context: "Context") -> List[str]:
gcp_conn_id=self.destination_gcp_conn_id,
impersonation_chain=self.destination_impersonation_chain,
)
StorageLink.persist(
context=context,
task_instance=self,
uri=self.destination_bucket,
project_id=destination_hook.project_id,
)

# Fetch list of files.
blobs_to_transform = source_hook.list_by_timespan(
Expand Down Expand Up @@ -904,6 +949,7 @@ class GCSSynchronizeBucketsOperator(BaseOperator):
'delegate_to',
'impersonation_chain',
)
operator_extra_links = (StorageLink(),)

def __init__(
self,
Expand Down Expand Up @@ -938,6 +984,12 @@ def execute(self, context: "Context") -> None:
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
StorageLink.persist(
context=context,
task_instance=self,
uri=self._get_uri(self.destination_bucket, self.destination_object),
project_id=hook.project_id,
)
hook.sync(
source_bucket=self.source_bucket,
destination_bucket=self.destination_bucket,
Expand All @@ -947,3 +999,8 @@ def execute(self, context: "Context") -> None:
delete_extra_files=self.delete_extra_files,
allow_overwrite=self.allow_overwrite,
)

def _get_uri(self, gcs_bucket: str, gcs_object: Optional[str]) -> str:
if gcs_object and gcs_object[-1] == "/":
gcs_object = gcs_object[:-1]
return f"{gcs_bucket}/{gcs_object}" if gcs_object else gcs_bucket
4 changes: 2 additions & 2 deletions airflow/providers/google/common/links/storage.py
Expand Up @@ -36,11 +36,11 @@ class StorageLink(BaseGoogleLink):
format_str = GCS_STORAGE_LINK

@staticmethod
def persist(context: "Context", task_instance, uri: str):
def persist(context: "Context", task_instance, uri: str, project_id: Optional[str]):
task_instance.xcom_push(
context=context,
key=StorageLink.key,
value={"uri": uri, "project_id": task_instance.project_id},
value={"uri": uri, "project_id": project_id},
)


Expand Down
14 changes: 8 additions & 6 deletions tests/providers/google/cloud/operators/test_gcs.py
Expand Up @@ -57,7 +57,7 @@ def test_execute(self, mock_hook):
project_id=TEST_PROJECT,
)

operator.execute(None)
operator.execute(context=mock.MagicMock())
mock_hook.return_value.create_bucket.assert_called_once_with(
bucket_name=TEST_BUCKET,
storage_class="MULTI_REGIONAL",
Expand All @@ -78,7 +78,7 @@ def test_bucket_create_acl(self, mock_hook):
user_project="test-user-project",
task_id="id",
)
operator.execute(None)
operator.execute(context=mock.MagicMock())
mock_hook.return_value.insert_bucket_acl.assert_called_once_with(
bucket_name="test-bucket",
entity="test-entity",
Expand All @@ -97,7 +97,7 @@ def test_object_create_acl(self, mock_hook):
user_project="test-user-project",
task_id="id",
)
operator.execute(None)
operator.execute(context=mock.MagicMock())
mock_hook.return_value.insert_object_acl.assert_called_once_with(
bucket_name="test-bucket",
object_name="test-object",
Expand Down Expand Up @@ -148,7 +148,7 @@ def test_execute(self, mock_hook):
task_id=TASK_ID, bucket=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
)

files = operator.execute(None)
files = operator.execute(context=mock.MagicMock())
mock_hook.return_value.list.assert_called_once_with(
bucket_name=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
)
Expand Down Expand Up @@ -197,7 +197,7 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempfile):
destination_bucket=destination_bucket,
transform_script=transform_script,
)
op.execute(None)
op.execute(context=mock.MagicMock())

mock_hook.return_value.download.assert_called_once_with(
bucket_name=source_bucket, object_name=source_object, filename=source
Expand Down Expand Up @@ -273,9 +273,11 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempdir):
timespan_end = timespan_start + timedelta(hours=1)
mock_dag = mock.Mock()
mock_dag.following_schedule = lambda x: x + timedelta(hours=1)
mock_ti = mock.Mock()
context = dict(
execution_date=timespan_start,
dag=mock_dag,
ti=mock_ti,
)

mock_tempdir.return_value.__enter__.side_effect = [source, destination]
Expand Down Expand Up @@ -397,7 +399,7 @@ def test_execute(self, mock_hook):
delegate_to="DELEGATE_TO",
impersonation_chain=IMPERSONATION_CHAIN,
)
task.execute({})
task.execute(context=mock.MagicMock())
mock_hook.assert_called_once_with(
gcp_conn_id='GCP_CONN_ID',
delegate_to='DELEGATE_TO',
Expand Down

0 comments on commit 80c1ce7

Please sign in to comment.