From 1750328bbc7f8a1125f8e0c38024ced8e195a1b9 Mon Sep 17 00:00:00 2001 From: Astha Mohta <35952883+asthamohta@users.noreply.github.com> Date: Tue, 13 Feb 2024 17:15:47 +0530 Subject: [PATCH] feat: Untyped param (#1001) * changes * change * tests * tests * changes * change * lint * lint --------- Co-authored-by: surbhigarg92 --- google/cloud/spanner_v1/database.py | 2 -- google/cloud/spanner_v1/snapshot.py | 4 --- google/cloud/spanner_v1/transaction.py | 5 --- tests/system/test_session_api.py | 50 ++++++++++++++++++++++++-- tests/unit/test_database.py | 4 --- tests/unit/test_snapshot.py | 22 ------------ tests/unit/test_transaction.py | 24 ------------- 7 files changed, 48 insertions(+), 63 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 1ef2754a6e..650b4fda4c 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -648,8 +648,6 @@ def execute_partitioned_dml( if params is not None: from google.cloud.spanner_v1.transaction import Transaction - if param_types is None: - raise ValueError("Specify 'param_types' when passing 'params'.") params_pb = Transaction._make_params_pb(params, param_types) else: params_pb = {} diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 491ff37d4a..2b6e1ce924 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -410,8 +410,6 @@ def execute_sql( raise ValueError("Transaction ID pending.") if params is not None: - if param_types is None: - raise ValueError("Specify 'param_types' when passing 'params'.") params_pb = Struct( fields={key: _make_value_pb(value) for key, value in params.items()} ) @@ -646,8 +644,6 @@ def partition_query( raise ValueError("Transaction not started.") if params is not None: - if param_types is None: - raise ValueError("Specify 'param_types' when passing 'params'.") params_pb = Struct( fields={key: _make_value_pb(value) for (key, value) in params.items()} ) diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 3c950401ac..1f5ff1098a 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -276,14 +276,9 @@ def _make_params_pb(params, param_types): If ``params`` is None but ``param_types`` is not None. """ if params is not None: - if param_types is None: - raise ValueError("Specify 'param_types' when passing 'params'.") return Struct( fields={key: _make_value_pb(value) for key, value in params.items()} ) - else: - if param_types is not None: - raise ValueError("Specify 'params' when passing 'param_types'.") return {} diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 9ea66b65ec..29d196b011 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -90,6 +90,8 @@ "jsonb_array", ) +QUERY_ALL_TYPES_COLUMNS = LIVE_ALL_TYPES_COLUMNS[1:17:2] + AllTypesRowData = collections.namedtuple("AllTypesRowData", LIVE_ALL_TYPES_COLUMNS) AllTypesRowData.__new__.__defaults__ = tuple([None for colum in LIVE_ALL_TYPES_COLUMNS]) EmulatorAllTypesRowData = collections.namedtuple( @@ -211,6 +213,17 @@ PostGresAllTypesRowData(pkey=309, jsonb_array=[JSON_1, JSON_2, None]), ) +QUERY_ALL_TYPES_DATA = ( + 123, + False, + BYTES_1, + SOME_DATE, + 1.4142136, + "VALUE", + SOME_TIME, + NUMERIC_1, +) + if _helpers.USE_EMULATOR: ALL_TYPES_COLUMNS = EMULATOR_ALL_TYPES_COLUMNS ALL_TYPES_ROWDATA = EMULATOR_ALL_TYPES_ROWDATA @@ -475,6 +488,39 @@ def test_batch_insert_or_update_then_query(sessions_database): sd._check_rows_data(rows) +def test_batch_insert_then_read_wo_param_types( + sessions_database, database_dialect, not_emulator +): + sd = _sample_data + + with sessions_database.batch() as batch: + batch.delete(ALL_TYPES_TABLE, sd.ALL) + batch.insert(ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, ALL_TYPES_ROWDATA) + + with sessions_database.snapshot(multi_use=True) as snapshot: + for column_type, value in list( + zip(QUERY_ALL_TYPES_COLUMNS, QUERY_ALL_TYPES_DATA) + ): + placeholder = ( + "$1" if database_dialect == DatabaseDialect.POSTGRESQL else "@value" + ) + sql = ( + "SELECT * FROM " + + ALL_TYPES_TABLE + + " WHERE " + + column_type + + " = " + + placeholder + ) + param = ( + {"p1": value} + if database_dialect == DatabaseDialect.POSTGRESQL + else {"value": value} + ) + rows = list(snapshot.execute_sql(sql, params=param)) + assert len(rows) == 1 + + def test_batch_insert_w_commit_timestamp(sessions_database, not_postgres): table = "users_history" columns = ["id", "commit_ts", "name", "email", "deleted"] @@ -1930,8 +1976,8 @@ def _check_sql_results( database, sql, params, - param_types, - expected, + param_types=None, + expected=None, order=True, recurse_into_lists=True, ): diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 00c57797ef..6bcacd379b 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -1136,10 +1136,6 @@ def _execute_partitioned_dml_helper( def test_execute_partitioned_dml_wo_params(self): self._execute_partitioned_dml_helper(dml=DML_WO_PARAM) - def test_execute_partitioned_dml_w_params_wo_param_types(self): - with self.assertRaises(ValueError): - self._execute_partitioned_dml_helper(dml=DML_W_PARAM, params=PARAMS) - def test_execute_partitioned_dml_w_params_and_param_types(self): self._execute_partitioned_dml_helper( dml=DML_W_PARAM, params=PARAMS, param_types=PARAM_TYPES diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index aec20c2f54..bf5563dcfd 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -868,16 +868,6 @@ def test_execute_sql_other_error(self): attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), ) - def test_execute_sql_w_params_wo_param_types(self): - database = _Database() - session = _Session(database) - derived = self._makeDerived(session) - - with self.assertRaises(ValueError): - derived.execute_sql(SQL_QUERY_WITH_PARAM, PARAMS) - - self.assertNoSpans() - def _execute_sql_helper( self, multi_use, @@ -1397,18 +1387,6 @@ def test_partition_query_other_error(self): attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), ) - def test_partition_query_w_params_wo_param_types(self): - database = _Database() - session = _Session(database) - derived = self._makeDerived(session) - derived._multi_use = True - derived._transaction_id = TXN_ID - - with self.assertRaises(ValueError): - list(derived.partition_query(SQL_QUERY_WITH_PARAM, PARAMS)) - - self.assertNoSpans() - def test_partition_query_single_use_raises(self): with self.assertRaises(ValueError): self._partition_query_helper(multi_use=False, w_txn=True) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index d391fe4c13..a673eabb83 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -471,20 +471,6 @@ def test_commit_w_incorrect_tag_dictionary_error(self): with self.assertRaises(ValueError): self._commit_helper(request_options=request_options) - def test__make_params_pb_w_params_wo_param_types(self): - session = _Session() - transaction = self._make_one(session) - - with self.assertRaises(ValueError): - transaction._make_params_pb(PARAMS, None) - - def test__make_params_pb_wo_params_w_param_types(self): - session = _Session() - transaction = self._make_one(session) - - with self.assertRaises(ValueError): - transaction._make_params_pb(None, PARAM_TYPES) - def test__make_params_pb_w_params_w_param_types(self): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1._helpers import _make_value_pb @@ -510,16 +496,6 @@ def test_execute_update_other_error(self): with self.assertRaises(RuntimeError): transaction.execute_update(DML_QUERY) - def test_execute_update_w_params_wo_param_types(self): - database = _Database() - database.spanner_api = self._make_spanner_api() - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - - with self.assertRaises(ValueError): - transaction.execute_update(DML_QUERY_WITH_PARAM, PARAMS) - def _execute_update_helper( self, count=0,