diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2022-11-11 21:00:21 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-11-11 21:00:21 +0000 |
commit | 1dd0f23e8d74aa7edc8dd309093a95171e2e8f09 (patch) | |
tree | dda63840955464c225ecf0b09481275bbbf73ac9 | |
parent | 604611e7e522269ee11b314fb6fb75873a465494 (diff) | |
parent | 8e91cfe529b9b0150c16e52e22e4590bfbbe79fd (diff) | |
download | sqlalchemy-1dd0f23e8d74aa7edc8dd309093a95171e2e8f09.tar.gz |
Merge "establish consistency for RETURNING column labels" into main
-rw-r--r-- | doc/build/changelog/unreleased_14/8770.rst | 23 | ||||
-rw-r--r-- | doc/build/changelog/unreleased_20/8770.rst | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 19 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 23 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 8 | ||||
-rw-r--r-- | test/dialect/mssql/test_compiler.py | 29 | ||||
-rw-r--r-- | test/dialect/oracle/test_compiler.py | 31 | ||||
-rw-r--r-- | test/dialect/postgresql/test_compiler.py | 30 | ||||
-rw-r--r-- | test/sql/test_labels.py | 97 | ||||
-rw-r--r-- | test/sql/test_returning.py | 44 |
11 files changed, 310 insertions, 8 deletions
diff --git a/doc/build/changelog/unreleased_14/8770.rst b/doc/build/changelog/unreleased_14/8770.rst new file mode 100644 index 000000000..8968b0361 --- /dev/null +++ b/doc/build/changelog/unreleased_14/8770.rst @@ -0,0 +1,23 @@ +.. change:: + :tags: bug, postgresql, mssql + :tickets: 8770 + + For the PostgreSQL and SQL Server dialects only, adjusted the compiler so + that when rendering column expressions in the RETURNING clause, the "non + anon" label that's used in SELECT statements is suggested for SQL + expression elements that generate a label; the primary example is a SQL + function that may be emitting as part of the column's type, where the label + name should match the column's name by default. This restores a not-well + defined behavior that had changed in version 1.4.21 due to :ticket:`6718`, + :ticket:`6710`. The Oracle dialect has a different RETURNING implementation + and was not affected by this issue. Version 2.0 features an across the + board change for its widely expanded support of RETURNING on other + backends. + + +.. change:: + :tags: bug, oracle + + Fixed issue in the Oracle dialect where an INSERT statement that used + ``insert(some_table).values(...).returning(some_table)`` against a full + :class:`.Table` object at once would fail to execute, raising an exception. diff --git a/doc/build/changelog/unreleased_20/8770.rst b/doc/build/changelog/unreleased_20/8770.rst new file mode 100644 index 000000000..59b94d658 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8770.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, sql + :tickets: 8770 + + The RETURNING clause now renders columns using the routine as that of the + :class:`.Select` to generate labels, which will include disambiguating + labels, as well as that a SQL function surrounding a named column will be + labeled using the column name itself. This is a more comprehensive change + than a similar one made for the 1.4 series that adjusted the function label + issue only. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index a338ba27a..53fe96c9a 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2295,11 +2295,24 @@ class MSSQLCompiler(compiler.SQLCompiler): columns = [ self._label_returning_column( stmt, - adapter.traverse(c), + adapter.traverse(column), populate_result_map, - {"result_map_targets": (c,)}, + {"result_map_targets": (column,)}, + fallback_label_name=fallback_label_name, + column_is_repeated=repeated, + name=name, + proxy_name=proxy_name, + **kw, + ) + for ( + name, + proxy_name, + fallback_label_name, + column, + repeated, + ) in stmt._generate_columns_plus_names( + True, cols=expression._select_iterables(returning_cols) ) - for c in expression._select_iterables(returning_cols) ] return "OUTPUT " + ", ".join(columns) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3e62cb350..97397e9cf 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -3760,7 +3760,6 @@ class SQLCompiler(Compiled): "_label_select_column is only relevant within " "the columns clause of a SELECT or RETURNING" ) - if isinstance(column, elements.Label): if col_expr is not column: result_expr = _CompileLabel( @@ -4416,9 +4415,27 @@ class SQLCompiler(Compiled): populate_result_map: bool, **kw: Any, ) -> str: + columns = [ - self._label_returning_column(stmt, c, populate_result_map, **kw) - for c in base._select_iterables(returning_cols) + self._label_returning_column( + stmt, + column, + populate_result_map, + fallback_label_name=fallback_label_name, + column_is_repeated=repeated, + name=name, + proxy_name=proxy_name, + **kw, + ) + for ( + name, + proxy_name, + fallback_label_name, + column, + repeated, + ) in stmt._generate_columns_plus_names( + True, cols=base._select_iterables(returning_cols) + ) ] return "RETURNING " + ", ".join(columns) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 5145a4a16..2d3e3598b 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -59,6 +59,7 @@ from .selectable import FromClause from .selectable import HasCTE from .selectable import HasPrefixes from .selectable import Join +from .selectable import SelectLabelStyle from .selectable import TableClause from .selectable import TypedReturnsRows from .sqltypes import NullType @@ -399,6 +400,9 @@ class UpdateBase( ] = util.EMPTY_DICT named_with_column = False + _label_style: SelectLabelStyle = ( + SelectLabelStyle.LABEL_STYLE_DISAMBIGUATE_ONLY + ) table: _DMLTableElement _return_defaults = False diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 9de015774..488dfe721 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2193,7 +2193,9 @@ class SelectsRows(ReturnsRows): _label_style: SelectLabelStyle = LABEL_STYLE_NONE def _generate_columns_plus_names( - self, anon_for_dupe_key: bool + self, + anon_for_dupe_key: bool, + cols: Optional[_SelectIterable] = None, ) -> List[_ColumnsPlusNames]: """Generate column names as rendered in a SELECT statement by the compiler. @@ -2204,7 +2206,9 @@ class SelectsRows(ReturnsRows): _column_naming_convention as well. """ - cols = self._all_selected_columns + + if cols is None: + cols = self._all_selected_columns key_naming_convention = SelectState._column_naming_convention( self._label_style diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 8605ea9c0..b575595ac 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -36,6 +36,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing.assertions import eq_ignore_whitespace +from sqlalchemy.types import TypeEngine tbl = table("t", column("a")) @@ -119,6 +120,34 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "Latin1_General_CS_AS_KS_WS_CI ASC", ) + @testing.fixture + def column_expression_fixture(self): + class MyString(TypeEngine): + def column_expression(self, column): + return func.lower(column) + + return table( + "some_table", column("name", String), column("value", MyString) + ) + + @testing.combinations("columns", "table", argnames="use_columns") + def test_plain_returning_column_expression( + self, column_expression_fixture, use_columns + ): + """test #8770""" + table1 = column_expression_fixture + + if use_columns == "columns": + stmt = insert(table1).returning(table1) + else: + stmt = insert(table1).returning(table1.c.name, table1.c.value) + + self.assert_compile( + stmt, + "INSERT INTO some_table (name, value) OUTPUT inserted.name, " + "lower(inserted.value) AS value VALUES (:name, :value)", + ) + def test_join_with_hint(self): t1 = table( "t1", diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index 2973c6e39..8981e74e8 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -9,6 +9,7 @@ from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Identity from sqlalchemy import Index +from sqlalchemy import insert from sqlalchemy import Integer from sqlalchemy import literal from sqlalchemy import literal_column @@ -42,6 +43,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import eq_ignore_whitespace from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table +from sqlalchemy.types import TypeEngine class CompileTest(fixtures.TestBase, AssertsCompiledSQL): @@ -1359,6 +1361,35 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "t1.c2, t1.c3 INTO :ret_0, :ret_1", ) + @testing.fixture + def column_expression_fixture(self): + class MyString(TypeEngine): + def column_expression(self, column): + return func.lower(column) + + return table( + "some_table", column("name", String), column("value", MyString) + ) + + @testing.combinations("columns", "table", argnames="use_columns") + def test_plain_returning_column_expression( + self, column_expression_fixture, use_columns + ): + """test #8770""" + table1 = column_expression_fixture + + if use_columns == "columns": + stmt = insert(table1).returning(table1) + else: + stmt = insert(table1).returning(table1.c.name, table1.c.value) + + self.assert_compile( + stmt, + "INSERT INTO some_table (name, value) VALUES (:name, :value) " + "RETURNING some_table.name, lower(some_table.value) " + "INTO :ret_0, :ret_1", + ) + def test_returning_insert_computed(self): m = MetaData() t1 = Table( diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 96a8e7d5a..338d0da4e 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -61,6 +61,7 @@ from sqlalchemy.testing.assertions import AssertsCompiledSQL from sqlalchemy.testing.assertions import eq_ignore_whitespace from sqlalchemy.testing.assertions import expect_warnings from sqlalchemy.testing.assertions import is_ +from sqlalchemy.types import TypeEngine from sqlalchemy.util import OrderedDict @@ -200,6 +201,35 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): dialect=dialect, ) + @testing.fixture + def column_expression_fixture(self): + class MyString(TypeEngine): + def column_expression(self, column): + return func.lower(column) + + return table( + "some_table", column("name", String), column("value", MyString) + ) + + @testing.combinations("columns", "table", argnames="use_columns") + def test_plain_returning_column_expression( + self, column_expression_fixture, use_columns + ): + """test #8770""" + table1 = column_expression_fixture + + if use_columns == "columns": + stmt = insert(table1).returning(table1) + else: + stmt = insert(table1).returning(table1.c.name, table1.c.value) + + self.assert_compile( + stmt, + "INSERT INTO some_table (name, value) " + "VALUES (%(name)s, %(value)s) RETURNING some_table.name, " + "lower(some_table.value) AS value", + ) + def test_create_drop_enum(self): # test escaping and unicode within CREATE TYPE for ENUM typ = postgresql.ENUM("val1", "val2", "val's 3", "méil", name="myname") diff --git a/test/sql/test_labels.py b/test/sql/test_labels.py index 42d9c5f00..a74c5811c 100644 --- a/test/sql/test_labels.py +++ b/test/sql/test_labels.py @@ -2,6 +2,8 @@ from sqlalchemy import bindparam from sqlalchemy import Boolean from sqlalchemy import cast from sqlalchemy import exc as exceptions +from sqlalchemy import func +from sqlalchemy import insert from sqlalchemy import Integer from sqlalchemy import literal_column from sqlalchemy import MetaData @@ -32,6 +34,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import mock from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table +from sqlalchemy.types import TypeEngine IDENT_LENGTH = 29 @@ -827,6 +830,100 @@ class ColExprLabelTest(fixtures.TestBase, AssertsCompiledSQL): return SomeColThing + @testing.fixture + def compiler_column_fixture(self): + return self._fixture() + + @testing.fixture + def column_expression_fixture(self): + class MyString(TypeEngine): + def column_expression(self, column): + return func.lower(column) + + return table( + "some_table", column("name", String), column("value", MyString) + ) + + def test_plain_select_compiler_expression(self, compiler_column_fixture): + expr = compiler_column_fixture + table1 = self.table1 + + self.assert_compile( + select( + table1.c.name, + expr(table1.c.value), + ), + "SELECT some_table.name, SOME_COL_THING(some_table.value) " + "AS value FROM some_table", + ) + + def test_plain_select_column_expression(self, column_expression_fixture): + table1 = column_expression_fixture + + self.assert_compile( + select(table1), + "SELECT some_table.name, lower(some_table.value) AS value " + "FROM some_table", + ) + + def test_plain_returning_compiler_expression( + self, compiler_column_fixture + ): + expr = compiler_column_fixture + table1 = self.table1 + + self.assert_compile( + insert(table1).returning( + table1.c.name, + expr(table1.c.value), + ), + "INSERT INTO some_table (name, value) VALUES (:name, :value) " + "RETURNING some_table.name, " + "SOME_COL_THING(some_table.value) AS value", + ) + + @testing.combinations("columns", "table", argnames="use_columns") + def test_plain_returning_column_expression( + self, column_expression_fixture, use_columns + ): + table1 = column_expression_fixture + + if use_columns == "columns": + stmt = insert(table1).returning(table1) + else: + stmt = insert(table1).returning(table1.c.name, table1.c.value) + + self.assert_compile( + stmt, + "INSERT INTO some_table (name, value) VALUES (:name, :value) " + "RETURNING some_table.name, lower(some_table.value) AS value", + ) + + def test_select_dupes_column_expression(self, column_expression_fixture): + table1 = column_expression_fixture + + self.assert_compile( + select(table1.c.name, table1.c.value, table1.c.value), + "SELECT some_table.name, lower(some_table.value) AS value, " + "lower(some_table.value) AS value__1 FROM some_table", + ) + + def test_returning_dupes_column_expression( + self, column_expression_fixture + ): + table1 = column_expression_fixture + + stmt = insert(table1).returning( + table1.c.name, table1.c.value, table1.c.value + ) + + self.assert_compile( + stmt, + "INSERT INTO some_table (name, value) VALUES (:name, :value) " + "RETURNING some_table.name, lower(some_table.value) AS value, " + "lower(some_table.value) AS value__1", + ) + def test_column_auto_label_dupes_label_style_none(self): expr = self._fixture() table1 = self.table1 diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index 32d4c7740..e0299e334 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -415,6 +415,50 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): result = connection.execute(ins) eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)]) + @testing.fixture + def column_expression_fixture(self, metadata, connection): + class MyString(TypeDecorator): + cache_ok = True + impl = String(50) + + def column_expression(self, column): + return func.lower(column) + + t1 = Table( + "some_table", + metadata, + Column("name", String(50)), + Column("value", MyString(50)), + ) + metadata.create_all(connection) + return t1 + + @testing.combinations("columns", "table", argnames="use_columns") + def test_plain_returning_column_expression( + self, column_expression_fixture, use_columns, connection + ): + """test #8770""" + table1 = column_expression_fixture + + if use_columns == "columns": + stmt = ( + insert(table1) + .values(name="n1", value="ValUE1") + .returning(table1) + ) + else: + stmt = ( + insert(table1) + .values(name="n1", value="ValUE1") + .returning(table1.c.name, table1.c.value) + ) + + result = connection.execute(stmt) + row = result.first() + + eq_(row._mapping["name"], "n1") + eq_(row._mapping["value"], "value1") + @testing.fails_on_everything_except( "postgresql", "mariadb>=10.5", "sqlite>=3.34" ) |