Skip to content

Commit

Permalink
fix: Prevent sending full table scan when retrying (backport #554) (#697
Browse files Browse the repository at this point in the history
)

Update the retry logic. Don't send empty row_key and empty row_ranges
if the original message didn't ask for those.

Closes internal issue 214449800

* Create InvalidRetryRequest exception.
Raise InvalidRetryRequest instead of StopIteration
Catch the InvalidRetryRequest
Handle stop the retry request if row_limit has been reached.
  • Loading branch information
igorbernstein2 committed Nov 18, 2022
1 parent 14415e9 commit c4ae6ad
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
40 changes: 28 additions & 12 deletions google/cloud/bigtable/row_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ class InvalidChunk(RuntimeError):
"""Exception raised to to invalid chunk data from back-end."""


class InvalidRetryRequest(RuntimeError):
"""Exception raised when retry request is invalid."""


def _retry_read_rows_exception(exc):
if isinstance(exc, grpc.RpcError):
exc = exceptions.from_grpc_error(exc)
Expand Down Expand Up @@ -487,6 +491,9 @@ def __iter__(self):
if self.state != self.NEW_ROW:
raise ValueError("The row remains partial / is not committed.")
break
except InvalidRetryRequest:
self._cancelled = True
break

for chunk in response.chunks:
if self._cancelled:
Expand Down Expand Up @@ -625,29 +632,38 @@ def __init__(self, message, last_scanned_key, rows_read_so_far):

def build_updated_request(self):
"""Updates the given message request as per last scanned key"""
r_kwargs = {
"table_name": self.message.table_name,
"filter": self.message.filter,
}

resume_request = data_messages_v2_pb2.ReadRowsRequest()
data_messages_v2_pb2.ReadRowsRequest.CopyFrom(resume_request, self.message)
resume_request.rows.Clear()

if self.message.rows_limit != 0:
r_kwargs["rows_limit"] = max(
1, self.message.rows_limit - self.rows_read_so_far
)
row_limit_remaining = self.message.rows_limit - self.rows_read_so_far
if row_limit_remaining > 0:
resume_request.rows_limit = row_limit_remaining
else:
raise InvalidRetryRequest

# if neither RowSet.row_keys nor RowSet.row_ranges currently exist,
# add row_range that starts with last_scanned_key as start_key_open
# to request only rows that have not been returned yet
if not self.message.HasField("rows"):
row_range = data_v2_pb2.RowRange(start_key_open=self.last_scanned_key)
r_kwargs["rows"] = data_v2_pb2.RowSet(row_ranges=[row_range])
resume_request.rows.row_ranges.add().CopyFrom(row_range)
else:
row_keys = self._filter_rows_keys()
row_ranges = self._filter_row_ranges()
r_kwargs["rows"] = data_v2_pb2.RowSet(
row_keys=row_keys, row_ranges=row_ranges
)
return data_messages_v2_pb2.ReadRowsRequest(**r_kwargs)

if len(row_keys) == 0 and len(row_ranges) == 0:
# Avoid sending empty row_keys and row_ranges
# if that was not the intention
raise InvalidRetryRequest

resume_request.rows.row_keys[:] = row_keys
for rr in row_ranges:
resume_request.rows.row_ranges.add().CopyFrom(rr)

return resume_request

def _filter_rows_keys(self):
"""Helper for :meth:`build_updated_request`"""
Expand Down
6 changes: 2 additions & 4 deletions tests/unit/test_row_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def test_build_updated_request_full_table(self):
request_manager = self._make_one(request, last_scanned_key, 2)

result = request_manager.build_updated_request()
expected_result = _ReadRowsRequestPB(table_name=self.table_name, filter={})
expected_result = _ReadRowsRequestPB(table_name=self.table_name)
expected_result.rows.row_ranges.add(start_key_open=last_scanned_key)
self.assertEqual(expected_result, result)

Expand Down Expand Up @@ -940,9 +940,7 @@ def test_build_updated_request_rows_limit(self):
request_manager = self._make_one(request, last_scanned_key, 2)

result = request_manager.build_updated_request()
expected_result = _ReadRowsRequestPB(
table_name=self.table_name, filter={}, rows_limit=8
)
expected_result = _ReadRowsRequestPB(table_name=self.table_name, rows_limit=8)
expected_result.rows.row_ranges.add(start_key_open=last_scanned_key)
self.assertEqual(expected_result, result)

Expand Down

0 comments on commit c4ae6ad

Please sign in to comment.