25
25
import uuid
26
26
import warnings
27
27
from datetime import datetime
28
- from typing import Any , Dict , Iterable , List , Optional , Sequence , Set , SupportsAbs , Union
28
+ from typing import Any , Dict , Iterable , List , Optional , Sequence , Set , SupportsAbs , Union , cast
29
29
30
30
import attr
31
31
from google .api_core .exceptions import Conflict
@@ -91,7 +91,7 @@ def get_link(self, operator: BaseOperator, dttm: datetime):
91
91
92
92
93
93
class _BigQueryDbHookMixin :
94
- def get_db_hook (self ) -> BigQueryHook :
94
+ def get_db_hook (self : 'BigQueryCheckOperator' ) -> BigQueryHook : # type:ignore[misc]
95
95
"""Get BigQuery DB Hook"""
96
96
return BigQueryHook (
97
97
gcp_conn_id = self .gcp_conn_id ,
@@ -948,7 +948,8 @@ def execute(self, context) -> None:
948
948
delegate_to = self .delegate_to ,
949
949
impersonation_chain = self .impersonation_chain ,
950
950
)
951
- schema_fields = json .loads (gcs_hook .download (gcs_bucket , gcs_object ).decode ("utf-8" ))
951
+ schema_fields_string = gcs_hook .download_as_byte_array (gcs_bucket , gcs_object ).decode ("utf-8" )
952
+ schema_fields = json .loads (schema_fields_string )
952
953
else :
953
954
schema_fields = self .schema_fields
954
955
@@ -1090,8 +1091,8 @@ def __init__(
1090
1091
self ,
1091
1092
* ,
1092
1093
bucket : Optional [str ] = None ,
1093
- source_objects : Optional [List ] = None ,
1094
- destination_project_dataset_table : str = None ,
1094
+ source_objects : Optional [List [ str ] ] = None ,
1095
+ destination_project_dataset_table : Optional [ str ] = None ,
1095
1096
table_resource : Optional [Dict [str , Any ]] = None ,
1096
1097
schema_fields : Optional [List ] = None ,
1097
1098
schema_object : Optional [str ] = None ,
@@ -1115,11 +1116,6 @@ def __init__(
1115
1116
) -> None :
1116
1117
super ().__init__ (** kwargs )
1117
1118
1118
- # GCS config
1119
- self .bucket = bucket
1120
- self .source_objects = source_objects
1121
- self .schema_object = schema_object
1122
-
1123
1119
# BQ config
1124
1120
kwargs_passed = any (
1125
1121
[
@@ -1158,22 +1154,30 @@ def __init__(
1158
1154
skip_leading_rows = 0
1159
1155
if not field_delimiter :
1160
1156
field_delimiter = ","
1157
+ if not destination_project_dataset_table :
1158
+ raise ValueError (
1159
+ "`destination_project_dataset_table` is required when not using `table_resource`."
1160
+ )
1161
+ self .bucket = bucket
1162
+ self .source_objects = source_objects
1163
+ self .schema_object = schema_object
1164
+ self .destination_project_dataset_table = destination_project_dataset_table
1165
+ self .schema_fields = schema_fields
1166
+ self .source_format = source_format
1167
+ self .compression = compression
1168
+ self .skip_leading_rows = skip_leading_rows
1169
+ self .field_delimiter = field_delimiter
1170
+ self .table_resource = None
1171
+ else :
1172
+ self .table_resource = table_resource
1161
1173
1162
1174
if table_resource and kwargs_passed :
1163
1175
raise ValueError ("You provided both `table_resource` and exclusive keywords arguments." )
1164
1176
1165
- self .table_resource = table_resource
1166
- self .destination_project_dataset_table = destination_project_dataset_table
1167
- self .schema_fields = schema_fields
1168
- self .source_format = source_format
1169
- self .compression = compression
1170
- self .skip_leading_rows = skip_leading_rows
1171
- self .field_delimiter = field_delimiter
1172
1177
self .max_bad_records = max_bad_records
1173
1178
self .quote_character = quote_character
1174
1179
self .allow_quoted_newlines = allow_quoted_newlines
1175
1180
self .allow_jagged_rows = allow_jagged_rows
1176
-
1177
1181
self .bigquery_conn_id = bigquery_conn_id
1178
1182
self .google_cloud_storage_conn_id = google_cloud_storage_conn_id
1179
1183
self .delegate_to = delegate_to
@@ -1203,7 +1207,10 @@ def execute(self, context) -> None:
1203
1207
delegate_to = self .delegate_to ,
1204
1208
impersonation_chain = self .impersonation_chain ,
1205
1209
)
1206
- schema_fields = json .loads (gcs_hook .download (self .bucket , self .schema_object ).decode ("utf-8" ))
1210
+ schema_fields_bytes_or_string = gcs_hook .download (self .bucket , self .schema_object )
1211
+ if hasattr (schema_fields_bytes_or_string , 'decode' ):
1212
+ schema_fields_bytes_or_string = cast (bytes , schema_fields_bytes_or_string ).decode ("utf-8" )
1213
+ schema_fields = json .loads (schema_fields_bytes_or_string )
1207
1214
else :
1208
1215
schema_fields = self .schema_fields
1209
1216
0 commit comments