19
19
from google .cloud .environment_vars import BIGTABLE_EMULATOR
20
20
from google .cloud .bigtable .data import BigtableDataClientAsync
21
21
from google .cloud .bigtable .data ._cross_sync import CrossSync
22
+ from helpers import sql_encoding_helpers
22
23
23
24
if not CrossSync .is_async :
24
25
from client_handler_data_async import error_safe
@@ -32,6 +33,7 @@ def error_safe(func):
32
33
Catch and pass errors back to the grpc_server_process
33
34
Also check if client is closed before processing requests
34
35
"""
36
+
35
37
async def wrapper (self , * args , ** kwargs ):
36
38
try :
37
39
if self .closed :
@@ -50,6 +52,7 @@ def encode_exception(exc):
50
52
Encode an exception or chain of exceptions to pass back to grpc_handler
51
53
"""
52
54
from google .api_core .exceptions import GoogleAPICallError
55
+
53
56
error_msg = f"{ type (exc ).__name__ } : { exc } "
54
57
result = {"error" : error_msg }
55
58
if exc .__cause__ :
@@ -113,7 +116,9 @@ async def ReadRows(self, request, **kwargs):
113
116
table_id = request .pop ("table_name" ).split ("/" )[- 1 ]
114
117
app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
115
118
table = self .client .get_table (self .instance_id , table_id , app_profile_id )
116
- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
119
+ kwargs ["operation_timeout" ] = (
120
+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
121
+ )
117
122
result_list = CrossSync .rm_aio (await table .read_rows (request , ** kwargs ))
118
123
# pack results back into protobuf-parsable format
119
124
serialized_response = [row ._to_dict () for row in result_list ]
@@ -124,7 +129,9 @@ async def ReadRow(self, row_key, **kwargs):
124
129
table_id = kwargs .pop ("table_name" ).split ("/" )[- 1 ]
125
130
app_profile_id = self .app_profile_id or kwargs .get ("app_profile_id" , None )
126
131
table = self .client .get_table (self .instance_id , table_id , app_profile_id )
127
- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
132
+ kwargs ["operation_timeout" ] = (
133
+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
134
+ )
128
135
result_row = CrossSync .rm_aio (await table .read_row (row_key , ** kwargs ))
129
136
# pack results back into protobuf-parsable format
130
137
if result_row :
@@ -135,10 +142,13 @@ async def ReadRow(self, row_key, **kwargs):
135
142
@error_safe
136
143
async def MutateRow (self , request , ** kwargs ):
137
144
from google .cloud .bigtable .data .mutations import Mutation
145
+
138
146
table_id = request ["table_name" ].split ("/" )[- 1 ]
139
147
app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
140
148
table = self .client .get_table (self .instance_id , table_id , app_profile_id )
141
- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
149
+ kwargs ["operation_timeout" ] = (
150
+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
151
+ )
142
152
row_key = request ["row_key" ]
143
153
mutations = [Mutation ._from_dict (d ) for d in request ["mutations" ]]
144
154
CrossSync .rm_aio (await table .mutate_row (row_key , mutations , ** kwargs ))
@@ -147,21 +157,29 @@ async def MutateRow(self, request, **kwargs):
147
157
@error_safe
148
158
async def BulkMutateRows (self , request , ** kwargs ):
149
159
from google .cloud .bigtable .data .mutations import RowMutationEntry
160
+
150
161
table_id = request ["table_name" ].split ("/" )[- 1 ]
151
162
app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
152
163
table = self .client .get_table (self .instance_id , table_id , app_profile_id )
153
- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
154
- entry_list = [RowMutationEntry ._from_dict (entry ) for entry in request ["entries" ]]
164
+ kwargs ["operation_timeout" ] = (
165
+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
166
+ )
167
+ entry_list = [
168
+ RowMutationEntry ._from_dict (entry ) for entry in request ["entries" ]
169
+ ]
155
170
CrossSync .rm_aio (await table .bulk_mutate_rows (entry_list , ** kwargs ))
156
171
return "OK"
157
172
158
173
@error_safe
159
174
async def CheckAndMutateRow (self , request , ** kwargs ):
160
175
from google .cloud .bigtable .data .mutations import Mutation , SetCell
176
+
161
177
table_id = request ["table_name" ].split ("/" )[- 1 ]
162
178
app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
163
179
table = self .client .get_table (self .instance_id , table_id , app_profile_id )
164
- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
180
+ kwargs ["operation_timeout" ] = (
181
+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
182
+ )
165
183
row_key = request ["row_key" ]
166
184
# add default values for incomplete dicts, so they can still be parsed to objects
167
185
true_mutations = []
@@ -180,33 +198,44 @@ async def CheckAndMutateRow(self, request, **kwargs):
180
198
# invalid mutation type. Conformance test may be sending generic empty request
181
199
false_mutations .append (SetCell ("" , "" , "" , 0 ))
182
200
predicate_filter = request .get ("predicate_filter" , None )
183
- result = CrossSync .rm_aio (await table .check_and_mutate_row (
184
- row_key ,
185
- predicate_filter ,
186
- true_case_mutations = true_mutations ,
187
- false_case_mutations = false_mutations ,
188
- ** kwargs ,
189
- ))
201
+ result = CrossSync .rm_aio (
202
+ await table .check_and_mutate_row (
203
+ row_key ,
204
+ predicate_filter ,
205
+ true_case_mutations = true_mutations ,
206
+ false_case_mutations = false_mutations ,
207
+ ** kwargs ,
208
+ )
209
+ )
190
210
return result
191
211
192
212
@error_safe
193
213
async def ReadModifyWriteRow (self , request , ** kwargs ):
194
214
from google .cloud .bigtable .data .read_modify_write_rules import IncrementRule
195
215
from google .cloud .bigtable .data .read_modify_write_rules import AppendValueRule
216
+
196
217
table_id = request ["table_name" ].split ("/" )[- 1 ]
197
218
app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
198
219
table = self .client .get_table (self .instance_id , table_id , app_profile_id )
199
- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
220
+ kwargs ["operation_timeout" ] = (
221
+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
222
+ )
200
223
row_key = request ["row_key" ]
201
224
rules = []
202
225
for rule_dict in request .get ("rules" , []):
203
226
qualifier = rule_dict ["column_qualifier" ]
204
227
if "append_value" in rule_dict :
205
- new_rule = AppendValueRule (rule_dict ["family_name" ], qualifier , rule_dict ["append_value" ])
228
+ new_rule = AppendValueRule (
229
+ rule_dict ["family_name" ], qualifier , rule_dict ["append_value" ]
230
+ )
206
231
else :
207
- new_rule = IncrementRule (rule_dict ["family_name" ], qualifier , rule_dict ["increment_amount" ])
232
+ new_rule = IncrementRule (
233
+ rule_dict ["family_name" ], qualifier , rule_dict ["increment_amount" ]
234
+ )
208
235
rules .append (new_rule )
209
- result = CrossSync .rm_aio (await table .read_modify_write_row (row_key , rules , ** kwargs ))
236
+ result = CrossSync .rm_aio (
237
+ await table .read_modify_write_row (row_key , rules , ** kwargs )
238
+ )
210
239
# pack results back into protobuf-parsable format
211
240
if result :
212
241
return result ._to_dict ()
@@ -218,6 +247,55 @@ async def SampleRowKeys(self, request, **kwargs):
218
247
table_id = request ["table_name" ].split ("/" )[- 1 ]
219
248
app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
220
249
table = self .client .get_table (self .instance_id , table_id , app_profile_id )
221
- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
250
+ kwargs ["operation_timeout" ] = (
251
+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
252
+ )
222
253
result = CrossSync .rm_aio (await table .sample_row_keys (** kwargs ))
223
254
return result
255
+
256
+ @error_safe
257
+ async def ExecuteQuery (self , request , ** kwargs ):
258
+ app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
259
+ query = request .get ("query" )
260
+ params = request .get ("params" ) or {}
261
+ # Note that the request has been coverted to json, and the code for this converts
262
+ # query param names to snake case. convert_params reverses this conversion. For this
263
+ # reason, snake case params will have issues if they're used in the conformance tests.
264
+ formatted_params , parameter_types = sql_encoding_helpers .convert_params (params )
265
+ operation_timeout = (
266
+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
267
+ )
268
+ result = CrossSync .rm_aio (
269
+ await self .client .execute_query (
270
+ query ,
271
+ self .instance_id ,
272
+ parameters = formatted_params ,
273
+ parameter_types = parameter_types ,
274
+ app_profile_id = app_profile_id ,
275
+ operation_timeout = operation_timeout ,
276
+ prepare_operation_timeout = operation_timeout ,
277
+ )
278
+ )
279
+ rows = [r async for r in result ]
280
+ md = result .metadata
281
+ proto_rows = []
282
+ for r in rows :
283
+ vals = []
284
+ for c in md .columns :
285
+ vals .append (sql_encoding_helpers .convert_value (c .column_type , r [c .column_name ]))
286
+
287
+ proto_rows .append ({"values" : vals })
288
+
289
+ proto_columns = []
290
+ for c in md .columns :
291
+ proto_columns .append (
292
+ {
293
+ "name" : c .column_name ,
294
+ "type" : sql_encoding_helpers .convert_type (c .column_type ),
295
+ }
296
+ )
297
+
298
+ return {
299
+ "metadata" : {"columns" : proto_columns },
300
+ "rows" : proto_rows ,
301
+ }
0 commit comments