Skip to content

Commit

Permalink
chore: support named schemas (#1073)
Browse files Browse the repository at this point in the history
* chore: support named schemas

* chore: import type and typecode

* fix: use magic string instead of method reference as default value

* fix: dialect property now also reloads the database

* Comment addressed

* Fix test

---------

Co-authored-by: Ankit Agarwal <[email protected]>
Co-authored-by: ankiaga <[email protected]>
  • Loading branch information
3 people committed Feb 12, 2024
1 parent 9299212 commit 2bf0319
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 37 deletions.
4 changes: 2 additions & 2 deletions google/cloud/spanner_dbapi/_helpers.py
Expand Up @@ -18,13 +18,13 @@
SQL_LIST_TABLES = """
SELECT table_name
FROM information_schema.tables
WHERE table_catalog = '' AND table_schema = ''
WHERE table_catalog = '' AND table_schema = @table_schema
"""

SQL_GET_TABLE_COLUMN_SCHEMA = """
SELECT COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = '' AND TABLE_NAME = @table_name
WHERE TABLE_SCHEMA = @table_schema AND TABLE_NAME = @table_name
"""

# This table maps spanner_types to Spanner's data type sizes as per
Expand Down
17 changes: 12 additions & 5 deletions google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -510,13 +510,17 @@ def __iter__(self):
raise ProgrammingError("no results to return")
return self._itr

def list_tables(self):
def list_tables(self, schema_name=""):
"""List the tables of the linked Database.
:rtype: list
:returns: The list of tables within the Database.
"""
return self.run_sql_in_snapshot(_helpers.SQL_LIST_TABLES)
return self.run_sql_in_snapshot(
sql=_helpers.SQL_LIST_TABLES,
params={"table_schema": schema_name},
param_types={"table_schema": spanner.param_types.STRING},
)

def run_sql_in_snapshot(self, sql, params=None, param_types=None):
# Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions
Expand All @@ -528,11 +532,14 @@ def run_sql_in_snapshot(self, sql, params=None, param_types=None):
with self.connection.database.snapshot() as snapshot:
return list(snapshot.execute_sql(sql, params, param_types))

def get_table_column_schema(self, table_name):
def get_table_column_schema(self, table_name, schema_name=""):
rows = self.run_sql_in_snapshot(
sql=_helpers.SQL_GET_TABLE_COLUMN_SCHEMA,
params={"table_name": table_name},
param_types={"table_name": spanner.param_types.STRING},
params={"schema_name": schema_name, "table_name": table_name},
param_types={
"schema_name": spanner.param_types.STRING,
"table_name": spanner.param_types.STRING,
},
)

column_details = {}
Expand Down
47 changes: 41 additions & 6 deletions google/cloud/spanner_v1/database.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""User friendly container for Cloud Spanner Database."""
"""User-friendly container for Cloud Spanner Database."""

import copy
import functools
Expand Down Expand Up @@ -42,6 +42,8 @@
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
from google.cloud.spanner_dbapi.partition_helper import BatchTransactionId
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import Type
from google.cloud.spanner_v1 import TypeCode
from google.cloud.spanner_v1 import TransactionSelector
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud.spanner_v1 import RequestOptions
Expand Down Expand Up @@ -334,8 +336,21 @@ def database_dialect(self):
:rtype: :class:`google.cloud.spanner_admin_database_v1.types.DatabaseDialect`
:returns: the dialect of the database
"""
if self._database_dialect == DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED:
self.reload()
return self._database_dialect

@property
def default_schema_name(self):
"""Default schema name for this database.
:rtype: str
:returns: "" for GoogleSQL and "public" for PostgreSQL
"""
if self.database_dialect == DatabaseDialect.POSTGRESQL:
return "public"
return ""

@property
def database_role(self):
"""User-assigned database_role for sessions created by the pool.
Expand Down Expand Up @@ -961,20 +976,40 @@ def table(self, table_id):
"""
return Table(table_id, self)

def list_tables(self):
def list_tables(self, schema="_default"):
"""List tables within the database.
:type schema: str
:param schema: The schema to search for tables, or None for all schemas. Use the special string "_default" to
search for tables in the default schema of the database.
:type: Iterable
:returns:
Iterable of :class:`~google.cloud.spanner_v1.table.Table`
resources within the current database.
"""
if "_default" == schema:
schema = self.default_schema_name

with self.snapshot() as snapshot:
if self._database_dialect == DatabaseDialect.POSTGRESQL:
where_clause = "WHERE TABLE_SCHEMA = 'public'"
if schema is None:
results = snapshot.execute_sql(
sql=_LIST_TABLES_QUERY.format(""),
)
else:
where_clause = "WHERE SPANNER_STATE = 'COMMITTED'"
results = snapshot.execute_sql(_LIST_TABLES_QUERY.format(where_clause))
if self._database_dialect == DatabaseDialect.POSTGRESQL:
where_clause = "WHERE TABLE_SCHEMA = $1"
param_name = "p1"
else:
where_clause = (
"WHERE TABLE_SCHEMA = @schema AND SPANNER_STATE = 'COMMITTED'"
)
param_name = "schema"
results = snapshot.execute_sql(
sql=_LIST_TABLES_QUERY.format(where_clause),
params={param_name: schema},
param_types={param_name: Type(code=TypeCode.STRING)},
)
for row in results:
yield self.table(row[0])

Expand Down
68 changes: 55 additions & 13 deletions google/cloud/spanner_v1/table.py
Expand Up @@ -43,13 +43,26 @@ class Table(object):
:param database: The database that owns the table.
"""

def __init__(self, table_id, database):
def __init__(self, table_id, database, schema_name=None):
if schema_name is None:
self._schema_name = database.default_schema_name
else:
self._schema_name = schema_name
self._table_id = table_id
self._database = database

# Calculated properties.
self._schema = None

@property
def schema_name(self):
"""The schema name of the table used in SQL.
:rtype: str
:returns: The table schema name.
"""
return self._schema_name

@property
def table_id(self):
"""The ID of the table used in SQL.
Expand All @@ -59,6 +72,30 @@ def table_id(self):
"""
return self._table_id

@property
def qualified_table_name(self):
"""The qualified name of the table used in SQL.
:rtype: str
:returns: The qualified table name.
"""
if self.schema_name == self._database.default_schema_name:
return self._quote_identifier(self.table_id)
return "{}.{}".format(
self._quote_identifier(self.schema_name),
self._quote_identifier(self.table_id),
)

def _quote_identifier(self, identifier):
"""Quotes the given identifier using the rules of the dialect of the database of this table.
:rtype: str
:returns: The quoted identifier.
"""
if self._database.database_dialect == DatabaseDialect.POSTGRESQL:
return '"{}"'.format(identifier)
return "`{}`".format(identifier)

def exists(self):
"""Test whether this table exists.
Expand All @@ -77,22 +114,27 @@ def _exists(self, snapshot):
:rtype: bool
:returns: True if the table exists, else false.
"""
if (
self._database.database_dialect
== DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED
):
self._database.reload()
if self._database.database_dialect == DatabaseDialect.POSTGRESQL:
results = snapshot.execute_sql(
_EXISTS_TEMPLATE.format("WHERE TABLE_NAME = $1"),
params={"p1": self.table_id},
param_types={"p1": Type(code=TypeCode.STRING)},
sql=_EXISTS_TEMPLATE.format(
"WHERE TABLE_SCHEMA=$1 AND TABLE_NAME = $2"
),
params={"p1": self.schema_name, "p2": self.table_id},
param_types={
"p1": Type(code=TypeCode.STRING),
"p2": Type(code=TypeCode.STRING),
},
)
else:
results = snapshot.execute_sql(
_EXISTS_TEMPLATE.format("WHERE TABLE_NAME = @table_id"),
params={"table_id": self.table_id},
param_types={"table_id": Type(code=TypeCode.STRING)},
sql=_EXISTS_TEMPLATE.format(
"WHERE TABLE_SCHEMA = @schema_name AND TABLE_NAME = @table_id"
),
params={"schema_name": self.schema_name, "table_id": self.table_id},
param_types={
"schema_name": Type(code=TypeCode.STRING),
"table_id": Type(code=TypeCode.STRING),
},
)
return next(iter(results))[0]

Expand All @@ -117,7 +159,7 @@ def _get_schema(self, snapshot):
:rtype: list of :class:`~google.cloud.spanner_v1.types.StructType.Field`
:returns: The table schema.
"""
query = _GET_SCHEMA_TEMPLATE.format(self.table_id)
query = _GET_SCHEMA_TEMPLATE.format(self.qualified_table_name)
results = snapshot.execute_sql(query)
# Start iterating to force the schema to download.
try:
Expand Down
2 changes: 1 addition & 1 deletion tests/system/test_table_api.py
Expand Up @@ -33,7 +33,7 @@ def test_table_exists_reload_database_dialect(
shared_instance, shared_database, not_emulator
):
database = shared_instance.database(shared_database.database_id)
assert database.database_dialect == DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED
assert database.database_dialect != DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED
table = database.table("all_types")
assert table.exists()
assert database.database_dialect != DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED
Expand Down
14 changes: 11 additions & 3 deletions tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -936,6 +936,7 @@ def test_iter(self):

def test_list_tables(self):
from google.cloud.spanner_dbapi import _helpers
from google.cloud.spanner_v1 import param_types

connection = self._make_connection(self.INSTANCE, self.DATABASE)
cursor = self._make_one(connection)
Expand All @@ -946,7 +947,11 @@ def test_list_tables(self):
return_value=table_list,
) as mock_run_sql:
cursor.list_tables()
mock_run_sql.assert_called_once_with(_helpers.SQL_LIST_TABLES)
mock_run_sql.assert_called_once_with(
sql=_helpers.SQL_LIST_TABLES,
params={"table_schema": ""},
param_types={"table_schema": param_types.STRING},
)

