Skip to content

Commit

Permalink
Allow for uploading metadata with GCS Hook Upload (#22058)
Browse files Browse the repository at this point in the history
  • Loading branch information
patricker committed Mar 7, 2022
1 parent c1faaf3 commit a11d831
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
6 changes: 6 additions & 0 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def upload(
chunk_size: Optional[int] = None,
timeout: Optional[int] = DEFAULT_TIMEOUT,
num_max_attempts: int = 1,
metadata: Optional[dict] = None,
) -> None:
"""
Uploads a local file or file data as string or bytes to Google Cloud Storage.
Expand All @@ -461,6 +462,7 @@ def upload(
:param chunk_size: Blob chunk size.
:param timeout: Request timeout in seconds.
:param num_max_attempts: Number of attempts to try to upload the file.
:param metadata: The metadata to be uploaded with the file.
"""

def _call_with_retry(f: Callable[[], None]) -> None:
Expand Down Expand Up @@ -493,6 +495,10 @@ def _call_with_retry(f: Callable[[], None]) -> None:
client = self.get_conn()
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size)

if metadata:
blob.metadata = metadata

if filename and data:
raise ValueError(
"'filename' and 'data' parameter provided. Please "
Expand Down
10 changes: 8 additions & 2 deletions tests/providers/google/cloud/hooks/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,15 +789,21 @@ def tearDown(self):
def test_upload_file(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'
metadata = {'key1': 'val1', 'key2': 'key2'}

upload_method = mock_service.return_value.bucket.return_value.blob.return_value.upload_from_filename
bucket_mock = mock_service.return_value.bucket
blob_object = bucket_mock.return_value.blob

self.gcs_hook.upload(test_bucket, test_object, filename=self.testfile.name)
upload_method = blob_object.return_value.upload_from_filename

self.gcs_hook.upload(test_bucket, test_object, filename=self.testfile.name, metadata=metadata)

upload_method.assert_called_once_with(
filename=self.testfile.name, content_type='application/octet-stream', timeout=60
)

self.assertEqual(metadata, blob_object.return_value.metadata)

@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_upload_file_gzip(self, mock_service):
test_bucket = 'test_bucket'
Expand Down

0 comments on commit a11d831

Please sign in to comment.