Skip to content

Commit 7a91bbf

Browse files
feat: Implement SQL support in test proxy (#1106)
1 parent adf816c commit 7a91bbf

12 files changed

+1300
-331
lines changed

test_proxy/handlers/client_handler_data_async.py

Lines changed: 96 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from google.cloud.environment_vars import BIGTABLE_EMULATOR
2020
from google.cloud.bigtable.data import BigtableDataClientAsync
2121
from google.cloud.bigtable.data._cross_sync import CrossSync
22+
from helpers import sql_encoding_helpers
2223

2324
if not CrossSync.is_async:
2425
from client_handler_data_async import error_safe
@@ -32,6 +33,7 @@ def error_safe(func):
3233
Catch and pass errors back to the grpc_server_process
3334
Also check if client is closed before processing requests
3435
"""
36+
3537
async def wrapper(self, *args, **kwargs):
3638
try:
3739
if self.closed:
@@ -50,6 +52,7 @@ def encode_exception(exc):
5052
Encode an exception or chain of exceptions to pass back to grpc_handler
5153
"""
5254
from google.api_core.exceptions import GoogleAPICallError
55+
5356
error_msg = f"{type(exc).__name__}: {exc}"
5457
result = {"error": error_msg}
5558
if exc.__cause__:
@@ -113,7 +116,9 @@ async def ReadRows(self, request, **kwargs):
113116
table_id = request.pop("table_name").split("/")[-1]
114117
app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
115118
table = self.client.get_table(self.instance_id, table_id, app_profile_id)
116-
kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
119+
kwargs["operation_timeout"] = (
120+
kwargs.get("operation_timeout", self.per_operation_timeout) or 20
121+
)
117122
result_list = CrossSync.rm_aio(await table.read_rows(request, **kwargs))
118123
# pack results back into protobuf-parsable format
119124
serialized_response = [row._to_dict() for row in result_list]
@@ -124,7 +129,9 @@ async def ReadRow(self, row_key, **kwargs):
124129
table_id = kwargs.pop("table_name").split("/")[-1]
125130
app_profile_id = self.app_profile_id or kwargs.get("app_profile_id", None)
126131
table = self.client.get_table(self.instance_id, table_id, app_profile_id)
127-
kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
132+
kwargs["operation_timeout"] = (
133+
kwargs.get("operation_timeout", self.per_operation_timeout) or 20
134+
)
128135
result_row = CrossSync.rm_aio(await table.read_row(row_key, **kwargs))
129136
# pack results back into protobuf-parsable format
130137
if result_row:
@@ -135,10 +142,13 @@ async def ReadRow(self, row_key, **kwargs):
135142
@error_safe
136143
async def MutateRow(self, request, **kwargs):
137144
from google.cloud.bigtable.data.mutations import Mutation
145+
138146
table_id = request["table_name"].split("/")[-1]
139147
app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
140148
table = self.client.get_table(self.instance_id, table_id, app_profile_id)
141-
kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
149+
kwargs["operation_timeout"] = (
150+
kwargs.get("operation_timeout", self.per_operation_timeout) or 20
151+
)
142152
row_key = request["row_key"]
143153
mutations = [Mutation._from_dict(d) for d in request["mutations"]]
144154
CrossSync.rm_aio(await table.mutate_row(row_key, mutations, **kwargs))
@@ -147,21 +157,29 @@ async def MutateRow(self, request, **kwargs):
147157
@error_safe
148158
async def BulkMutateRows(self, request, **kwargs):
149159
from google.cloud.bigtable.data.mutations import RowMutationEntry
160+
150161
table_id = request["table_name"].split("/")[-1]
151162
app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
152163
table = self.client.get_table(self.instance_id, table_id, app_profile_id)
153-
kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
154-
entry_list = [RowMutationEntry._from_dict(entry) for entry in request["entries"]]
164+
kwargs["operation_timeout"] = (
165+
kwargs.get("operation_timeout", self.per_operation_timeout) or 20
166+
)
167+
entry_list = [
168+
RowMutationEntry._from_dict(entry) for entry in request["entries"]
169+
]
155170
CrossSync.rm_aio(await table.bulk_mutate_rows(entry_list, **kwargs))
156171
return "OK"
157172

158173
@error_safe
159174
async def CheckAndMutateRow(self, request, **kwargs):
160175
from google.cloud.bigtable.data.mutations import Mutation, SetCell
176+
161177
table_id = request["table_name"].split("/")[-1]
162178
app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
163179
table = self.client.get_table(self.instance_id, table_id, app_profile_id)
164-
kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
180+
kwargs["operation_timeout"] = (
181+
kwargs.get("operation_timeout", self.per_operation_timeout) or 20
182+
)
165183
row_key = request["row_key"]
166184
# add default values for incomplete dicts, so they can still be parsed to objects
167185
true_mutations = []
@@ -180,33 +198,44 @@ async def CheckAndMutateRow(self, request, **kwargs):
180198
# invalid mutation type. Conformance test may be sending generic empty request
181199
false_mutations.append(SetCell("", "", "", 0))
182200
predicate_filter = request.get("predicate_filter", None)
183-
result = CrossSync.rm_aio(await table.check_and_mutate_row(
184-
row_key,
185-
predicate_filter,
186-
true_case_mutations=true_mutations,
187-
false_case_mutations=false_mutations,
188-
**kwargs,
189-
))
201+
result = CrossSync.rm_aio(
202+
await table.check_and_mutate_row(
203+
row_key,
204+
predicate_filter,
205+
true_case_mutations=true_mutations,
206+
false_case_mutations=false_mutations,
207+
**kwargs,
208+
)
209+
)
190210
return result
191211

192212
@error_safe
193213
async def ReadModifyWriteRow(self, request, **kwargs):
194214
from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule
195215
from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule
216+
196217
table_id = request["table_name"].split("/")[-1]
197218
app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
198219
table = self.client.get_table(self.instance_id, table_id, app_profile_id)
199-
kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
220+
kwargs["operation_timeout"] = (
221+
kwargs.get("operation_timeout", self.per_operation_timeout) or 20
222+
)
200223
row_key = request["row_key"]
201224
rules = []
202225
for rule_dict in request.get("rules", []):
203226
qualifier = rule_dict["column_qualifier"]
204227
if "append_value" in rule_dict:
205-
new_rule = AppendValueRule(rule_dict["family_name"], qualifier, rule_dict["append_value"])
228+
new_rule = AppendValueRule(
229+
rule_dict["family_name"], qualifier, rule_dict["append_value"]
230+
)
206231
else:
207-
new_rule = IncrementRule(rule_dict["family_name"], qualifier, rule_dict["increment_amount"])
232+
new_rule = IncrementRule(
233+
rule_dict["family_name"], qualifier, rule_dict["increment_amount"]
234+
)
208235
rules.append(new_rule)
209-
result = CrossSync.rm_aio(await table.read_modify_write_row(row_key, rules, **kwargs))
236+
result = CrossSync.rm_aio(
237+
await table.read_modify_write_row(row_key, rules, **kwargs)
238+
)
210239
# pack results back into protobuf-parsable format
211240
if result:
212241
return result._to_dict()
@@ -218,6 +247,55 @@ async def SampleRowKeys(self, request, **kwargs):
218247
table_id = request["table_name"].split("/")[-1]
219248
app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
220249
table = self.client.get_table(self.instance_id, table_id, app_profile_id)
221-
kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20
250+
kwargs["operation_timeout"] = (
251+
kwargs.get("operation_timeout", self.per_operation_timeout) or 20
252+
)
222253
result = CrossSync.rm_aio(await table.sample_row_keys(**kwargs))
223254
return result
255+
256+
@error_safe
257+
async def ExecuteQuery(self, request, **kwargs):
258+
app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
259+
query = request.get("query")
260+
params = request.get("params") or {}
261+
# Note that the request has been coverted to json, and the code for this converts
262+
# query param names to snake case. convert_params reverses this conversion. For this
263+
# reason, snake case params will have issues if they're used in the conformance tests.
264+
formatted_params, parameter_types = sql_encoding_helpers.convert_params(params)
265+
operation_timeout = (
266+
kwargs.get("operation_timeout", self.per_operation_timeout) or 20
267+
)
268+
result = CrossSync.rm_aio(
269+
await self.client.execute_query(
270+
query,
271+
self.instance_id,
272+
parameters=formatted_params,
273+
parameter_types=parameter_types,
274+
app_profile_id=app_profile_id,
275+
operation_timeout=operation_timeout,
276+
prepare_operation_timeout=operation_timeout,
277+
)
278+
)
279+
rows = [r async for r in result]
280+
md = result.metadata
281+
proto_rows = []
282+
for r in rows:
283+
vals = []
284+
for c in md.columns:
285+
vals.append(sql_encoding_helpers.convert_value(c.column_type, r[c.column_name]))
286+
287+
proto_rows.append({"values": vals})
288+
289+
proto_columns = []
290+
for c in md.columns:
291+
proto_columns.append(
292+
{
293+
"name": c.column_name,
294+
"type": sql_encoding_helpers.convert_type(c.column_type),
295+
}
296+
)
297+
298+
return {
299+
"metadata": {"columns": proto_columns},
300+
"rows": proto_rows,
301+
}

test_proxy/handlers/client_handler_data_sync_autogen.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
from google.cloud.environment_vars import BIGTABLE_EMULATOR
2222
from google.cloud.bigtable.data._cross_sync import CrossSync
23+
from helpers import sql_encoding_helpers
2324
from client_handler_data_async import error_safe
2425

2526

@@ -183,3 +184,43 @@ async def SampleRowKeys(self, request, **kwargs):
183184
)
184185
result = table.sample_row_keys(**kwargs)
185186
return result
187+
188+
@error_safe
189+
async def ExecuteQuery(self, request, **kwargs):
190+
app_profile_id = self.app_profile_id or request.get("app_profile_id", None)
191+
query = request.get("query")
192+
params = request.get("params") or {}
193+
(formatted_params, parameter_types) = sql_encoding_helpers.convert_params(
194+
params
195+
)
196+
operation_timeout = (
197+
kwargs.get("operation_timeout", self.per_operation_timeout) or 20
198+
)
199+
result = self.client.execute_query(
200+
query,
201+
self.instance_id,
202+
parameters=formatted_params,
203+
parameter_types=parameter_types,
204+
app_profile_id=app_profile_id,
205+
operation_timeout=operation_timeout,
206+
prepare_operation_timeout=operation_timeout,
207+
)
208+
rows = [r async for r in result]
209+
md = result.metadata
210+
proto_rows = []
211+
for r in rows:
212+
vals = []
213+
for c in md.columns:
214+
vals.append(
215+
sql_encoding_helpers.convert_value(c.column_type, r[c.column_name])
216+
)
217+
proto_rows.append({"values": vals})
218+
proto_columns = []
219+
for c in md.columns:
220+
proto_columns.append(
221+
{
222+
"name": c.column_name,
223+
"type": sql_encoding_helpers.convert_type(c.column_type),
224+
}
225+
)
226+
return {"metadata": {"columns": proto_columns}, "rows": proto_rows}

