Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support partitioned dml in dbapi #1103

Merged
merged 5 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
partition_ids,
)
if statement_type == ClientSideStatementType.RUN_PARTITION:
return connection.run_partition(
parsed_statement.client_side_statement_params[0]
)
return connection.run_partition(parsed_statement)
if statement_type == ClientSideStatementType.RUN_PARTITIONED_QUERY:
return connection.run_partitioned_query(parsed_statement)
if statement_type == ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE:
return connection.set_autocommit_dml_mode(parsed_statement)


def _get_streamed_result_set(column_name, type_code, column_values):
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
RE_RUN_PARTITIONED_QUERY = re.compile(
r"^\s*(RUN)\s+(PARTITIONED)\s+(QUERY)\s+(.+)", re.IGNORECASE
)
RE_SET_AUTOCOMMIT_DML_MODE = re.compile(
r"^\s*(SET)\s+(AUTOCOMMIT_DML_MODE)\s+(=)\s+(.+)", re.IGNORECASE
)


def parse_stmt(query):
Expand Down Expand Up @@ -82,6 +85,10 @@ def parse_stmt(query):
match = re.search(RE_RUN_PARTITION, query)
client_side_statement_params.append(match.group(3))
client_side_statement_type = ClientSideStatementType.RUN_PARTITION
elif RE_SET_AUTOCOMMIT_DML_MODE.match(query):
match = re.search(RE_SET_AUTOCOMMIT_DML_MODE, query)
client_side_statement_params.append(match.group(4))
client_side_statement_type = ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE
if client_side_statement_type is not None:
return ParsedStatement(
StatementType.CLIENT_SIDE,
Expand Down
31 changes: 30 additions & 1 deletion google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from google.cloud.spanner_dbapi.parse_utils import _get_statement_type
from google.cloud.spanner_dbapi.parsed_statement import (
StatementType,
AutocommitDmlMode,
)
from google.cloud.spanner_dbapi.partition_helper import PartitionId
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(self, instance, database=None, read_only=False):
self._batch_mode = BatchMode.NONE
self._batch_dml_executor: BatchDmlExecutor = None
self._transaction_helper = TransactionRetryHelper(self)
self._autocommit_dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL

@property
def spanner_client(self):
Expand Down Expand Up @@ -155,6 +157,14 @@ def database(self):
"""
return self._database

@property
def autocommit_dml_mode(self):
"""AutocommitDmlMode of this connection.
ankiaga marked this conversation as resolved.
Show resolved Hide resolved

:rtype: :class:`~google.cloud.spanner_dbapi.parsed_statement.AutocommitDmlMode`
"""
return self._autocommit_dml_mode

@property
@deprecated(
reason="This method is deprecated. Use _spanner_transaction_started field"
Expand Down Expand Up @@ -540,7 +550,8 @@ def partition_query(
return partition_ids

@check_not_closed
def run_partition(self, encoded_partition_id):
def run_partition(self, parsed_statement: ParsedStatement):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change of a public method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted

encoded_partition_id = parsed_statement.client_side_statement_params[0]
partition_id: PartitionId = partition_helper.decode_from_string(
encoded_partition_id
)
Expand All @@ -565,6 +576,24 @@ def run_partitioned_query(
partitioned_query, statement.params, statement.param_types
)

@check_not_closed
def set_autocommit_dml_mode(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it would make more sense to have a public method that takes a AutocommitDmlMode enum value as an argument, rather than one that takes a parsed statement as an argument. This method looks like a private method that should only be called by our own parser.

Instead, we should:

  1. Have a private method that is basically the same as this.
  2. Have a public method that takes an AutocommitDmlMode enum as an input argument and that actually changes the value of the flag. That method would then also be usable for anyone using the dbapi driver programmatically, as they could just call that method to change the value.

See https://github.com/googleapis/java-spanner-jdbc/blob/910a130a02f72c4d8764f12e347f6c3d1bd51b2b/src/main/java/com/google/cloud/spanner/jdbc/CloudSpannerJdbcConnection.java#L136 for how the API in the JDBC driver looks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ankiaga marked this conversation as resolved.
Show resolved Hide resolved
self,
parsed_statement: ParsedStatement,
):
if self._client_transaction_started is True:
raise ProgrammingError(
"Cannot set autocommit DML mode while not in autocommit mode or while a transaction is active."
)
if self.read_only is True:
raise ProgrammingError(
"Cannot set autocommit DML mode for a read-only connection."
)
if self._batch_mode is not BatchMode.NONE:
raise ProgrammingError("Cannot set autocommit DML mode while in a batch.")
autocommit_dml_mode_str = parsed_statement.client_side_statement_params[0]
self._autocommit_dml_mode = AutocommitDmlMode[autocommit_dml_mode_str.upper()]

def _partitioned_query_validation(self, partitioned_query, statement):
if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY:
raise ProgrammingError(
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
StatementType,
Statement,
ParsedStatement,
AutocommitDmlMode,
)
from google.cloud.spanner_dbapi.transaction_helper import CursorStatementType
from google.cloud.spanner_dbapi.utils import PeekIterator
Expand Down Expand Up @@ -272,6 +273,17 @@ def _execute(self, sql, args=None, call_from_execute_many=False):
self._batch_DDLs(sql)
if not self.connection._client_transaction_started:
self.connection.run_prior_DDL_statements()
elif (
self.connection.autocommit_dml_mode
is AutocommitDmlMode.PARTITIONED_NON_ATOMIC
):
self._row_count = self.connection.database.execute_partitioned_dml(
sql,
params=args,
param_types=self._parsed_statement.statement.param_types,
request_options=self.connection.request_options,
)
self._result_set = None
else:
self._execute_in_rw_transaction()

Expand Down
6 changes: 6 additions & 0 deletions google/cloud/spanner_dbapi/parsed_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ class ClientSideStatementType(Enum):
PARTITION_QUERY = 9
RUN_PARTITION = 10
RUN_PARTITIONED_QUERY = 11
SET_AUTOCOMMIT_DML_MODE = 12


class AutocommitDmlMode(Enum):
TRANSACTIONAL = 1
PARTITIONED_NON_ATOMIC = 2


@dataclass
Expand Down
22 changes: 22 additions & 0 deletions tests/system/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
OperationalError,
RetryAborted,
)
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from google.cloud.spanner_v1 import JsonObject
from google.cloud.spanner_v1 import gapic_version as package_version
from google.api_core.datetime_helpers import DatetimeWithNanoseconds
Expand Down Expand Up @@ -669,6 +670,27 @@ def test_run_partitioned_query(self):
assert len(rows) == 10
self._conn.commit()

def test_partitioned_dml_query(self):
"""Test partitioned_dml query works in autocommit mode."""
self._cursor.execute("start batch dml")
for i in range(1, 11):
self._insert_row(i)
self._cursor.execute("run batch")
self._conn.commit()

self._conn.autocommit = True
self._cursor.execute("set autocommit_dml_mode = PARTITIONED_NON_ATOMIC")
self._cursor.execute("DELETE FROM contacts WHERE contact_id > 3")
assert self._cursor.rowcount == 7

self._cursor.execute("set autocommit_dml_mode = TRANSACTIONAL")
assert self._conn.autocommit_dml_mode == AutocommitDmlMode.TRANSACTIONAL

self._conn.autocommit = False
# Test changing autocommit_dml_mode is not allowed when connection is in autocommit mode
with pytest.raises(ProgrammingError):
self._cursor.execute("set autocommit_dml_mode = PARTITIONED_NON_ATOMIC")

def _insert_row(self, i):
self._cursor.execute(
f"""
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
ParsedStatement,
StatementType,
Statement,
ClientSideStatementType,
AutocommitDmlMode,
)

