Skip to content

Commit

Permalink
[AIRFLOW-5610] Add ability to specify multiple objects to copy in GCS…
Browse files Browse the repository at this point in the history
…ToGCSOperator (#7728)
  • Loading branch information
ephraimbuddy committed Mar 18, 2020
1 parent 49998ed commit 60fdbf6
Show file tree
Hide file tree
Showing 2 changed files with 269 additions and 44 deletions.
140 changes: 111 additions & 29 deletions airflow/providers/google/cloud/operators/gcs_to_gcs.py
Expand Up @@ -47,6 +47,9 @@ class GCSToGCSOperator(BaseOperator):
end of the object name. Appending a wildcard to the bucket name is
unsupported.
:type source_object: str
:param source_objects: A list of source name of the objects to copy in the Google cloud
storage bucket. (templated)
:type source_objects: List[str]
:param destination_bucket: The destination Google Cloud Storage bucket
where the object should be. If the destination_bucket is None, it defaults
to source_bucket. (templated)
Expand All @@ -61,11 +64,16 @@ class GCSToGCSOperator(BaseOperator):
file ``foo/baz`` will be copied to ``blah/baz``; to retain the prefix write
the destination_object as e.g. ``blah/foo``, in which case the copied file
will be named ``blah/foo/baz``.
The same thing applies to source objects inside source_objects.
:type destination_object: str
:param move_object: When move object is True, the object is moved instead
of copied to the new location. This is the equivalent of a mv command
as opposed to a cp command.
:type move_object: bool
:param delimiter: This is used to restrict the result to only the 'files' in a given 'folder'.
If source_objects = ['foo/bah/'] and delimiter = '.avro', then only the 'files' in the
folder 'foo/bah/' with '.avro' delimiter will be copied to the destination object.
:type delimiter: str
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform.
:type gcp_conn_id: str
:param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud
Expand All @@ -89,7 +97,7 @@ class GCSToGCSOperator(BaseOperator):
copy_single_file = GCSToGCSOperator(
task_id='copy_single_file',
source_bucket='data',
source_object='sales/sales-2017/january.avro',
source_objects=['sales/sales-2017/january.avro'],
destination_bucket='data_backup',
destination_object='copied_sales/2017/january-backup.avro',
gcp_conn_id=google_cloud_conn_id
Expand All @@ -99,6 +107,18 @@ class GCSToGCSOperator(BaseOperator):
folder (i.e. with names starting with that prefix) in ``data`` bucket to the
``copied_sales/2017`` folder in the ``data_backup`` bucket. ::
copy_files = GCSToGCSOperator(
task_id='copy_files',
source_bucket='data',
source_objects=['sales/sales-2017'],
destination_bucket='data_backup',
destination_object='copied_sales/2017/',
delimiter='.avro'
gcp_conn_id=google_cloud_conn_id
)
Or ::
copy_files = GCSToGCSOperator(
task_id='copy_files',
source_bucket='data',
Expand All @@ -122,17 +142,33 @@ class GCSToGCSOperator(BaseOperator):
gcp_conn_id=google_cloud_conn_id
)
The following Operator would move all the Avro files from ``sales/sales-2019``
and ``sales/sales-2020` folder in ``data`` bucket to the same folder in the
``data_backup`` bucket, deleting the original files in the process. ::
move_files = GCSToGCSOperator(
task_id='move_files',
source_bucket='data',
source_objects=['sales/sales-2019/*.avro', 'sales/sales-2020'],
destination_bucket='data_backup',
delimiter='.avro',
move_object=True,
gcp_conn_id=google_cloud_conn_id
)
"""
template_fields = ('source_bucket', 'source_object', 'destination_bucket',
'destination_object',)
template_fields = ('source_bucket', 'source_object', 'source_objects', 'destination_bucket',
'destination_object', 'delimiter')
ui_color = '#f0eee4'

@apply_defaults
def __init__(self,
def __init__(self, # pylint: disable=too-many-arguments
source_bucket,
source_object,
source_object=None,
source_objects=None,
destination_bucket=None,
destination_object=None,
delimiter=None,
move_object=False,
gcp_conn_id='google_cloud_default',
google_cloud_storage_conn_id=None,
Expand All @@ -141,7 +177,6 @@ def __init__(self,
*args,
**kwargs):
super().__init__(*args, **kwargs)

if google_cloud_storage_conn_id:
warnings.warn(
"The google_cloud_storage_conn_id parameter has been deprecated. You should pass "
Expand All @@ -150,8 +185,10 @@ def __init__(self,

self.source_bucket = source_bucket
self.source_object = source_object
self.source_objects = source_objects
self.destination_bucket = destination_bucket
self.destination_object = destination_object
self.delimiter = delimiter
self.move_object = move_object
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
Expand All @@ -163,36 +200,81 @@ def execute(self, context):
google_cloud_storage_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to
)
if self.source_objects and self.source_object:
error_msg = "You can either set source_object parameter or source_objects " \
"parameter but not both. Found source_object={} and" \
" source_objects={}".format(self.source_object, self.source_objects)
raise AirflowException(error_msg)

if not self.source_object and not self.source_objects:
error_msg = "You must set source_object parameter or source_objects parameter. None set"
raise AirflowException(error_msg)

if self.source_objects and not all(isinstance(item, str) for item in self.source_objects):
raise AirflowException('At least, one of the `objects` in the `source_objects` is not a string')

# If source_object is set, default it to source_objects
if self.source_object:
self.source_objects = [self.source_object]

if self.destination_bucket is None:
self.log.warning(
'destination_bucket is None. Defaulting it to source_bucket (%s)',
self.source_bucket)
self.destination_bucket = self.source_bucket

if WILDCARD in self.source_object:
total_wildcards = self.source_object.count(WILDCARD)
if total_wildcards > 1:
error_msg = "Only one wildcard '*' is allowed in source_object parameter. " \
"Found {} in {}.".format(total_wildcards, self.source_object)

raise AirflowException(error_msg)

prefix, delimiter = self.source_object.split(WILDCARD, 1)
objects = hook.list(self.source_bucket, prefix=prefix, delimiter=delimiter)

for source_object in objects:
if self.destination_object is None:
destination_object = source_object
else:
destination_object = source_object.replace(prefix,
self.destination_object, 1)

self._copy_single_object(hook=hook, source_object=source_object,
destination_object=destination_object)
else:
self._copy_single_object(hook=hook, source_object=self.source_object,
destination_object=self.destination_object)
# An empty source_object means to copy all files
if len(self.source_objects) == 0:
self.source_objects = ['']
# Raise exception if empty string `''` is used twice in source_object, this is to avoid double copy
if self.source_objects.count('') > 1:
raise AirflowException("You can't have two empty strings inside source_object")

# Iterate over the source_objects and do the copy
for prefix in self.source_objects:
# Check if prefix contains wildcard
if WILDCARD in prefix:
self._copy_source_with_wildcard(hook=hook, prefix=prefix)
# Now search with prefix using provided delimiter if any
else:
self._copy_source_without_wildcard(hook=hook, prefix=prefix)

def _copy_source_without_wildcard(self, hook, prefix):
objects = hook.list(self.source_bucket, prefix=prefix, delimiter=self.delimiter)

# If objects is empty and we have prefix, let's check if prefix is a blob
# and copy directly
if len(objects) == 0 and prefix:
if hook.exists(self.source_bucket, prefix):
self._copy_single_object(hook=hook, source_object=prefix,
destination_object=self.destination_object)
for source_obj in objects:
if self.destination_object is None:
destination_object = source_obj
else:
destination_object = self.destination_object
self._copy_single_object(hook=hook, source_object=source_obj,
destination_object=destination_object)

def _copy_source_with_wildcard(self, hook, prefix):
total_wildcards = prefix.count(WILDCARD)
if total_wildcards > 1:
error_msg = "Only one wildcard '*' is allowed in source_object parameter. " \
"Found {} in {}.".format(total_wildcards, prefix)

raise AirflowException(error_msg)
self.log.info('Delimiter ignored because wildcard is in prefix')
prefix_, delimiter = prefix.split(WILDCARD, 1)
objects = hook.list(self.source_bucket, prefix=prefix_, delimiter=delimiter)
for source_object in objects:
if self.destination_object is None:
destination_object = source_object
else:
destination_object = source_object.replace(prefix_,
self.destination_object, 1)

self._copy_single_object(hook=hook, source_object=source_object,
destination_object=destination_object)

def _copy_single_object(self, hook, source_object, destination_object):
if self.last_modified_time is not None:
Expand Down

0 comments on commit 60fdbf6

Please sign in to comment.