Skip to content

Commit c6dbb3f

Browse files
authored
Fix MyPy for Google Bigquery (#20329)
Part of #19891
1 parent 479d9be commit c6dbb3f

File tree

5 files changed

+68
-27
lines changed

5 files changed

+68
-27
lines changed

airflow/providers/google/cloud/hooks/bigquery.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,15 +1419,15 @@ def _build_new_schema(
14191419
) -> List[Dict[str, Any]]:
14201420

14211421
# Turn schema_field_updates into a dict keyed on field names
1422-
schema_fields_updates = {field["name"]: field for field in deepcopy(schema_fields_updates)}
1422+
schema_fields_updates_dict = {field["name"]: field for field in deepcopy(schema_fields_updates)}
14231423

14241424
# Create a new dict for storing the new schema, initiated based on the current_schema
14251425
# as of Python 3.6, dicts retain order.
14261426
new_schema = {field["name"]: field for field in deepcopy(current_schema)}
14271427

14281428
# Each item in schema_fields_updates contains a potential patch
14291429
# to a schema field, iterate over them
1430-
for field_name, patched_value in schema_fields_updates.items():
1430+
for field_name, patched_value in schema_fields_updates_dict.items():
14311431
# If this field already exists, update it
14321432
if field_name in new_schema:
14331433
# If this field is of type RECORD and has a fields key we need to patch it recursively
@@ -1822,7 +1822,7 @@ def run_load(
18221822
var_name='destination_project_dataset_table',
18231823
)
18241824

1825-
configuration = {
1825+
configuration: Dict[str, Any] = {
18261826
'load': {
18271827
'autodetect': autodetect,
18281828
'createDisposition': create_disposition,

airflow/providers/google/cloud/hooks/gcs.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,45 @@ def download(
338338
time.sleep(timeout_seconds)
339339
continue
340340

341+
def download_as_byte_array(
342+
self,
343+
bucket_name: str,
344+
object_name: str,
345+
chunk_size: Optional[int] = None,
346+
timeout: Optional[int] = DEFAULT_TIMEOUT,
347+
num_max_attempts: Optional[int] = 1,
348+
) -> bytes:
349+
"""
350+
Downloads a file from Google Cloud Storage.
351+
352+
When no filename is supplied, the operator loads the file into memory and returns its
353+
content. When a filename is supplied, it writes the file to the specified location and
354+
returns the location. For file sizes that exceed the available memory it is recommended
355+
to write to a file.
356+
357+
:param bucket_name: The bucket to fetch from.
358+
:type bucket_name: str
359+
:param object_name: The object to fetch.
360+
:type object_name: str
361+
:param chunk_size: Blob chunk size.
362+
:type chunk_size: int
363+
:param timeout: Request timeout in seconds.
364+
:type timeout: int
365+
:param num_max_attempts: Number of attempts to download the file.
366+
:type num_max_attempts: int
367+
"""
368+
# We do not pass filename, so will never receive string as response
369+
return cast(
370+
bytes,
371+
self.download(
372+
bucket_name=bucket_name,
373+
object_name=object_name,
374+
chunk_size=chunk_size,
375+
timeout=timeout,
376+
num_max_attempts=num_max_attempts,
377+
),
378+
)
379+
341380
@_fallback_object_url_to_object_name_and_bucket_name()
342381
@contextmanager
343382
def provide_file(

airflow/providers/google/cloud/operators/bigquery.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import uuid
2626
import warnings
2727
from datetime import datetime
28-
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, SupportsAbs, Union
28+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, SupportsAbs, Union, cast
2929

3030
import attr
3131
from google.api_core.exceptions import Conflict
@@ -91,7 +91,7 @@ def get_link(self, operator: BaseOperator, dttm: datetime):
9191

9292

9393
class _BigQueryDbHookMixin:
94-
def get_db_hook(self) -> BigQueryHook:
94+
def get_db_hook(self: 'BigQueryCheckOperator') -> BigQueryHook: # type:ignore[misc]
9595
"""Get BigQuery DB Hook"""
9696
return BigQueryHook(
9797
gcp_conn_id=self.gcp_conn_id,
@@ -948,7 +948,8 @@ def execute(self, context) -> None:
948948
delegate_to=self.delegate_to,
949949
impersonation_chain=self.impersonation_chain,
950950
)
951-
schema_fields = json.loads(gcs_hook.download(gcs_bucket, gcs_object).decode("utf-8"))
951+
schema_fields_string = gcs_hook.download_as_byte_array(gcs_bucket, gcs_object).decode("utf-8")
952+
schema_fields = json.loads(schema_fields_string)
952953
else:
953954
schema_fields = self.schema_fields
954955

@@ -1090,8 +1091,8 @@ def __init__(
10901091
self,
10911092
*,
10921093
bucket: Optional[str] = None,
1093-
source_objects: Optional[List] = None,
1094-
destination_project_dataset_table: str = None,
1094+
source_objects: Optional[List[str]] = None,
1095+
destination_project_dataset_table: Optional[str] = None,
10951096
table_resource: Optional[Dict[str, Any]] = None,
10961097
schema_fields: Optional[List] = None,
10971098
schema_object: Optional[str] = None,
@@ -1115,11 +1116,6 @@ def __init__(
11151116
) -> None:
11161117
super().__init__(**kwargs)
11171118

1118-
# GCS config
1119-
self.bucket = bucket
1120-
self.source_objects = source_objects
1121-
self.schema_object = schema_object
1122-
11231119
# BQ config
11241120
kwargs_passed = any(
11251121
[
@@ -1158,22 +1154,30 @@ def __init__(
11581154
skip_leading_rows = 0
11591155
if not field_delimiter:
11601156
field_delimiter = ","
1157+
if not destination_project_dataset_table:
1158+
raise ValueError(
1159+
"`destination_project_dataset_table` is required when not using `table_resource`."
1160+
)
1161+
self.bucket = bucket
1162+
self.source_objects = source_objects
1163+
self.schema_object = schema_object
1164+
self.destination_project_dataset_table = destination_project_dataset_table
1165+
self.schema_fields = schema_fields
1166+
self.source_format = source_format
1167+
self.compression = compression
1168+
self.skip_leading_rows = skip_leading_rows
1169+
self.field_delimiter = field_delimiter
1170+
self.table_resource = None
1171+
else:
1172+
self.table_resource = table_resource
11611173

11621174
if table_resource and kwargs_passed:
11631175
raise ValueError("You provided both `table_resource` and exclusive keywords arguments.")
11641176

1165-
self.table_resource = table_resource
1166-
self.destination_project_dataset_table = destination_project_dataset_table
1167-
self.schema_fields = schema_fields
1168-
self.source_format = source_format
1169-
self.compression = compression
1170-
self.skip_leading_rows = skip_leading_rows
1171-
self.field_delimiter = field_delimiter
11721177
self.max_bad_records = max_bad_records
11731178
self.quote_character = quote_character
11741179
self.allow_quoted_newlines = allow_quoted_newlines
11751180
self.allow_jagged_rows = allow_jagged_rows
1176-
11771181
self.bigquery_conn_id = bigquery_conn_id
11781182
self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
11791183
self.delegate_to = delegate_to
@@ -1203,7 +1207,10 @@ def execute(self, context) -> None:
12031207
delegate_to=self.delegate_to,
12041208
impersonation_chain=self.impersonation_chain,
12051209
)
1206-
schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object).decode("utf-8"))
1210+
schema_fields_bytes_or_string = gcs_hook.download(self.bucket, self.schema_object)
1211+
if hasattr(schema_fields_bytes_or_string, 'decode'):
1212+
schema_fields_bytes_or_string = cast(bytes, schema_fields_bytes_or_string).decode("utf-8")
1213+
schema_fields = json.loads(schema_fields_bytes_or_string)
12071214
else:
12081215
schema_fields = self.schema_fields
12091216

airflow/providers/google/firebase/example_dags/example_firestore.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,6 @@
119119
"csvOptions": {"skipLeadingRows": 1},
120120
},
121121
},
122-
source_objects=[
123-
f"{EXPORT_PREFIX}/all_namespaces/kind_{EXPORT_COLLECTION_ID}"
124-
f"/all_namespaces_kind_{EXPORT_COLLECTION_ID}.export_metadata"
125-
],
126122
)
127123
# [END howto_operator_create_external_table_multiple_types]
128124

setup.cfg

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,6 @@ no_implicit_optional = False
252252
[mypy-google.cloud.oslogin_v1.services.os_login_service.*]
253253
no_implicit_optional = False
254254

255-
256255
[isort]
257256
line_length=110
258257
combine_as_imports = true

0 commit comments

Comments
 (0)