test_proxy/handlers/grpc_handler.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import time
32

43
import test_proxy_pb2
@@ -59,7 +58,6 @@ def wrapper(self, request, context, **kwargs):
5958

6059
return wrapper
6160

62-
6361
@delegate_to_client_handler
6462
def CreateClient(self, request, context, client_response=None):
6563
return test_proxy_pb2.CreateClientResponse()
@@ -80,15 +78,18 @@ def ReadRows(self, request, context, client_response=None):
8078
status = Status(code=5, message=client_response["error"])
8179
else:
8280
rows = [data_pb2.Row(**d) for d in client_response]
83-
result = test_proxy_pb2.RowsResult(row=rows, status=status)
81+
result = test_proxy_pb2.RowsResult(rows=rows, status=status)
8482
return result
8583

8684
@delegate_to_client_handler
8785
def ReadRow(self, request, context, client_response=None):
8886
status = Status()
8987
row = None
9088
if isinstance(client_response, dict) and "error" in client_response:
91-
status=Status(code=client_response.get("code", 5), message=client_response.get("error"))
89+
status = Status(
90+
code=client_response.get("code", 5),
91+
message=client_response.get("error"),
92+
)
9293
elif client_response != "None":
9394
row = data_pb2.Row(**client_response)
9495
result = test_proxy_pb2.RowResult(row=row, status=status)
@@ -98,38 +99,57 @@ def ReadRow(self, request, context, client_response=None):
9899
def MutateRow(self, request, context, client_response=None):
99100
status = Status()
100101
if isinstance(client_response, dict) and "error" in client_response:
101-
status = Status(code=client_response.get("code", 5), message=client_response["error"])
102+
status = Status(
103+
code=client_response.get("code", 5), message=client_response["error"]
104+
)
102105
return test_proxy_pb2.MutateRowResult(status=status)
103106

