Skip to content

Commit

Permalink
Fix MyPy for Google Bigquery (#20329)
Browse files Browse the repository at this point in the history
Part of #19891
  • Loading branch information
potiuk committed Dec 21, 2021
1 parent 479d9be commit c6dbb3f
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 27 deletions.
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,15 +1419,15 @@ def _build_new_schema(
) -> List[Dict[str, Any]]:

# Turn schema_field_updates into a dict keyed on field names
schema_fields_updates = {field["name"]: field for field in deepcopy(schema_fields_updates)}
schema_fields_updates_dict = {field["name"]: field for field in deepcopy(schema_fields_updates)}

# Create a new dict for storing the new schema, initiated based on the current_schema
# as of Python 3.6, dicts retain order.
new_schema = {field["name"]: field for field in deepcopy(current_schema)}

# Each item in schema_fields_updates contains a potential patch
# to a schema field, iterate over them
for field_name, patched_value in schema_fields_updates.items():
for field_name, patched_value in schema_fields_updates_dict.items():
# If this field already exists, update it
if field_name in new_schema:
# If this field is of type RECORD and has a fields key we need to patch it recursively
Expand Down Expand Up @@ -1822,7 +1822,7 @@ def run_load(
var_name='destination_project_dataset_table',
)

configuration = {
configuration: Dict[str, Any] = {
'load': {
'autodetect': autodetect,
'createDisposition': create_disposition,
Expand Down
39 changes: 39 additions & 0 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,45 @@ def download(
time.sleep(timeout_seconds)
continue

def download_as_byte_array(
self,
bucket_name: str,
object_name: str,
chunk_size: Optional[int] = None,
timeout: Optional[int] = DEFAULT_TIMEOUT,
num_max_attempts: Optional[int] = 1,
) -> bytes:
"""
Downloads a file from Google Cloud Storage.
When no filename is supplied, the operator loads the file into memory and returns its
content. When a filename is supplied, it writes the file to the specified location and
returns the location. For file sizes that exceed the available memory it is recommended
to write to a file.
:param bucket_name: The bucket to fetch from.
:type bucket_name: str
:param object_name: The object to fetch.
:type object_name: str
:param chunk_size: Blob chunk size.
:type chunk_size: int
:param timeout: Request timeout in seconds.
:type timeout: int
:param num_max_attempts: Number of attempts to download the file.
:type num_max_attempts: int
"""
# We do not pass filename, so will never receive string as response
return cast(
bytes,
self.download(
bucket_name=bucket_name,
object_name=object_name,
chunk_size=chunk_size,
timeout=timeout,
num_max_attempts=num_max_attempts,
),
)

@_fallback_object_url_to_object_name_and_bucket_name()
@contextmanager
def provide_file(
Expand Down
45 changes: 26 additions & 19 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import uuid
import warnings
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, SupportsAbs, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, SupportsAbs, Union, cast

import attr
from google.api_core.exceptions import Conflict
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_link(self, operator: BaseOperator, dttm: datetime):


class _BigQueryDbHookMixin:
def get_db_hook(self) -> BigQueryHook:
def get_db_hook(self: 'BigQueryCheckOperator') -> BigQueryHook: # type:ignore[misc]
"""Get BigQuery DB Hook"""
return BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
Expand Down Expand Up @@ -948,7 +948,8 @@ def execute(self, context) -> None:
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
schema_fields = json.loads(gcs_hook.download(gcs_bucket, gcs_object).decode("utf-8"))
schema_fields_string = gcs_hook.download_as_byte_array(gcs_bucket, gcs_object).decode("utf-8")
schema_fields = json.loads(schema_fields_string)
else:
schema_fields = self.schema_fields

Expand Down Expand Up @@ -1090,8 +1091,8 @@ def __init__(
self,
*,
bucket: Optional[str] = None,
source_objects: Optional[List] = None,
destination_project_dataset_table: str = None,
source_objects: Optional[List[str]] = None,
destination_project_dataset_table: Optional[str] = None,
table_resource: Optional[Dict[str, Any]] = None,
schema_fields: Optional[List] = None,
schema_object: Optional[str] = None,
Expand All @@ -1115,11 +1116,6 @@ def __init__(
) -> None:
super().__init__(**kwargs)

# GCS config
self.bucket = bucket
self.source_objects = source_objects
self.schema_object = schema_object

# BQ config
kwargs_passed = any(
[
Expand Down Expand Up @@ -1158,22 +1154,30 @@ def __init__(
skip_leading_rows = 0
if not field_delimiter:
field_delimiter = ","
if not destination_project_dataset_table:
raise ValueError(
"`destination_project_dataset_table` is required when not using `table_resource`."
)
self.bucket = bucket
self.source_objects = source_objects
self.schema_object = schema_object
self.destination_project_dataset_table = destination_project_dataset_table
self.schema_fields = schema_fields
self.source_format = source_format
self.compression = compression
self.skip_leading_rows = skip_leading_rows
self.field_delimiter = field_delimiter
self.table_resource = None
else:
self.table_resource = table_resource

if table_resource and kwargs_passed:
raise ValueError("You provided both `table_resource` and exclusive keywords arguments.")

self.table_resource = table_resource
self.destination_project_dataset_table = destination_project_dataset_table
self.schema_fields = schema_fields
self.source_format = source_format
self.compression = compression
self.skip_leading_rows = skip_leading_rows
self.field_delimiter = field_delimiter
self.max_bad_records = max_bad_records
self.quote_character = quote_character
self.allow_quoted_newlines = allow_quoted_newlines
self.allow_jagged_rows = allow_jagged_rows

self.bigquery_conn_id = bigquery_conn_id
self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
self.delegate_to = delegate_to
Expand Down Expand Up @@ -1203,7 +1207,10 @@ def execute(self, context) -> None:
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object).decode("utf-8"))
schema_fields_bytes_or_string = gcs_hook.download(self.bucket, self.schema_object)
if hasattr(schema_fields_bytes_or_string, 'decode'):
schema_fields_bytes_or_string = cast(bytes, schema_fields_bytes_or_string).decode("utf-8")
schema_fields = json.loads(schema_fields_bytes_or_string)
else:
schema_fields = self.schema_fields

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,6 @@
"csvOptions": {"skipLeadingRows": 1},
},
},
source_objects=[
f"{EXPORT_PREFIX}/all_namespaces/kind_{EXPORT_COLLECTION_ID}"
f"/all_namespaces_kind_{EXPORT_COLLECTION_ID}.export_metadata"
],
)
# [END howto_operator_create_external_table_multiple_types]

Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ no_implicit_optional = False
[mypy-google.cloud.oslogin_v1.services.os_login_service.*]
no_implicit_optional = False


[isort]
line_length=110
combine_as_imports = true
Expand Down

0 comments on commit c6dbb3f

Please sign in to comment.