Skip to content

Commit

Permalink
Modify transfer operators to handle more data (#22495)
Browse files Browse the repository at this point in the history
* Modify transfer operators to handle more data

This addresses an issue where large data imports can result in filling
all available disk space and cause the task to fail.

Previously all data would be written out to disk before any was uploaded
to GCS. Now each data chunk is written to GCS and immediately freed.
  • Loading branch information
mwallace582 committed Apr 4, 2022
1 parent 46cf931 commit 99b0211
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 76 deletions.
71 changes: 46 additions & 25 deletions airflow/providers/google/cloud/transfers/cassandra_to_gcs.py
Expand Up @@ -169,21 +169,30 @@ def execute(self, context: 'Context'):

cursor = hook.get_conn().execute(self.cql, **query_extra)

files_to_upload = self._write_local_data_files(cursor)

# If a schema is set, create a BQ schema JSON file.
if self.schema_filename:
files_to_upload.update(self._write_local_schema_file(cursor))
self.log.info('Writing local schema file')
schema_file = self._write_local_schema_file(cursor)

# Flush file before uploading
schema_file['file_handle'].flush()

self.log.info('Uploading schema file to GCS.')
self._upload_to_gcs(schema_file)
schema_file['file_handle'].close()

# Flush all files before uploading
for file_handle in files_to_upload.values():
file_handle.flush()
counter = 0
self.log.info('Writing local data files')
for file_to_upload in self._write_local_data_files(cursor):
# Flush file before uploading
file_to_upload['file_handle'].flush()

self._upload_to_gcs(files_to_upload)
self.log.info('Uploading chunk file #%d to GCS.', counter)
self._upload_to_gcs(file_to_upload)

# Close all temp file handles.
for file_handle in files_to_upload.values():
file_handle.close()
self.log.info('Removing local file')
file_to_upload['file_handle'].close()
counter += 1

# Close all sessions and connection associated with this Cassandra cluster
hook.shutdown_cluster()
Expand All @@ -197,8 +206,12 @@ def _write_local_data_files(self, cursor):
contain the data for the GCS objects.
"""
file_no = 0

tmp_file_handle = NamedTemporaryFile(delete=True)
tmp_file_handles = {self.filename.format(file_no): tmp_file_handle}
file_to_upload = {
'file_name': self.filename.format(file_no),
'file_handle': tmp_file_handle,
}
for row in cursor:
row_dict = self.generate_data_dict(row._fields, row)
content = json.dumps(row_dict).encode('utf-8')
Expand All @@ -209,10 +222,14 @@ def _write_local_data_files(self, cursor):

if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
file_no += 1
tmp_file_handle = NamedTemporaryFile(delete=True)
tmp_file_handles[self.filename.format(file_no)] = tmp_file_handle

return tmp_file_handles
yield file_to_upload
tmp_file_handle = NamedTemporaryFile(delete=True)
file_to_upload = {
'file_name': self.filename.format(file_no),
'file_handle': tmp_file_handle,
}
yield file_to_upload

def _write_local_schema_file(self, cursor):
"""
Expand All @@ -231,22 +248,26 @@ def _write_local_schema_file(self, cursor):
json_serialized_schema = json.dumps(schema).encode('utf-8')

tmp_schema_file_handle.write(json_serialized_schema)
return {self.schema_filename: tmp_schema_file_handle}

def _upload_to_gcs(self, files_to_upload: Dict[str, Any]):
schema_file_to_upload = {
'file_name': self.schema_filename,
'file_handle': tmp_schema_file_handle,
}
return schema_file_to_upload

def _upload_to_gcs(self, file_to_upload):
"""Upload a file (data split or schema .json file) to Google Cloud Storage."""
hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
for obj, tmp_file_handle in files_to_upload.items():
hook.upload(
bucket_name=self.bucket,
object_name=obj,
filename=tmp_file_handle.name,
mime_type='application/json',
gzip=self.gzip,
)
hook.upload(
bucket_name=self.bucket,
object_name=file_to_upload.get('file_name'),
filename=file_to_upload.get('file_handle').name,
mime_type='application/json',
gzip=self.gzip,
)

@classmethod
def generate_data_dict(cls, names: Iterable[str], values: Any) -> Dict[str, Any]:
Expand Down
87 changes: 44 additions & 43 deletions airflow/providers/google/cloud/transfers/sql_to_gcs.py
Expand Up @@ -139,24 +139,30 @@ def execute(self, context: 'Context'):
self.log.info("Executing query")
cursor = self.query()

self.log.info("Writing local data files")
files_to_upload = self._write_local_data_files(cursor)
# If a schema is set, create a BQ schema JSON file.
if self.schema_filename:
self.log.info("Writing local schema file")
files_to_upload.append(self._write_local_schema_file(cursor))
self.log.info('Writing local schema file')
schema_file = self._write_local_schema_file(cursor)

# Flush all files before uploading
for tmp_file in files_to_upload:
tmp_file['file_handle'].flush()
# Flush file before uploading
schema_file['file_handle'].flush()

self.log.info("Uploading %d files to GCS.", len(files_to_upload))
self._upload_to_gcs(files_to_upload)
self.log.info('Uploading schema file to GCS.')
self._upload_to_gcs(schema_file)
schema_file['file_handle'].close()

self.log.info("Removing local files")
# Close all temp file handles.
for tmp_file in files_to_upload:
tmp_file['file_handle'].close()
counter = 0
self.log.info('Writing local data files')
for file_to_upload in self._write_local_data_files(cursor):
# Flush file before uploading
file_to_upload['file_handle'].flush()

self.log.info('Uploading chunk file #%d to GCS.', counter)
self._upload_to_gcs(file_to_upload)

self.log.info('Removing local file')
file_to_upload['file_handle'].close()
counter += 1

def convert_types(self, schema, col_type_dict, row) -> list:
"""Convert values from DBAPI to output-friendly formats."""
Expand All @@ -181,14 +187,11 @@ def _write_local_data_files(self, cursor):
file_mime_type = 'application/octet-stream'
else:
file_mime_type = 'application/json'
files_to_upload = [
{
'file_name': self.filename.format(file_no),
'file_handle': tmp_file_handle,
'file_mime_type': file_mime_type,
}
]
self.log.info("Current file count: %d", len(files_to_upload))
file_to_upload = {
'file_name': self.filename.format(file_no),
'file_handle': tmp_file_handle,
'file_mime_type': file_mime_type,
}

if self.export_format == 'csv':
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
Expand Down Expand Up @@ -225,20 +228,22 @@ def _write_local_data_files(self, cursor):
if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
file_no += 1

if self.export_format == 'parquet':
parquet_writer.close()
yield file_to_upload
tmp_file_handle = NamedTemporaryFile(delete=True)
files_to_upload.append(
{
'file_name': self.filename.format(file_no),
'file_handle': tmp_file_handle,
'file_mime_type': file_mime_type,
}
)
self.log.info("Current file count: %d", len(files_to_upload))
file_to_upload = {
'file_name': self.filename.format(file_no),
'file_handle': tmp_file_handle,
'file_mime_type': file_mime_type,
}
if self.export_format == 'csv':
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
if self.export_format == 'parquet':
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
return files_to_upload
if self.export_format == 'parquet':
parquet_writer.close()
yield file_to_upload

def _configure_csv_file(self, file_handle, schema):
"""Configure a csv writer with the file_handle and write schema
Expand Down Expand Up @@ -338,21 +343,17 @@ def _write_local_schema_file(self, cursor):
}
return schema_file_to_upload

def _upload_to_gcs(self, files_to_upload):
"""
Upload all of the file splits (and optionally the schema .json file) to
Google Cloud Storage.
"""
def _upload_to_gcs(self, file_to_upload):
"""Upload a file (data split or schema .json file) to Google Cloud Storage."""
hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
for tmp_file in files_to_upload:
hook.upload(
self.bucket,
tmp_file.get('file_name'),
tmp_file.get('file_handle').name,
mime_type=tmp_file.get('file_mime_type'),
gzip=self.gzip if tmp_file.get('file_name') != self.schema_filename else False,
)
hook.upload(
self.bucket,
file_to_upload.get('file_name'),
file_to_upload.get('file_handle').name,
mime_type=file_to_upload.get('file_mime_type'),
gzip=self.gzip if file_to_upload.get('file_name') != self.schema_filename else False,
)
22 changes: 14 additions & 8 deletions tests/providers/google/cloud/transfers/test_sql_to_gcs.py
Expand Up @@ -29,7 +29,7 @@

SQL = "SELECT * FROM test_table"
BUCKET = "TEST-BUCKET-1"
FILENAME = "test_results.csv"
FILENAME = "test_results_{}.csv"
TASK_ID = "TEST_TASK_ID"
SCHEMA = [
{"name": "column_a", "type": "3"},
Expand Down Expand Up @@ -137,9 +137,13 @@ def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, m
]
)
mock_flush.assert_has_calls([mock.call(), mock.call(), mock.call(), mock.call(), mock.call()])
csv_call = mock.call(BUCKET, FILENAME, TMP_FILE_NAME, mime_type='text/csv', gzip=True)
csv_calls = []
for i in range(0, 3):
csv_calls.append(
mock.call(BUCKET, FILENAME.format(i), TMP_FILE_NAME, mime_type='text/csv', gzip=True)
)
json_call = mock.call(BUCKET, SCHEMA_FILE, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False)
upload_calls = [csv_call, csv_call, csv_call, json_call]
upload_calls = [json_call, csv_calls[0], csv_calls[1], csv_calls[2]]
mock_upload.assert_has_calls(upload_calls)
mock_close.assert_has_calls([mock.call(), mock.call(), mock.call(), mock.call(), mock.call()])

Expand Down Expand Up @@ -169,7 +173,9 @@ def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, m
]
)
mock_flush.assert_called_once()
mock_upload.assert_called_once_with(BUCKET, FILENAME, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False)
mock_upload.assert_called_once_with(
BUCKET, FILENAME.format(0), TMP_FILE_NAME, mime_type=APP_JSON, gzip=False
)
mock_close.assert_called_once()

mock_query.reset_mock()
Expand All @@ -189,7 +195,7 @@ def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, m
mock_query.assert_called_once()
mock_flush.assert_called_once()
mock_upload.assert_called_once_with(
BUCKET, FILENAME, TMP_FILE_NAME, mime_type='application/octet-stream', gzip=False
BUCKET, FILENAME.format(0), TMP_FILE_NAME, mime_type='application/octet-stream', gzip=False
)
mock_close.assert_called_once()

Expand Down Expand Up @@ -233,7 +239,7 @@ def test__write_local_data_files_csv(self):
cursor.description = CURSOR_DESCRIPTION

files = op._write_local_data_files(cursor)
file = files[0]['file_handle']
file = next(files)['file_handle']
file.flush()
df = pd.read_csv(file.name)
assert df.equals(OUTPUT_DF)
Expand All @@ -255,7 +261,7 @@ def test__write_local_data_files_json(self):
cursor.description = CURSOR_DESCRIPTION

files = op._write_local_data_files(cursor)
file = files[0]['file_handle']
file = next(files)['file_handle']
file.flush()
df = pd.read_json(file.name, orient='records', lines=True)
assert df.equals(OUTPUT_DF)
Expand All @@ -277,7 +283,7 @@ def test__write_local_data_files_parquet(self):
cursor.description = CURSOR_DESCRIPTION

files = op._write_local_data_files(cursor)
file = files[0]['file_handle']
file = next(files)['file_handle']
file.flush()
df = pd.read_parquet(file.name)
assert df.equals(OUTPUT_DF)

0 comments on commit 99b0211

Please sign in to comment.