104107
@delegate_to_client_handler
105108
def BulkMutateRows(self, request, context, client_response=None):
106109
status = Status()
107110
entries = []
108111
if isinstance(client_response, dict) and "error" in client_response:
109-
entries = [bigtable_pb2.MutateRowsResponse.Entry(index=exc_dict.get("index",1), status=Status(code=exc_dict.get("code", 5))) for exc_dict in client_response.get("subexceptions", [])]
112+
entries = [
113+
bigtable_pb2.MutateRowsResponse.Entry(
114+
index=exc_dict.get("index", 1),
115+
status=Status(code=exc_dict.get("code", 5)),
116+
)
117+
for exc_dict in client_response.get("subexceptions", [])
118+
]
110119
if not entries:
111120
# only return failure on the overall request if there are failed entries
112-
status = Status(code=client_response.get("code", 5), message=client_response["error"])
113-
# TODO: protos were updated. entry is now entries: https://github.com/googleapis/cndb-client-testing-protos/commit/e6205a2bba04acc10d12421a1402870b4a525fb3
114-
response = test_proxy_pb2.MutateRowsResult(status=status, entry=entries)
121+
status = Status(
122+
code=client_response.get("code", 5),
123+
message=client_response["error"],
124+
)
125+
response = test_proxy_pb2.MutateRowsResult(status=status, entries=entries)
115126
return response
116127