PROJECT = "test-project"
Expand Down Expand Up @@ -414,6 +416,62 @@ def test_abort_dml_batch(self, mock_batch_dml_executor):
self.assertEqual(self._under_test._batch_mode, BatchMode.NONE)
self.assertEqual(self._under_test._batch_dml_executor, None)

def test_set_autocommit_dml_mode_with_autocommit_false(self):
self._under_test.autocommit = False
parsed_statement = ParsedStatement(
StatementType.CLIENT_SIDE,
Statement("sql"),
ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE,
["PARTITIONED_NON_ATOMIC"],
)

with self.assertRaises(ProgrammingError):
self._under_test.set_autocommit_dml_mode(parsed_statement)

def test_set_autocommit_dml_mode_with_readonly(self):
self._under_test.autocommit = True
self._under_test.read_only = True
parsed_statement = ParsedStatement(
StatementType.CLIENT_SIDE,
Statement("sql"),
ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE,
["PARTITIONED_NON_ATOMIC"],
)

with self.assertRaises(ProgrammingError):
self._under_test.set_autocommit_dml_mode(parsed_statement)

def test_set_autocommit_dml_mode_with_batch_mode(self):
self._under_test.autocommit = True
parsed_statement = ParsedStatement(
StatementType.CLIENT_SIDE,
Statement("sql"),
ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE,
["PARTITIONED_NON_ATOMIC"],
)

self._under_test.set_autocommit_dml_mode(parsed_statement)

assert (
self._under_test.autocommit_dml_mode
== AutocommitDmlMode.PARTITIONED_NON_ATOMIC
)

def test_set_autocommit_dml_mode(self):
self._under_test.autocommit = True
parsed_statement = ParsedStatement(
StatementType.CLIENT_SIDE,
Statement("sql"),
ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE,
["PARTITIONED_NON_ATOMIC"],
)

self._under_test.set_autocommit_dml_mode(parsed_statement)
assert (
self._under_test.autocommit_dml_mode
== AutocommitDmlMode.PARTITIONED_NON_ATOMIC
)

@mock.patch("google.cloud.spanner_v1.database.Database", autospec=True)
def test_run_prior_DDL_statements(self, mock_database):
from google.cloud.spanner_dbapi import Connection, InterfaceError
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/spanner_dbapi/test_parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,20 @@ def test_run_partitioned_query_classify_stmt(self):
),
)

def test_set_autocommit_dml_mode_stmt(self):
parsed_statement = classify_statement(
" set autocommit_dml_mode = PARTITIONED_NON_ATOMIC "
)
self.assertEqual(
parsed_statement,
ParsedStatement(
StatementType.CLIENT_SIDE,
Statement("set autocommit_dml_mode = PARTITIONED_NON_ATOMIC"),
ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE,
["PARTITIONED_NON_ATOMIC"],
),
)

@unittest.skipIf(skip_condition, skip_message)
def test_sql_pyformat_args_to_spanner(self):
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
Expand Down