def test_run_sql_in_snapshot(self):
connection = self._make_connection(self.INSTANCE, mock.MagicMock())
Expand Down Expand Up @@ -987,8 +992,11 @@ def test_get_table_column_schema(self):
result = cursor.get_table_column_schema(table_name=table_name)
mock_run_sql.assert_called_once_with(
sql=_helpers.SQL_GET_TABLE_COLUMN_SCHEMA,
params={"table_name": table_name},
param_types={"table_name": param_types.STRING},
params={"schema_name": "", "table_name": table_name},
param_types={
"schema_name": param_types.STRING,
"table_name": param_types.STRING,
},
)
self.assertEqual(result, expected)

Expand Down
13 changes: 12 additions & 1 deletion tests/unit/test_database.py
Expand Up @@ -17,7 +17,10 @@

import mock
from google.api_core import gapic_v1
from google.cloud.spanner_admin_database_v1 import Database as DatabasePB
from google.cloud.spanner_admin_database_v1 import (
Database as DatabasePB,
DatabaseDialect,
)
from google.cloud.spanner_v1.param_types import INT64
from google.api_core.retry import Retry
from google.protobuf.field_mask_pb2 import FieldMask
Expand Down Expand Up @@ -1680,6 +1683,7 @@ def test_table_factory_defaults(self):
instance = _Instance(self.INSTANCE_NAME, client=client)
pool = _Pool()
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
database._database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
my_table = database.table("my_table")
self.assertIsInstance(my_table, Table)
self.assertIs(my_table._database, database)
Expand Down Expand Up @@ -3011,6 +3015,12 @@ def _make_instance_api():
return mock.create_autospec(InstanceAdminClient)


