diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2020-09-01 17:12:11 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-09-01 17:12:11 +0000 |
commit | 7439697ca3c8b07af79de6a146a15c34fc85d9d6 (patch) | |
tree | ec0f4a477276017f2b8e48646e47ca64c13c83ec | |
parent | d61cf0a9ad91edfcb569214a19122a9572fbb29b (diff) | |
parent | 516131c40da9c8cd304061850e2d98e309966dd5 (diff) | |
download | sqlalchemy-7439697ca3c8b07af79de6a146a15c34fc85d9d6.tar.gz |
Merge "Improve reflection for mssql temporary tables"
-rw-r--r-- | doc/build/changelog/unreleased_14/5506.rst | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 37 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mssql/information_schema.py | 22 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mssql/provision.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/provision.py | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_reflection.py | 14 | ||||
-rw-r--r-- | test/dialect/mssql/test_reflection.py | 49 | ||||
-rw-r--r-- | test/requirements.py | 10 |
8 files changed, 154 insertions, 12 deletions
diff --git a/doc/build/changelog/unreleased_14/5506.rst b/doc/build/changelog/unreleased_14/5506.rst new file mode 100644 index 000000000..71b57322d --- /dev/null +++ b/doc/build/changelog/unreleased_14/5506.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: usecase, mssql + :tickets: 5506 + + Added support for reflection of temporary tables with the SQL Server dialect. + Table names that are prefixed by a pound sign "#" are now introspected from + the MSSQL "tempdb" system catalog. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index f38c537fd..ed17fb863 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2913,11 +2913,46 @@ class MSDialect(default.DefaultDialect): view_def = rp.scalar() return view_def + def _get_internal_temp_table_name(self, connection, tablename): + result = connection.execute( + sql.text( + "select table_name " + "from tempdb.information_schema.tables " + "where table_name like :p1" + ), + { + "p1": tablename + + (("___%") if not tablename.startswith("##") else "") + }, + ).fetchall() + if len(result) > 1: + raise exc.UnreflectableTableError( + "Found more than one temporary table named '%s' in tempdb " + "at this time. Cannot reliably resolve that name to its " + "internal table name." % tablename + ) + elif len(result) == 0: + raise exc.NoSuchTableError( + "Unable to find a temporary table named '%s' in tempdb." + % tablename + ) + else: + return result[0][0] + @reflection.cache @_db_plus_owner def get_columns(self, connection, tablename, dbname, owner, schema, **kw): + is_temp_table = tablename.startswith("#") + if is_temp_table: + tablename = self._get_internal_temp_table_name( + connection, tablename + ) # Get base columns - columns = ischema.columns + columns = ( + ischema.mssql_temp_table_columns + if is_temp_table + else ischema.columns + ) computed_cols = ischema.computed_columns if owner: whereclause = sql.and_( diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index 6cdde8386..f80110b7d 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -5,9 +5,6 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -# TODO: should be using the sys. catalog with SQL Server, not information -# schema - from ... import cast from ... import Column from ... import MetaData @@ -93,6 +90,25 @@ columns = Table( schema="INFORMATION_SCHEMA", ) +mssql_temp_table_columns = Table( + "COLUMNS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column( + "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length" + ), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="tempdb.INFORMATION_SCHEMA", +) + constraints = Table( "TABLE_CONSTRAINTS", ischema, diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py index a5131eae6..269eb164f 100644 --- a/lib/sqlalchemy/dialects/mssql/provision.py +++ b/lib/sqlalchemy/dialects/mssql/provision.py @@ -2,8 +2,10 @@ from ... import create_engine from ... import exc from ...testing.provision import create_db from ...testing.provision import drop_db +from ...testing.provision import get_temp_table_name from ...testing.provision import log from ...testing.provision import run_reap_dbs +from ...testing.provision import temp_table_keyword_args @create_db.for_db("mssql") @@ -72,3 +74,13 @@ def _reap_mssql_dbs(url, idents): log.info( "Dropped %d out of %d stale databases detected", dropped, total ) + + +@temp_table_keyword_args.for_db("mssql") +def _mssql_temp_table_keyword_args(cfg, eng): + return {} + + +@get_temp_table_name.for_db("mssql") +def _mssql_get_temp_table_name(cfg, eng, base_name): + return "#" + base_name diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 0edaae490..8bdad357c 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -296,3 +296,18 @@ def temp_table_keyword_args(cfg, eng): raise NotImplementedError( "no temp table keyword args routine for cfg: %s" % eng.url ) + + +@register.init +def get_temp_table_name(cfg, eng, base_name): + """Specify table name for creating a temporary Table. + + Dialect-specific implementations of this method will return the + name to use when creating a temporary table for testing, + e.g., in the define_temp_tables method of the + ComponentReflectionTest class in suite/test_reflection.py + + Default to just the base name since that's what most dialects will + use. The mssql dialect's implementation will need a "#" prepended. + """ + return base_name diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 151be757a..94ec22c1e 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -8,6 +8,7 @@ from .. import eq_ from .. import expect_warnings from .. import fixtures from .. import is_ +from ..provision import get_temp_table_name from ..provision import temp_table_keyword_args from ..schema import Column from ..schema import Table @@ -442,8 +443,9 @@ class ComponentReflectionTest(fixtures.TablesTest): @classmethod def define_temp_tables(cls, metadata): kw = temp_table_keyword_args(config, config.db) + table_name = get_temp_table_name(config, config.db, "user_tmp") user_tmp = Table( - "user_tmp", + table_name, metadata, Column("id", sa.INT, primary_key=True), Column("name", sa.VARCHAR(50)), @@ -736,10 +738,11 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.temp_table_reflection def test_get_temp_table_columns(self): + table_name = get_temp_table_name(config, config.db, "user_tmp") meta = MetaData(self.bind) - user_tmp = self.tables.user_tmp + user_tmp = self.tables[table_name] insp = inspect(meta.bind) - cols = insp.get_columns("user_tmp") + cols = insp.get_columns(table_name) self.assert_(len(cols) > 0, len(cols)) for i, col in enumerate(user_tmp.columns): @@ -1051,10 +1054,11 @@ class ComponentReflectionTest(fixtures.TablesTest): refl.pop("duplicates_index", None) eq_(reflected, [{"column_names": ["name"], "name": "user_tmp_uq"}]) - @testing.requires.temp_table_reflection + @testing.requires.temp_table_reflect_indexes def test_get_temp_table_indexes(self): insp = inspect(self.bind) - indexes = insp.get_indexes("user_tmp") + table_name = get_temp_table_name(config, config.db, "user_tmp") + indexes = insp.get_indexes(table_name) for ind in indexes: ind.pop("dialect_options", None) eq_( diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index 0bd8f7a5a..67bde6fb3 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -1,7 +1,10 @@ # -*- encoding: utf-8 +import datetime + from sqlalchemy import Column from sqlalchemy import DDL from sqlalchemy import event +from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import Index from sqlalchemy import inspect @@ -246,6 +249,52 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): eq_(t.name, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") @testing.provide_metadata + @testing.combinations( + ("local_temp", "#tmp", True), + ("global_temp", "##tmp", True), + ("nonexistent", "#no_es_bueno", False), + id_="iaa", + ) + def test_temporary_table(self, table_name, exists): + metadata = self.metadata + if exists: + # TODO: why this test hangs when using the connection fixture? + with testing.db.connect() as conn: + tran = conn.begin() + conn.execute( + ( + "CREATE TABLE %s " + "(id int primary key, txt nvarchar(50), dt2 datetime2)" # noqa + ) + % table_name + ) + conn.execute( + ( + "INSERT INTO %s (id, txt, dt2) VALUES " + "(1, N'foo', '2020-01-01 01:01:01'), " + "(2, N'bar', '2020-02-02 02:02:02') " + ) + % table_name + ) + tran.commit() + tran = conn.begin() + try: + tmp_t = Table( + table_name, metadata, autoload_with=testing.db, + ) + tran.commit() + result = conn.execute( + tmp_t.select().where(tmp_t.c.id == 2) + ).fetchall() + eq_( + result, + [(2, "bar", datetime.datetime(2020, 2, 2, 2, 2, 2))], + ) + except exc.NoSuchTableError: + if exists: + raise + + @testing.provide_metadata def test_db_qualified_items(self): metadata = self.metadata Table("foo", metadata, Column("id", Integer, primary_key=True)) diff --git a/test/requirements.py b/test/requirements.py index 1c2561bbb..145d87d75 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -257,15 +257,19 @@ class DefaultRequirements(SuiteRequirements): @property def temporary_tables(self): """target database supports temporary tables""" - return skip_if( - ["mssql", "firebird", self._sqlite_file_db], "not supported (?)" - ) + return skip_if(["firebird", self._sqlite_file_db], "not supported (?)") @property def temp_table_reflection(self): return self.temporary_tables @property + def temp_table_reflect_indexes(self): + return skip_if( + ["mssql", "firebird", self._sqlite_file_db], "not supported (?)" + ) + + @property def reflectable_autoincrement(self): """Target database must support tables that can automatically generate PKs assuming they were reflected. |