summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_13/4883.rst11
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py15
-rw-r--r--test/dialect/mssql/test_compiler.py16
-rw-r--r--test/dialect/mssql/test_reflection.py83
4 files changed, 106 insertions, 19 deletions
diff --git a/doc/build/changelog/unreleased_13/4883.rst b/doc/build/changelog/unreleased_13/4883.rst
new file mode 100644
index 000000000..161dbf146
--- /dev/null
+++ b/doc/build/changelog/unreleased_13/4883.rst
@@ -0,0 +1,11 @@
+.. change::
+ :tags: bug, mssql
+ :tickets: 4883
+
+ Added identifier quoting to the schema name applied to the "use" statement
+ which is invoked when a SQL Server multipart schema name is used within a
+ :class:`.Table` that is being reflected, as well as for :class:`.Inspector`
+ methods such as :meth:`.Inspector.get_table_names`; this accommodates for
+ special characters or spaces in the database name. Additionally, the "use"
+ statement is not emitted if the current database matches the target owner
+ database name being passed.
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 54f0043c4..d4d303d5d 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -2174,12 +2174,21 @@ def _db_plus_owner(fn):
def _switch_db(dbname, connection, fn, *arg, **kw):
if dbname:
current_db = connection.scalar("select db_name()")
- connection.execute("use %s" % dbname)
+ if current_db != dbname:
+ connection.execute(
+ "use %s"
+ % connection.dialect.identifier_preparer.quote_schema(dbname)
+ )
try:
return fn(*arg, **kw)
finally:
- if dbname:
- connection.execute("use %s" % current_db)
+ if dbname and current_db != dbname:
+ connection.execute(
+ "use %s"
+ % connection.dialect.identifier_preparer.quote_schema(
+ current_db
+ )
+ )
def _owner_plus_db(dialect, schema):
diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py
index 4f656a36c..00a8a08fc 100644
--- a/test/dialect/mssql/test_compiler.py
+++ b/test/dialect/mssql/test_compiler.py
@@ -20,7 +20,6 @@ from sqlalchemy import union
from sqlalchemy import UniqueConstraint
from sqlalchemy import update
from sqlalchemy.dialects import mssql
-from sqlalchemy.dialects.mssql import base
from sqlalchemy.dialects.mssql import mxodbc
from sqlalchemy.dialects.mssql.base import try_cast
from sqlalchemy.sql import column
@@ -480,21 +479,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
select([tbl]), "SELECT [Foo].dbo.test.id FROM [Foo].dbo.test"
)
- def test_owner_database_pairs(self):
- dialect = mssql.dialect()
-
- for identifier, expected_schema, expected_owner in [
- ("foo", None, "foo"),
- ("foo.bar", "foo", "bar"),
- ("Foo.Bar", "Foo", "Bar"),
- ("[Foo.Bar]", None, "Foo.Bar"),
- ("[Foo.Bar].[bat]", "Foo.Bar", "bat"),
- ]:
- schema, owner = base._owner_plus_db(dialect, identifier)
-
- eq_(owner, expected_owner)
- eq_(schema, expected_schema)
-
def test_delete_schema(self):
metadata = MetaData()
tbl = Table(
diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py
index 937f2bb2a..24c4a6455 100644
--- a/test/dialect/mssql/test_reflection.py
+++ b/test/dialect/mssql/test_reflection.py
@@ -418,3 +418,86 @@ class ReflectHugeViewTest(fixtures.TestBase):
inspector = Inspector.from_engine(testing.db)
view_def = inspector.get_view_definition("huge_named_view")
eq_(view_def, self.view_str)
+
+
+class OwnerPlusDBTest(fixtures.TestBase):
+ def test_owner_database_pairs_dont_use_for_same_db(self):
+ dialect = mssql.dialect()
+
+ identifier = "my_db.some_schema"
+ schema, owner = base._owner_plus_db(dialect, identifier)
+
+ mock_connection = mock.Mock(
+ dialect=dialect, scalar=mock.Mock(return_value="my_db")
+ )
+ mock_lambda = mock.Mock()
+ base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar")
+ eq_(mock_connection.mock_calls, [mock.call.scalar("select db_name()")])
+ eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])
+
+ def test_owner_database_pairs_switch_for_different_db(self):
+ dialect = mssql.dialect()
+
+ identifier = "my_other_db.some_schema"
+ schema, owner = base._owner_plus_db(dialect, identifier)
+
+ mock_connection = mock.Mock(
+ dialect=dialect, scalar=mock.Mock(return_value="my_db")
+ )
+ mock_lambda = mock.Mock()
+ base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar")
+ eq_(
+ mock_connection.mock_calls,
+ [
+ mock.call.scalar("select db_name()"),
+ mock.call.execute("use my_other_db"),
+ mock.call.execute("use my_db"),
+ ],
+ )
+ eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])
+
+ def test_owner_database_pairs(self):
+ dialect = mssql.dialect()
+
+ for identifier, expected_schema, expected_owner, use_stmt in [
+ ("foo", None, "foo", "use foo"),
+ ("foo.bar", "foo", "bar", "use foo"),
+ ("Foo.Bar", "Foo", "Bar", "use [Foo]"),
+ ("[Foo.Bar]", None, "Foo.Bar", "use [Foo].[Bar]"),
+ ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo].[Bar]"),
+ (
+ "[foo].]do something; select [foo",
+ "foo",
+ "do something; select foo",
+ "use foo",
+ ),
+ (
+ "something; select [foo].bar",
+ "something; select foo",
+ "bar",
+ "use [something; select foo]",
+ ),
+ ]:
+ schema, owner = base._owner_plus_db(dialect, identifier)
+
+ eq_(owner, expected_owner)
+ eq_(schema, expected_schema)
+
+ mock_connection = mock.Mock(
+ dialect=dialect,
+ scalar=mock.Mock(return_value="Some ] Database"),
+ )
+ mock_lambda = mock.Mock()
+ base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar")
+ if schema is None:
+ eq_(mock_connection.mock_calls, [])
+ else:
+ eq_(
+ mock_connection.mock_calls,
+ [
+ mock.call.scalar("select db_name()"),
+ mock.call.execute(use_stmt),
+ mock.call.execute("use [Some Database]"),
+ ],
+ )
+ eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])