summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2020-09-01 17:12:11 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2020-09-01 17:12:11 +0000
commit7439697ca3c8b07af79de6a146a15c34fc85d9d6 (patch)
treeec0f4a477276017f2b8e48646e47ca64c13c83ec
parentd61cf0a9ad91edfcb569214a19122a9572fbb29b (diff)
parent516131c40da9c8cd304061850e2d98e309966dd5 (diff)
downloadsqlalchemy-7439697ca3c8b07af79de6a146a15c34fc85d9d6.tar.gz
Merge "Improve reflection for mssql temporary tables"
-rw-r--r--doc/build/changelog/unreleased_14/5506.rst7
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py37
-rw-r--r--lib/sqlalchemy/dialects/mssql/information_schema.py22
-rw-r--r--lib/sqlalchemy/dialects/mssql/provision.py12
-rw-r--r--lib/sqlalchemy/testing/provision.py15
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py14
-rw-r--r--test/dialect/mssql/test_reflection.py49
-rw-r--r--test/requirements.py10
8 files changed, 154 insertions, 12 deletions
diff --git a/doc/build/changelog/unreleased_14/5506.rst b/doc/build/changelog/unreleased_14/5506.rst
new file mode 100644
index 000000000..71b57322d
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/5506.rst
@@ -0,0 +1,7 @@
+.. change::
+ :tags: usecase, mssql
+ :tickets: 5506
+
+ Added support for reflection of temporary tables with the SQL Server dialect.
+ Table names that are prefixed by a pound sign "#" are now introspected from
+ the MSSQL "tempdb" system catalog.
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index f38c537fd..ed17fb863 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -2913,11 +2913,46 @@ class MSDialect(default.DefaultDialect):
view_def = rp.scalar()
return view_def
+ def _get_internal_temp_table_name(self, connection, tablename):
+ result = connection.execute(
+ sql.text(
+ "select table_name "
+ "from tempdb.information_schema.tables "
+ "where table_name like :p1"
+ ),
+ {
+ "p1": tablename
+ + (("___%") if not tablename.startswith("##") else "")
+ },
+ ).fetchall()
+ if len(result) > 1:
+ raise exc.UnreflectableTableError(
+ "Found more than one temporary table named '%s' in tempdb "
+ "at this time. Cannot reliably resolve that name to its "
+ "internal table name." % tablename
+ )
+ elif len(result) == 0:
+ raise exc.NoSuchTableError(
+ "Unable to find a temporary table named '%s' in tempdb."
+ % tablename
+ )
+ else:
+ return result[0][0]
+
@reflection.cache
@_db_plus_owner
def get_columns(self, connection, tablename, dbname, owner, schema, **kw):
+ is_temp_table = tablename.startswith("#")
+ if is_temp_table:
+ tablename = self._get_internal_temp_table_name(
+ connection, tablename
+ )
# Get base columns
- columns = ischema.columns
+ columns = (
+ ischema.mssql_temp_table_columns
+ if is_temp_table
+ else ischema.columns
+ )
computed_cols = ischema.computed_columns
if owner:
whereclause = sql.and_(
diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py
index 6cdde8386..f80110b7d 100644
--- a/lib/sqlalchemy/dialects/mssql/information_schema.py
+++ b/lib/sqlalchemy/dialects/mssql/information_schema.py
@@ -5,9 +5,6 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-# TODO: should be using the sys. catalog with SQL Server, not information
-# schema
-
from ... import cast
from ... import Column
from ... import MetaData
@@ -93,6 +90,25 @@ columns = Table(
schema="INFORMATION_SCHEMA",
)
+mssql_temp_table_columns = Table(
+ "COLUMNS",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("IS_NULLABLE", Integer, key="is_nullable"),
+ Column("DATA_TYPE", String, key="data_type"),
+ Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+ Column(
+ "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
+ ),
+ Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
+ Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
+ Column("COLUMN_DEFAULT", Integer, key="column_default"),
+ Column("COLLATION_NAME", String, key="collation_name"),
+ schema="tempdb.INFORMATION_SCHEMA",
+)
+
constraints = Table(
"TABLE_CONSTRAINTS",
ischema,
diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py
index a5131eae6..269eb164f 100644
--- a/lib/sqlalchemy/dialects/mssql/provision.py
+++ b/lib/sqlalchemy/dialects/mssql/provision.py
@@ -2,8 +2,10 @@ from ... import create_engine
from ... import exc
from ...testing.provision import create_db
from ...testing.provision import drop_db
+from ...testing.provision import get_temp_table_name
from ...testing.provision import log
from ...testing.provision import run_reap_dbs
+from ...testing.provision import temp_table_keyword_args
@create_db.for_db("mssql")
@@ -72,3 +74,13 @@ def _reap_mssql_dbs(url, idents):
log.info(
"Dropped %d out of %d stale databases detected", dropped, total
)
+
+
+@temp_table_keyword_args.for_db("mssql")
+def _mssql_temp_table_keyword_args(cfg, eng):
+ return {}
+
+
+@get_temp_table_name.for_db("mssql")
+def _mssql_get_temp_table_name(cfg, eng, base_name):
+ return "#" + base_name
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
index 0edaae490..8bdad357c 100644
--- a/lib/sqlalchemy/testing/provision.py
+++ b/lib/sqlalchemy/testing/provision.py
@@ -296,3 +296,18 @@ def temp_table_keyword_args(cfg, eng):
raise NotImplementedError(
"no temp table keyword args routine for cfg: %s" % eng.url
)
+
+
+@register.init
+def get_temp_table_name(cfg, eng, base_name):
+ """Specify table name for creating a temporary Table.
+
+ Dialect-specific implementations of this method will return the
+ name to use when creating a temporary table for testing,
+ e.g., in the define_temp_tables method of the
+ ComponentReflectionTest class in suite/test_reflection.py
+
+ Default to just the base name since that's what most dialects will
+ use. The mssql dialect's implementation will need a "#" prepended.
+ """
+ return base_name
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index 151be757a..94ec22c1e 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -8,6 +8,7 @@ from .. import eq_
from .. import expect_warnings
from .. import fixtures
from .. import is_
+from ..provision import get_temp_table_name
from ..provision import temp_table_keyword_args
from ..schema import Column
from ..schema import Table
@@ -442,8 +443,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
@classmethod
def define_temp_tables(cls, metadata):
kw = temp_table_keyword_args(config, config.db)
+ table_name = get_temp_table_name(config, config.db, "user_tmp")
user_tmp = Table(
- "user_tmp",
+ table_name,
metadata,
Column("id", sa.INT, primary_key=True),
Column("name", sa.VARCHAR(50)),
@@ -736,10 +738,11 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.temp_table_reflection
def test_get_temp_table_columns(self):
+ table_name = get_temp_table_name(config, config.db, "user_tmp")
meta = MetaData(self.bind)
- user_tmp = self.tables.user_tmp
+ user_tmp = self.tables[table_name]
insp = inspect(meta.bind)
- cols = insp.get_columns("user_tmp")
+ cols = insp.get_columns(table_name)
self.assert_(len(cols) > 0, len(cols))
for i, col in enumerate(user_tmp.columns):
@@ -1051,10 +1054,11 @@ class ComponentReflectionTest(fixtures.TablesTest):
refl.pop("duplicates_index", None)
eq_(reflected, [{"column_names": ["name"], "name": "user_tmp_uq"}])
- @testing.requires.temp_table_reflection
+ @testing.requires.temp_table_reflect_indexes
def test_get_temp_table_indexes(self):
insp = inspect(self.bind)
- indexes = insp.get_indexes("user_tmp")
+ table_name = get_temp_table_name(config, config.db, "user_tmp")
+ indexes = insp.get_indexes(table_name)
for ind in indexes:
ind.pop("dialect_options", None)
eq_(
diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py
index 0bd8f7a5a..67bde6fb3 100644
--- a/test/dialect/mssql/test_reflection.py
+++ b/test/dialect/mssql/test_reflection.py
@@ -1,7 +1,10 @@
# -*- encoding: utf-8
+import datetime
+
from sqlalchemy import Column
from sqlalchemy import DDL
from sqlalchemy import event
+from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import Index
from sqlalchemy import inspect
@@ -246,6 +249,52 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
eq_(t.name, "ABCDEFGHIJKLMNOPQRSTUVWXYZ")
@testing.provide_metadata
+ @testing.combinations(
+ ("local_temp", "#tmp", True),
+ ("global_temp", "##tmp", True),
+ ("nonexistent", "#no_es_bueno", False),
+ id_="iaa",
+ )
+ def test_temporary_table(self, table_name, exists):
+ metadata = self.metadata
+ if exists:
+ # TODO: why this test hangs when using the connection fixture?
+ with testing.db.connect() as conn:
+ tran = conn.begin()
+ conn.execute(
+ (
+ "CREATE TABLE %s "
+ "(id int primary key, txt nvarchar(50), dt2 datetime2)" # noqa
+ )
+ % table_name
+ )
+ conn.execute(
+ (
+ "INSERT INTO %s (id, txt, dt2) VALUES "
+ "(1, N'foo', '2020-01-01 01:01:01'), "
+ "(2, N'bar', '2020-02-02 02:02:02') "
+ )
+ % table_name
+ )
+ tran.commit()
+ tran = conn.begin()
+ try:
+ tmp_t = Table(
+ table_name, metadata, autoload_with=testing.db,
+ )
+ tran.commit()
+ result = conn.execute(
+ tmp_t.select().where(tmp_t.c.id == 2)
+ ).fetchall()
+ eq_(
+ result,
+ [(2, "bar", datetime.datetime(2020, 2, 2, 2, 2, 2))],
+ )
+ except exc.NoSuchTableError:
+ if exists:
+ raise
+
+ @testing.provide_metadata
def test_db_qualified_items(self):
metadata = self.metadata
Table("foo", metadata, Column("id", Integer, primary_key=True))
diff --git a/test/requirements.py b/test/requirements.py
index 1c2561bbb..145d87d75 100644
--- a/test/requirements.py
+++ b/test/requirements.py
@@ -257,15 +257,19 @@ class DefaultRequirements(SuiteRequirements):
@property
def temporary_tables(self):
"""target database supports temporary tables"""
- return skip_if(
- ["mssql", "firebird", self._sqlite_file_db], "not supported (?)"
- )
+ return skip_if(["firebird", self._sqlite_file_db], "not supported (?)")
@property
def temp_table_reflection(self):
return self.temporary_tables
@property
+ def temp_table_reflect_indexes(self):
+ return skip_if(
+ ["mssql", "firebird", self._sqlite_file_db], "not supported (?)"
+ )
+
+ @property
def reflectable_autoincrement(self):
"""Target database must support tables that can automatically generate
PKs assuming they were reflected.