def _make_database_admin_api():
from google.cloud.spanner_admin_database_v1 import DatabaseAdminClient

return mock.create_autospec(DatabaseAdminClient)


class _Client(object):
def __init__(
self,
Expand All @@ -3023,6 +3033,7 @@ def __init__(
self.project = project
self.project_name = "projects/" + self.project
self._endpoint_cache = {}
self.database_admin_api = _make_database_admin_api()
self.instance_admin_api = _make_instance_api()
self._client_info = mock.Mock()
self._client_options = mock.Mock()
Expand Down
19 changes: 13 additions & 6 deletions tests/unit/test_table.py
Expand Up @@ -26,6 +26,7 @@

class _BaseTest(unittest.TestCase):
TABLE_ID = "test_table"
TABLE_SCHEMA = ""

def _make_one(self, *args, **kwargs):
return self._get_target_class()(*args, **kwargs)
Expand Down Expand Up @@ -55,13 +56,18 @@ def test_exists_executes_query(self):
db.snapshot.return_value = checkout
checkout.__enter__.return_value = snapshot
snapshot.execute_sql.return_value = [[False]]
table = self._make_one(self.TABLE_ID, db)
table = self._make_one(self.TABLE_ID, db, schema_name=self.TABLE_SCHEMA)
exists = table.exists()
self.assertFalse(exists)
snapshot.execute_sql.assert_called_with(
_EXISTS_TEMPLATE.format("WHERE TABLE_NAME = @table_id"),
params={"table_id": self.TABLE_ID},
param_types={"table_id": Type(code=TypeCode.STRING)},
_EXISTS_TEMPLATE.format(
"WHERE TABLE_SCHEMA = @schema_name AND TABLE_NAME = @table_id"
),
params={"schema_name": self.TABLE_SCHEMA, "table_id": self.TABLE_ID},
param_types={
"schema_name": Type(code=TypeCode.STRING),
"table_id": Type(code=TypeCode.STRING),
},
)

def test_schema_executes_query(self):
Expand All @@ -70,14 +76,15 @@ def test_schema_executes_query(self):
from google.cloud.spanner_v1.table import _GET_SCHEMA_TEMPLATE

db = mock.create_autospec(Database, instance=True)
db.default_schema_name = ""
checkout = mock.create_autospec(SnapshotCheckout, instance=True)
snapshot = mock.create_autospec(Snapshot, instance=True)
db.snapshot.return_value = checkout
checkout.__enter__.return_value = snapshot
table = self._make_one(self.TABLE_ID, db)
table = self._make_one(self.TABLE_ID, db, schema_name=self.TABLE_SCHEMA)
schema = table.schema
self.assertIsInstance(schema, list)
expected_query = _GET_SCHEMA_TEMPLATE.format(self.TABLE_ID)
expected_query = _GET_SCHEMA_TEMPLATE.format("`{}`".format(self.TABLE_ID))
snapshot.execute_sql.assert_called_with(expected_query)

def test_schema_returns_cache(self):
Expand Down

0 comments on commit 2bf0319

Please sign in to comment.