diff options
-rw-r--r-- | doc/build/changelog/unreleased_13/4883.rst | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 15 | ||||
-rw-r--r-- | test/dialect/mssql/test_compiler.py | 16 | ||||
-rw-r--r-- | test/dialect/mssql/test_reflection.py | 83 |
4 files changed, 106 insertions, 19 deletions
diff --git a/doc/build/changelog/unreleased_13/4883.rst b/doc/build/changelog/unreleased_13/4883.rst new file mode 100644 index 000000000..161dbf146 --- /dev/null +++ b/doc/build/changelog/unreleased_13/4883.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, mssql + :tickets: 4883 + + Added identifier quoting to the schema name applied to the "use" statement + which is invoked when a SQL Server multipart schema name is used within a + :class:`.Table` that is being reflected, as well as for :class:`.Inspector` + methods such as :meth:`.Inspector.get_table_names`; this accommodates for + special characters or spaces in the database name. Additionally, the "use" + statement is not emitted if the current database matches the target owner + database name being passed. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 54f0043c4..d4d303d5d 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2174,12 +2174,21 @@ def _db_plus_owner(fn): def _switch_db(dbname, connection, fn, *arg, **kw): if dbname: current_db = connection.scalar("select db_name()") - connection.execute("use %s" % dbname) + if current_db != dbname: + connection.execute( + "use %s" + % connection.dialect.identifier_preparer.quote_schema(dbname) + ) try: return fn(*arg, **kw) finally: - if dbname: - connection.execute("use %s" % current_db) + if dbname and current_db != dbname: + connection.execute( + "use %s" + % connection.dialect.identifier_preparer.quote_schema( + current_db + ) + ) def _owner_plus_db(dialect, schema): diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 4f656a36c..00a8a08fc 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -20,7 +20,6 @@ from sqlalchemy import union from sqlalchemy import UniqueConstraint from sqlalchemy import update from sqlalchemy.dialects import mssql -from sqlalchemy.dialects.mssql import base from sqlalchemy.dialects.mssql import mxodbc from sqlalchemy.dialects.mssql.base import try_cast from sqlalchemy.sql import column @@ -480,21 +479,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): select([tbl]), "SELECT [Foo].dbo.test.id FROM [Foo].dbo.test" ) - def test_owner_database_pairs(self): - dialect = mssql.dialect() - - for identifier, expected_schema, expected_owner in [ - ("foo", None, "foo"), - ("foo.bar", "foo", "bar"), - ("Foo.Bar", "Foo", "Bar"), - ("[Foo.Bar]", None, "Foo.Bar"), - ("[Foo.Bar].[bat]", "Foo.Bar", "bat"), - ]: - schema, owner = base._owner_plus_db(dialect, identifier) - - eq_(owner, expected_owner) - eq_(schema, expected_schema) - def test_delete_schema(self): metadata = MetaData() tbl = Table( diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index 937f2bb2a..24c4a6455 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -418,3 +418,86 @@ class ReflectHugeViewTest(fixtures.TestBase): inspector = Inspector.from_engine(testing.db) view_def = inspector.get_view_definition("huge_named_view") eq_(view_def, self.view_str) + + +class OwnerPlusDBTest(fixtures.TestBase): + def test_owner_database_pairs_dont_use_for_same_db(self): + dialect = mssql.dialect() + + identifier = "my_db.some_schema" + schema, owner = base._owner_plus_db(dialect, identifier) + + mock_connection = mock.Mock( + dialect=dialect, scalar=mock.Mock(return_value="my_db") + ) + mock_lambda = mock.Mock() + base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar") + eq_(mock_connection.mock_calls, [mock.call.scalar("select db_name()")]) + eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")]) + + def test_owner_database_pairs_switch_for_different_db(self): + dialect = mssql.dialect() + + identifier = "my_other_db.some_schema" + schema, owner = base._owner_plus_db(dialect, identifier) + + mock_connection = mock.Mock( + dialect=dialect, scalar=mock.Mock(return_value="my_db") + ) + mock_lambda = mock.Mock() + base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar") + eq_( + mock_connection.mock_calls, + [ + mock.call.scalar("select db_name()"), + mock.call.execute("use my_other_db"), + mock.call.execute("use my_db"), + ], + ) + eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")]) + + def test_owner_database_pairs(self): + dialect = mssql.dialect() + + for identifier, expected_schema, expected_owner, use_stmt in [ + ("foo", None, "foo", "use foo"), + ("foo.bar", "foo", "bar", "use foo"), + ("Foo.Bar", "Foo", "Bar", "use [Foo]"), + ("[Foo.Bar]", None, "Foo.Bar", "use [Foo].[Bar]"), + ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo].[Bar]"), + ( + "[foo].]do something; select [foo", + "foo", + "do something; select foo", + "use foo", + ), + ( + "something; select [foo].bar", + "something; select foo", + "bar", + "use [something; select foo]", + ), + ]: + schema, owner = base._owner_plus_db(dialect, identifier) + + eq_(owner, expected_owner) + eq_(schema, expected_schema) + + mock_connection = mock.Mock( + dialect=dialect, + scalar=mock.Mock(return_value="Some ] Database"), + ) + mock_lambda = mock.Mock() + base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar") + if schema is None: + eq_(mock_connection.mock_calls, []) + else: + eq_( + mock_connection.mock_calls, + [ + mock.call.scalar("select db_name()"), + mock.call.execute(use_stmt), + mock.call.execute("use [Some Database]"), + ], + ) + eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")]) |