Skip to content

Commit

Permalink
Refactor GoogleDriveToGCSOperator to use common methods (#14276)
Browse files Browse the repository at this point in the history
Refactor GoogleDriveToGCSOperator to use common methods implemented
in hooks used by this operator.
  • Loading branch information
Scuall1992 committed Feb 21, 2021
1 parent 82cb041 commit a7e4266
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 86 deletions.
60 changes: 27 additions & 33 deletions airflow/providers/google/cloud/transfers/gdrive_to_gcs.py
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from io import BytesIO
import warnings
from typing import Optional, Sequence, Union

from airflow.models import BaseOperator
Expand All @@ -32,11 +32,15 @@ class GoogleDriveToGCSOperator(BaseOperator):
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:GoogleDriveToGCSOperator`
:param destination_bucket: The destination Google cloud storage bucket where the
:param bucket_name: The destination Google cloud storage bucket where the
file should be written to
:type destination_bucket: str
:param destination_object: The Google Cloud Storage object name for the object created by the operator.
:type bucket_name: str
:param object_name: The Google Cloud Storage object name for the object created by the operator.
For example: ``path/to/my/file/file.txt``.
:type object_name: str
:param destination_bucket: Same as bucket_name, but for backward compatibly
:type destination_bucket: str
:param destination_object: Same as object_name, but for backward compatibly
:type destination_object: str
:param folder_id: The folder id of the folder in which the Google Drive file resides
:type folder_id: str
Expand All @@ -62,6 +66,8 @@ class GoogleDriveToGCSOperator(BaseOperator):
"""

template_fields = [
"bucket_name",
"object_name",
"destination_bucket",
"destination_object",
"folder_id",
Expand All @@ -74,8 +80,10 @@ class GoogleDriveToGCSOperator(BaseOperator):
def __init__(
self,
*,
destination_bucket: str,
destination_object: str,
bucket_name: Optional[str] = None,
object_name: Optional[str] = None,
destination_bucket: Optional[str] = None, # deprecated
destination_object: Optional[str] = None, # deprecated
file_name: str,
folder_id: str,
drive_id: Optional[str] = None,
Expand All @@ -85,38 +93,18 @@ def __init__(
**kwargs,
) -> None:
super().__init__(**kwargs)
self.destination_bucket = destination_bucket
self.destination_object = destination_object
self.bucket_name = destination_bucket or bucket_name
if destination_bucket:
warnings.warn("`destination_bucket` is deprecated please use `bucket_name`", DeprecationWarning)
self.object_name = destination_object or object_name
if destination_object:
warnings.warn("`destination_object` is deprecated please use `object_name`", DeprecationWarning)
self.folder_id = folder_id
self.drive_id = drive_id
self.file_name = file_name
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
self.file_metadata = None

def _set_file_metadata(self, gdrive_hook):
if not self.file_metadata:
self.file_metadata = gdrive_hook.get_file_id(
folder_id=self.folder_id, file_name=self.file_name, drive_id=self.drive_id
)
return self.file_metadata

def _upload_data(self, gcs_hook: GCSHook, gdrive_hook: GoogleDriveHook) -> str:
file_handle = BytesIO()
self._set_file_metadata(gdrive_hook=gdrive_hook)
file_id = self.file_metadata["id"]
mime_type = self.file_metadata["mime_type"]
request = gdrive_hook.get_media_request(file_id=file_id)
gdrive_hook.download_content_from_request(
file_handle=file_handle, request=request, chunk_size=104857600
)
gcs_hook.upload(
bucket_name=self.destination_bucket,
object_name=self.destination_object,
data=file_handle.getvalue(),
mime_type=mime_type,
)

def execute(self, context):
gdrive_hook = GoogleDriveHook(
Expand All @@ -129,4 +117,10 @@ def execute(self, context):
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
self._upload_data(gdrive_hook=gdrive_hook, gcs_hook=gcs_hook)
file_metadata = gdrive_hook.get_file_id(
folder_id=self.folder_id, file_name=self.file_name, drive_id=self.drive_id
)
with gcs_hook.provide_file_and_upload(
bucket_name=self.bucket_name, object_name=self.object_name
) as file:
gdrive_hook.download_file(file_id=file_metadata["id"], file_handle=file)
60 changes: 12 additions & 48 deletions tests/providers/google/cloud/transfers/test_gdrive_to_gcs.py
Expand Up @@ -30,48 +30,9 @@


class TestGoogleDriveToGCSOperator:
@mock.patch("airflow.providers.google.cloud.transfers.gdrive_to_gcs.BytesIO")
@mock.patch("airflow.providers.google.cloud.transfers.gdrive_to_gcs.GoogleDriveHook")
def test_upload_data(self, mock_gdrive_hook, mock_file_handle):
mock_gdrive_hook.return_value.get_media_request.return_value = mock.MagicMock()

file_id = mock_gdrive_hook.get_file_id.return_value["id"]
mime_type = mock_gdrive_hook.get_file_id.return_value["mime_type"]

mock_gcs_hook = mock.Mock()

op = GoogleDriveToGCSOperator(
task_id="test_task",
folder_id=FOLDER_ID,
file_name=FILE_NAME,
drive_id=DRIVE_ID,
destination_bucket=BUCKET,
destination_object=OBJECT,
)

op._upload_data(
gcs_hook=mock_gcs_hook,
gdrive_hook=mock_gdrive_hook,
)
# Test writing to file
mock_gdrive_hook.get_media_request.assert_called_once_with(file_id=file_id)
mock_gdrive_hook.download_content_from_request.assert_called_once_with(
file_handle=mock_file_handle(),
request=mock_gdrive_hook.get_media_request.return_value,
chunk_size=104857600,
)

# Test upload
mock_gcs_hook.upload.assert_called_once_with(
bucket_name=BUCKET, object_name=OBJECT, data=mock_file_handle().getvalue(), mime_type=mime_type
)

@mock.patch("airflow.providers.google.cloud.transfers.gdrive_to_gcs.GCSHook")
@mock.patch("airflow.providers.google.cloud.transfers.gdrive_to_gcs.GoogleDriveHook")
@mock.patch(
"airflow.providers.google.cloud.transfers.gdrive_to_gcs.GoogleDriveToGCSOperator._upload_data"
)
def test_execute(self, mock_upload_data, mock_gdrive_hook, mock_gcs_hook):
def test_execute(self, mock_gdrive_hook, mock_gcs_hook):
context = {}
op = GoogleDriveToGCSOperator(
task_id="test_task",
Expand All @@ -83,15 +44,18 @@ def test_execute(self, mock_upload_data, mock_gdrive_hook, mock_gcs_hook):
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
meta = {"id": "123xyz"}
mock_gdrive_hook.return_value.get_file_id.return_value = meta

op.execute(context)
mock_gdrive_hook.return_value.get_file_id.assert_called_once_with(
folder_id=FOLDER_ID, file_name=FILE_NAME, drive_id=DRIVE_ID
)

mock_gdrive_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
delegate_to=None,
impersonation_chain=IMPERSONATION_CHAIN,
mock_gdrive_hook.return_value.download_file.assert_called_once_with(
file_id=meta["id"], file_handle=mock.ANY
)
mock_gcs_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
delegate_to=None,
impersonation_chain=IMPERSONATION_CHAIN,

mock_gcs_hook.return_value.provide_file_and_upload.assert_called_once_with(
bucket_name=BUCKET, object_name=OBJECT
)
Expand Up @@ -35,13 +35,12 @@ def test_execute(self, hook_mock):
file_name=FILE_NAME,
output_file=temp_file.name,
)
meta = {"id": "123xyz"}
hook_mock.return_value.get_file_id.return_value = meta

op.execute(context=None)
hook_mock.assert_called_once_with(delegate_to=None, impersonation_chain=None)

hook_mock.return_value.get_file_id.assert_called_once_with(
folder_id=FOLDER_ID, file_name=FILE_NAME, drive_id=None
)

hook_mock.return_value.download_file.assert_called_once_with(
file_id=mock.ANY, file_handle=mock.ANY
file_id=meta["id"], file_handle=mock.ANY
)

0 comments on commit a7e4266

Please sign in to comment.