117128
@delegate_to_client_handler
118129
def CheckAndMutateRow(self, request, context, client_response=None):
119130
if isinstance(client_response, dict) and "error" in client_response:
120-
status = Status(code=client_response.get("code", 5), message=client_response["error"])
131+
status = Status(
132+
code=client_response.get("code", 5), message=client_response["error"]
133+
)
121134
response = test_proxy_pb2.CheckAndMutateRowResult(status=status)
122135
else:
123-
result = bigtable_pb2.CheckAndMutateRowResponse(predicate_matched=client_response)
124-
response = test_proxy_pb2.CheckAndMutateRowResult(result=result, status=Status())
136+
result = bigtable_pb2.CheckAndMutateRowResponse(
137+
predicate_matched=client_response
138+
)
139+
response = test_proxy_pb2.CheckAndMutateRowResult(
140+
result=result, status=Status()
141+
)
125142
return response
126143

127144
@delegate_to_client_handler
128145
def ReadModifyWriteRow(self, request, context, client_response=None):
129146
status = Status()
130147
row = None
131148
if isinstance(client_response, dict) and "error" in client_response:
132-
status = Status(code=client_response.get("code", 5), message=client_response.get("error"))
149+
status = Status(
150+
code=client_response.get("code", 5),
151+
message=client_response.get("error"),
152+
)
133153
elif client_response != "None":
134154
row = data_pb2.Row(**client_response)
135155
result = test_proxy_pb2.RowResult(row=row, status=status)
@@ -140,9 +160,26 @@ def SampleRowKeys(self, request, context, client_response=None):
140160
status = Status()
141161
sample_list = []
142162
if isinstance(client_response, dict) and "error" in client_response:
143-
status = Status(code=client_response.get("code", 5), message=client_response.get("error"))
163+
status = Status(
164+
code=client_response.get("code", 5),
165+
message=client_response.get("error"),
166+
)
144167
else:
145168
for sample in client_response:
146-
sample_list.append(bigtable_pb2.SampleRowKeysResponse(offset_bytes=sample[1], row_key=sample[0]))
147-
# TODO: protos were updated. sample is now samples: https://github.com/googleapis/cndb-client-testing-protos/commit/e6205a2bba04acc10d12421a1402870b4a525fb3
148-
return test_proxy_pb2.SampleRowKeysResult(status=status, sample=sample_list)
169+
sample_list.append(
170+
bigtable_pb2.SampleRowKeysResponse(
171+
offset_bytes=sample[1], row_key=sample[0]
172+
)
173+
)
174+
return test_proxy_pb2.SampleRowKeysResult(status=status, samples=sample_list)
175+
176+
@delegate_to_client_handler
177+
def ExecuteQuery(self, request, context, client_response=None):
178+
if isinstance(client_response, dict) and "error" in client_response:
179+
return test_proxy_pb2.ExecuteQueryResult(
180+
status=Status(code=13, message=client_response["error"])
181+
)
182+
else:
183+
return test_proxy_pb2.ExecuteQueryResult(
184+
metadata=client_response["metadata"], rows=client_response["rows"]
185+
)

0 commit comments

Comments
 (0)