Skip to content

Commit

Permalink
Fix GCSSynchronizeBucketsOperator timeout error (#37237)
Browse files Browse the repository at this point in the history
Update comment to be more clear
  • Loading branch information
kevgeo committed Feb 16, 2024
1 parent 5ad1e78 commit 123b656
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 25 deletions.
20 changes: 17 additions & 3 deletions airflow/providers/google/cloud/hooks/gcs.py
Expand Up @@ -1213,15 +1213,19 @@ def sync(
:return: none
"""
client = self.get_conn()

# Create bucket object
source_bucket_obj = client.bucket(source_bucket)
destination_bucket_obj = client.bucket(destination_bucket)

# Normalize parameters when they are passed
source_object = normalize_directory_path(source_object)
destination_object = normalize_directory_path(destination_object)

# Calculate the number of characters that remove from the name, because they contain information
# about the parent's path
source_object_prefix_len = len(source_object) if source_object else 0

# Prepare synchronization plan
to_copy_blobs, to_delete_blobs, to_rewrite_blobs = self._prepare_sync_plan(
source_bucket=source_bucket_obj,
Expand All @@ -1246,13 +1250,14 @@ def sync(
dst_object = self._calculate_sync_destination_path(
blob, destination_object, source_object_prefix_len
)
self.copy(
self.rewrite(
source_bucket=source_bucket_obj.name,
source_object=blob.name,
destination_bucket=destination_bucket_obj.name,
destination_object=dst_object,
)
self.log.info("Blobs copied.")

# Delete redundant files
if not to_delete_blobs:
self.log.info("Skipped blobs deleting.")
Expand Down Expand Up @@ -1297,37 +1302,46 @@ def _prepare_sync_plan(
destination_object: str | None,
recursive: bool,
) -> tuple[set[storage.Blob], set[storage.Blob], set[storage.Blob]]:
# Calculate the number of characters that remove from the name, because they contain information
# Calculate the number of characters that are removed from the name, because they contain information
# about the parent's path
source_object_prefix_len = len(source_object) if source_object else 0
destination_object_prefix_len = len(destination_object) if destination_object else 0
delimiter = "/" if not recursive else None

# Fetch blobs list
source_blobs = list(source_bucket.list_blobs(prefix=source_object, delimiter=delimiter))
destination_blobs = list(
destination_bucket.list_blobs(prefix=destination_object, delimiter=delimiter)
)

# Create indexes that allow you to identify blobs based on their name
source_names_index = {a.name[source_object_prefix_len:]: a for a in source_blobs}
destination_names_index = {a.name[destination_object_prefix_len:]: a for a in destination_blobs}

# Create sets with names without parent object name
source_names = set(source_names_index.keys())
# Discards empty string from source set that creates an empty subdirectory in
# destination bucket with source subdirectory name
source_names.discard("")
destination_names = set(destination_names_index.keys())

# Determine objects to copy and delete
to_copy = source_names - destination_names
to_delete = destination_names - source_names
to_copy_blobs: set[storage.Blob] = {source_names_index[a] for a in to_copy}
to_delete_blobs: set[storage.Blob] = {destination_names_index[a] for a in to_delete}

# Find names that are in both buckets
names_to_check = source_names.intersection(destination_names)
to_rewrite_blobs: set[storage.Blob] = set()
# Compare objects based on crc32
for current_name in names_to_check:
source_blob = source_names_index[current_name]
destination_blob = destination_names_index[current_name]
# if the objects are different, save it
# If the objects are different, save it
if source_blob.crc32c != destination_blob.crc32c:
to_rewrite_blobs.add(source_blob)

return to_copy_blobs, to_delete_blobs, to_rewrite_blobs


Expand Down
32 changes: 10 additions & 22 deletions tests/providers/google/cloud/hooks/test_gcs.py
Expand Up @@ -1081,7 +1081,6 @@ def setup_method(self):
def test_should_do_nothing_when_buckets_is_empty(
self, mock_get_conn, mock_delete, mock_rewrite, mock_copy
):
# mock_get_conn.return_value =
source_bucket = self._create_bucket(name="SOURCE_BUCKET")
source_bucket.list_blobs.return_value = []
destination_bucket = self._create_bucket(name="DEST_BUCKET")
Expand All @@ -1104,7 +1103,6 @@ def test_should_do_nothing_when_buckets_is_empty(
def test_should_append_slash_to_object_if_missing(
self, mock_get_conn, mock_delete, mock_rewrite, mock_copy
):
# mock_get_conn.return_value =
source_bucket = self._create_bucket(name="SOURCE_BUCKET")
source_bucket.list_blobs.return_value = []
destination_bucket = self._create_bucket(name="DEST_BUCKET")
Expand All @@ -1124,7 +1122,6 @@ def test_should_append_slash_to_object_if_missing(
@mock.patch(GCS_STRING.format("GCSHook.delete"))
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_should_copy_files(self, mock_get_conn, mock_delete, mock_rewrite, mock_copy):
# mock_get_conn.return_value =
source_bucket = self._create_bucket(name="SOURCE_BUCKET")
source_bucket.list_blobs.return_value = [
self._create_blob("FILE_A", "C1"),
Expand All @@ -1135,31 +1132,30 @@ def test_should_copy_files(self, mock_get_conn, mock_delete, mock_rewrite, mock_
mock_get_conn.return_value.bucket.side_effect = [source_bucket, destination_bucket]
self.gcs_hook.sync(source_bucket="SOURCE_BUCKET", destination_bucket="DEST_BUCKET")
mock_delete.assert_not_called()
mock_rewrite.assert_not_called()
mock_copy.assert_has_calls(
mock_rewrite.assert_has_calls(
[
mock.call(
destination_bucket="DEST_BUCKET",
destination_object="FILE_A",
source_bucket="SOURCE_BUCKET",
source_object="FILE_A",
destination_bucket="DEST_BUCKET",
destination_object="FILE_A",
),
mock.call(
destination_bucket="DEST_BUCKET",
destination_object="FILE_B",
source_bucket="SOURCE_BUCKET",
source_object="FILE_B",
destination_bucket="DEST_BUCKET",
destination_object="FILE_B",
),
],
any_order=True,
)
mock_copy.assert_not_called()

@mock.patch(GCS_STRING.format("GCSHook.copy"))
@mock.patch(GCS_STRING.format("GCSHook.rewrite"))
@mock.patch(GCS_STRING.format("GCSHook.delete"))
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_should_copy_files_non_recursive(self, mock_get_conn, mock_delete, mock_rewrite, mock_copy):
# mock_get_conn.return_value =
source_bucket = self._create_bucket(name="SOURCE_BUCKET")
source_bucket.list_blobs.return_value = [
self._create_blob("FILE_A", "C1"),
Expand All @@ -1177,7 +1173,6 @@ def test_should_copy_files_non_recursive(self, mock_get_conn, mock_delete, mock_
@mock.patch(GCS_STRING.format("GCSHook.delete"))
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_should_copy_files_to_subdirectory(self, mock_get_conn, mock_delete, mock_rewrite, mock_copy):
# mock_get_conn.return_value =
source_bucket = self._create_bucket(name="SOURCE_BUCKET")
source_bucket.list_blobs.return_value = [
self._create_blob("FILE_A", "C1"),
Expand All @@ -1190,8 +1185,7 @@ def test_should_copy_files_to_subdirectory(self, mock_get_conn, mock_delete, moc
source_bucket="SOURCE_BUCKET", destination_bucket="DEST_BUCKET", destination_object="DEST_OBJ/"
)
mock_delete.assert_not_called()
mock_rewrite.assert_not_called()
mock_copy.assert_has_calls(
mock_rewrite.assert_has_calls(
[
mock.call(
source_bucket="SOURCE_BUCKET",
Expand All @@ -1208,13 +1202,13 @@ def test_should_copy_files_to_subdirectory(self, mock_get_conn, mock_delete, moc
],
any_order=True,
)
mock_copy.assert_not_called()

@mock.patch(GCS_STRING.format("GCSHook.copy"))
@mock.patch(GCS_STRING.format("GCSHook.rewrite"))
@mock.patch(GCS_STRING.format("GCSHook.delete"))
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_should_copy_files_from_subdirectory(self, mock_get_conn, mock_delete, mock_rewrite, mock_copy):
# mock_get_conn.return_value =
source_bucket = self._create_bucket(name="SOURCE_BUCKET")
source_bucket.list_blobs.return_value = [
self._create_blob("SRC_OBJ/FILE_A", "C1"),
Expand All @@ -1227,8 +1221,7 @@ def test_should_copy_files_from_subdirectory(self, mock_get_conn, mock_delete, m
source_bucket="SOURCE_BUCKET", destination_bucket="DEST_BUCKET", source_object="SRC_OBJ/"
)
mock_delete.assert_not_called()
mock_rewrite.assert_not_called()
mock_copy.assert_has_calls(
mock_rewrite.assert_has_calls(
[
mock.call(
source_bucket="SOURCE_BUCKET",
Expand All @@ -1245,13 +1238,13 @@ def test_should_copy_files_from_subdirectory(self, mock_get_conn, mock_delete, m
],
any_order=True,
)
mock_copy.assert_not_called()

@mock.patch(GCS_STRING.format("GCSHook.copy"))
@mock.patch(GCS_STRING.format("GCSHook.rewrite"))
@mock.patch(GCS_STRING.format("GCSHook.delete"))
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_should_overwrite_files(self, mock_get_conn, mock_delete, mock_rewrite, mock_copy):
# mock_get_conn.return_value =
source_bucket = self._create_bucket(name="SOURCE_BUCKET")
source_bucket.list_blobs.return_value = [
self._create_blob("FILE_A", "C1"),
Expand Down Expand Up @@ -1293,7 +1286,6 @@ def test_should_overwrite_files(self, mock_get_conn, mock_delete, mock_rewrite,
def test_should_overwrite_files_to_subdirectory(
self, mock_get_conn, mock_delete, mock_rewrite, mock_copy
):
# mock_get_conn.return_value =
source_bucket = self._create_bucket(name="SOURCE_BUCKET")
source_bucket.list_blobs.return_value = [
self._create_blob("FILE_A", "C1"),
Expand Down Expand Up @@ -1338,7 +1330,6 @@ def test_should_overwrite_files_to_subdirectory(
def test_should_overwrite_files_from_subdirectory(
self, mock_get_conn, mock_delete, mock_rewrite, mock_copy
):
# mock_get_conn.return_value =
source_bucket = self._create_bucket(name="SOURCE_BUCKET")
source_bucket.list_blobs.return_value = [
self._create_blob("SRC_OBJ/FILE_A", "C1"),
Expand Down Expand Up @@ -1381,7 +1372,6 @@ def test_should_overwrite_files_from_subdirectory(
@mock.patch(GCS_STRING.format("GCSHook.delete"))
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_should_delete_extra_files(self, mock_get_conn, mock_delete, mock_rewrite, mock_copy):
# mock_get_conn.return_value =
source_bucket = self._create_bucket(name="SOURCE_BUCKET")
source_bucket.list_blobs.return_value = []
destination_bucket = self._create_bucket(name="DEST_BUCKET")
Expand All @@ -1407,7 +1397,6 @@ def test_should_delete_extra_files(self, mock_get_conn, mock_delete, mock_rewrit
def test_should_not_delete_extra_files_when_delete_extra_files_is_disabled(
self, mock_get_conn, mock_delete, mock_rewrite, mock_copy
):
# mock_get_conn.return_value =
source_bucket = self._create_bucket(name="SOURCE_BUCKET")
source_bucket.list_blobs.return_value = []
destination_bucket = self._create_bucket(name="DEST_BUCKET")
Expand All @@ -1430,7 +1419,6 @@ def test_should_not_delete_extra_files_when_delete_extra_files_is_disabled(
def test_should_not_overwrite_when_overwrite_is_disabled(
self, mock_get_conn, mock_delete, mock_rewrite, mock_copy
):
# mock_get_conn.return_value =
source_bucket = self._create_bucket(name="SOURCE_BUCKET")
source_bucket.list_blobs.return_value = [
self._create_blob("SRC_OBJ/FILE_A", "C1", source_bucket),
Expand Down

0 comments on commit 123b656

Please sign in to comment.