summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-11-11 21:00:21 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-11-11 21:00:21 +0000
commit1dd0f23e8d74aa7edc8dd309093a95171e2e8f09 (patch)
treedda63840955464c225ecf0b09481275bbbf73ac9
parent604611e7e522269ee11b314fb6fb75873a465494 (diff)
parent8e91cfe529b9b0150c16e52e22e4590bfbbe79fd (diff)
downloadsqlalchemy-1dd0f23e8d74aa7edc8dd309093a95171e2e8f09.tar.gz
Merge "establish consistency for RETURNING column labels" into main
-rw-r--r--doc/build/changelog/unreleased_14/8770.rst23
-rw-r--r--doc/build/changelog/unreleased_20/8770.rst10
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py19
-rw-r--r--lib/sqlalchemy/sql/compiler.py23
-rw-r--r--lib/sqlalchemy/sql/dml.py4
-rw-r--r--lib/sqlalchemy/sql/selectable.py8
-rw-r--r--test/dialect/mssql/test_compiler.py29
-rw-r--r--test/dialect/oracle/test_compiler.py31
-rw-r--r--test/dialect/postgresql/test_compiler.py30
-rw-r--r--test/sql/test_labels.py97
-rw-r--r--test/sql/test_returning.py44
11 files changed, 310 insertions, 8 deletions
diff --git a/doc/build/changelog/unreleased_14/8770.rst b/doc/build/changelog/unreleased_14/8770.rst
new file mode 100644
index 000000000..8968b0361
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/8770.rst
@@ -0,0 +1,23 @@
+.. change::
+ :tags: bug, postgresql, mssql
+ :tickets: 8770
+
+ For the PostgreSQL and SQL Server dialects only, adjusted the compiler so
+ that when rendering column expressions in the RETURNING clause, the "non
+ anon" label that's used in SELECT statements is suggested for SQL
+ expression elements that generate a label; the primary example is a SQL
+ function that may be emitting as part of the column's type, where the label
+ name should match the column's name by default. This restores a not-well
+ defined behavior that had changed in version 1.4.21 due to :ticket:`6718`,
+ :ticket:`6710`. The Oracle dialect has a different RETURNING implementation
+ and was not affected by this issue. Version 2.0 features an across the
+ board change for its widely expanded support of RETURNING on other
+ backends.
+
+
+.. change::
+ :tags: bug, oracle
+
+ Fixed issue in the Oracle dialect where an INSERT statement that used
+ ``insert(some_table).values(...).returning(some_table)`` against a full
+ :class:`.Table` object at once would fail to execute, raising an exception.
diff --git a/doc/build/changelog/unreleased_20/8770.rst b/doc/build/changelog/unreleased_20/8770.rst
new file mode 100644
index 000000000..59b94d658
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/8770.rst
@@ -0,0 +1,10 @@
+.. change::
+ :tags: bug, sql
+ :tickets: 8770
+
+ The RETURNING clause now renders columns using the routine as that of the
+ :class:`.Select` to generate labels, which will include disambiguating
+ labels, as well as that a SQL function surrounding a named column will be
+ labeled using the column name itself. This is a more comprehensive change
+ than a similar one made for the 1.4 series that adjusted the function label
+ issue only.
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index a338ba27a..53fe96c9a 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -2295,11 +2295,24 @@ class MSSQLCompiler(compiler.SQLCompiler):
columns = [
self._label_returning_column(
stmt,
- adapter.traverse(c),
+ adapter.traverse(column),
populate_result_map,
- {"result_map_targets": (c,)},
+ {"result_map_targets": (column,)},
+ fallback_label_name=fallback_label_name,
+ column_is_repeated=repeated,
+ name=name,
+ proxy_name=proxy_name,
+ **kw,
+ )
+ for (
+ name,
+ proxy_name,
+ fallback_label_name,
+ column,
+ repeated,
+ ) in stmt._generate_columns_plus_names(
+ True, cols=expression._select_iterables(returning_cols)
)
- for c in expression._select_iterables(returning_cols)
]
return "OUTPUT " + ", ".join(columns)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 3e62cb350..97397e9cf 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -3760,7 +3760,6 @@ class SQLCompiler(Compiled):
"_label_select_column is only relevant within "
"the columns clause of a SELECT or RETURNING"
)
-
if isinstance(column, elements.Label):
if col_expr is not column:
result_expr = _CompileLabel(
@@ -4416,9 +4415,27 @@ class SQLCompiler(Compiled):
populate_result_map: bool,
**kw: Any,
) -> str:
+
columns = [
- self._label_returning_column(stmt, c, populate_result_map, **kw)
- for c in base._select_iterables(returning_cols)
+ self._label_returning_column(
+ stmt,
+ column,
+ populate_result_map,
+ fallback_label_name=fallback_label_name,
+ column_is_repeated=repeated,
+ name=name,
+ proxy_name=proxy_name,
+ **kw,
+ )
+ for (
+ name,
+ proxy_name,
+ fallback_label_name,
+ column,
+ repeated,
+ ) in stmt._generate_columns_plus_names(
+ True, cols=base._select_iterables(returning_cols)
+ )
]
return "RETURNING " + ", ".join(columns)
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 5145a4a16..2d3e3598b 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -59,6 +59,7 @@ from .selectable import FromClause
from .selectable import HasCTE
from .selectable import HasPrefixes
from .selectable import Join
+from .selectable import SelectLabelStyle
from .selectable import TableClause
from .selectable import TypedReturnsRows
from .sqltypes import NullType
@@ -399,6 +400,9 @@ class UpdateBase(
] = util.EMPTY_DICT
named_with_column = False
+ _label_style: SelectLabelStyle = (
+ SelectLabelStyle.LABEL_STYLE_DISAMBIGUATE_ONLY
+ )
table: _DMLTableElement
_return_defaults = False
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 9de015774..488dfe721 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -2193,7 +2193,9 @@ class SelectsRows(ReturnsRows):
_label_style: SelectLabelStyle = LABEL_STYLE_NONE
def _generate_columns_plus_names(
- self, anon_for_dupe_key: bool
+ self,
+ anon_for_dupe_key: bool,
+ cols: Optional[_SelectIterable] = None,
) -> List[_ColumnsPlusNames]:
"""Generate column names as rendered in a SELECT statement by
the compiler.
@@ -2204,7 +2206,9 @@ class SelectsRows(ReturnsRows):
_column_naming_convention as well.
"""
- cols = self._all_selected_columns
+
+ if cols is None:
+ cols = self._all_selected_columns
key_naming_convention = SelectState._column_naming_convention(
self._label_style
diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py
index 8605ea9c0..b575595ac 100644
--- a/test/dialect/mssql/test_compiler.py
+++ b/test/dialect/mssql/test_compiler.py
@@ -36,6 +36,7 @@ from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing.assertions import eq_ignore_whitespace
+from sqlalchemy.types import TypeEngine
tbl = table("t", column("a"))
@@ -119,6 +120,34 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
"Latin1_General_CS_AS_KS_WS_CI ASC",
)
+ @testing.fixture
+ def column_expression_fixture(self):
+ class MyString(TypeEngine):
+ def column_expression(self, column):
+ return func.lower(column)
+
+ return table(
+ "some_table", column("name", String), column("value", MyString)
+ )
+
+ @testing.combinations("columns", "table", argnames="use_columns")
+ def test_plain_returning_column_expression(
+ self, column_expression_fixture, use_columns
+ ):
+ """test #8770"""
+ table1 = column_expression_fixture
+
+ if use_columns == "columns":
+ stmt = insert(table1).returning(table1)
+ else:
+ stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+ self.assert_compile(
+ stmt,
+ "INSERT INTO some_table (name, value) OUTPUT inserted.name, "
+ "lower(inserted.value) AS value VALUES (:name, :value)",
+ )
+
def test_join_with_hint(self):
t1 = table(
"t1",
diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py
index 2973c6e39..8981e74e8 100644
--- a/test/dialect/oracle/test_compiler.py
+++ b/test/dialect/oracle/test_compiler.py
@@ -9,6 +9,7 @@ from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Identity
from sqlalchemy import Index
+from sqlalchemy import insert
from sqlalchemy import Integer
from sqlalchemy import literal
from sqlalchemy import literal_column
@@ -42,6 +43,7 @@ from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import eq_ignore_whitespace
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
+from sqlalchemy.types import TypeEngine
class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
@@ -1359,6 +1361,35 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
"t1.c2, t1.c3 INTO :ret_0, :ret_1",
)
+ @testing.fixture
+ def column_expression_fixture(self):
+ class MyString(TypeEngine):
+ def column_expression(self, column):
+ return func.lower(column)
+
+ return table(
+ "some_table", column("name", String), column("value", MyString)
+ )
+
+ @testing.combinations("columns", "table", argnames="use_columns")
+ def test_plain_returning_column_expression(
+ self, column_expression_fixture, use_columns
+ ):
+ """test #8770"""
+ table1 = column_expression_fixture
+
+ if use_columns == "columns":
+ stmt = insert(table1).returning(table1)
+ else:
+ stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+ self.assert_compile(
+ stmt,
+ "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+ "RETURNING some_table.name, lower(some_table.value) "
+ "INTO :ret_0, :ret_1",
+ )
+
def test_returning_insert_computed(self):
m = MetaData()
t1 = Table(
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index 96a8e7d5a..338d0da4e 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -61,6 +61,7 @@ from sqlalchemy.testing.assertions import AssertsCompiledSQL
from sqlalchemy.testing.assertions import eq_ignore_whitespace
from sqlalchemy.testing.assertions import expect_warnings
from sqlalchemy.testing.assertions import is_
+from sqlalchemy.types import TypeEngine
from sqlalchemy.util import OrderedDict
@@ -200,6 +201,35 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
dialect=dialect,
)
+ @testing.fixture
+ def column_expression_fixture(self):
+ class MyString(TypeEngine):
+ def column_expression(self, column):
+ return func.lower(column)
+
+ return table(
+ "some_table", column("name", String), column("value", MyString)
+ )
+
+ @testing.combinations("columns", "table", argnames="use_columns")
+ def test_plain_returning_column_expression(
+ self, column_expression_fixture, use_columns
+ ):
+ """test #8770"""
+ table1 = column_expression_fixture
+
+ if use_columns == "columns":
+ stmt = insert(table1).returning(table1)
+ else:
+ stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+ self.assert_compile(
+ stmt,
+ "INSERT INTO some_table (name, value) "
+ "VALUES (%(name)s, %(value)s) RETURNING some_table.name, "
+ "lower(some_table.value) AS value",
+ )
+
def test_create_drop_enum(self):
# test escaping and unicode within CREATE TYPE for ENUM
typ = postgresql.ENUM("val1", "val2", "val's 3", "méil", name="myname")
diff --git a/test/sql/test_labels.py b/test/sql/test_labels.py
index 42d9c5f00..a74c5811c 100644
--- a/test/sql/test_labels.py
+++ b/test/sql/test_labels.py
@@ -2,6 +2,8 @@ from sqlalchemy import bindparam
from sqlalchemy import Boolean
from sqlalchemy import cast
from sqlalchemy import exc as exceptions
+from sqlalchemy import func
+from sqlalchemy import insert
from sqlalchemy import Integer
from sqlalchemy import literal_column
from sqlalchemy import MetaData
@@ -32,6 +34,7 @@ from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
+from sqlalchemy.types import TypeEngine
IDENT_LENGTH = 29
@@ -827,6 +830,100 @@ class ColExprLabelTest(fixtures.TestBase, AssertsCompiledSQL):
return SomeColThing
+ @testing.fixture
+ def compiler_column_fixture(self):
+ return self._fixture()
+
+ @testing.fixture
+ def column_expression_fixture(self):
+ class MyString(TypeEngine):
+ def column_expression(self, column):
+ return func.lower(column)
+
+ return table(
+ "some_table", column("name", String), column("value", MyString)
+ )
+
+ def test_plain_select_compiler_expression(self, compiler_column_fixture):
+ expr = compiler_column_fixture
+ table1 = self.table1
+
+ self.assert_compile(
+ select(
+ table1.c.name,
+ expr(table1.c.value),
+ ),
+ "SELECT some_table.name, SOME_COL_THING(some_table.value) "
+ "AS value FROM some_table",
+ )
+
+ def test_plain_select_column_expression(self, column_expression_fixture):
+ table1 = column_expression_fixture
+
+ self.assert_compile(
+ select(table1),
+ "SELECT some_table.name, lower(some_table.value) AS value "
+ "FROM some_table",
+ )
+
+ def test_plain_returning_compiler_expression(
+ self, compiler_column_fixture
+ ):
+ expr = compiler_column_fixture
+ table1 = self.table1
+
+ self.assert_compile(
+ insert(table1).returning(
+ table1.c.name,
+ expr(table1.c.value),
+ ),
+ "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+ "RETURNING some_table.name, "
+ "SOME_COL_THING(some_table.value) AS value",
+ )
+
+ @testing.combinations("columns", "table", argnames="use_columns")
+ def test_plain_returning_column_expression(
+ self, column_expression_fixture, use_columns
+ ):
+ table1 = column_expression_fixture
+
+ if use_columns == "columns":
+ stmt = insert(table1).returning(table1)
+ else:
+ stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+ self.assert_compile(
+ stmt,
+ "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+ "RETURNING some_table.name, lower(some_table.value) AS value",
+ )
+
+ def test_select_dupes_column_expression(self, column_expression_fixture):
+ table1 = column_expression_fixture
+
+ self.assert_compile(
+ select(table1.c.name, table1.c.value, table1.c.value),
+ "SELECT some_table.name, lower(some_table.value) AS value, "
+ "lower(some_table.value) AS value__1 FROM some_table",
+ )
+
+ def test_returning_dupes_column_expression(
+ self, column_expression_fixture
+ ):
+ table1 = column_expression_fixture
+
+ stmt = insert(table1).returning(
+ table1.c.name, table1.c.value, table1.c.value
+ )
+
+ self.assert_compile(
+ stmt,
+ "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+ "RETURNING some_table.name, lower(some_table.value) AS value, "
+ "lower(some_table.value) AS value__1",
+ )
+
def test_column_auto_label_dupes_label_style_none(self):
expr = self._fixture()
table1 = self.table1
diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py
index 32d4c7740..e0299e334 100644
--- a/test/sql/test_returning.py
+++ b/test/sql/test_returning.py
@@ -415,6 +415,50 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults):
result = connection.execute(ins)
eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)])
+ @testing.fixture
+ def column_expression_fixture(self, metadata, connection):
+ class MyString(TypeDecorator):
+ cache_ok = True
+ impl = String(50)
+
+ def column_expression(self, column):
+ return func.lower(column)
+
+ t1 = Table(
+ "some_table",
+ metadata,
+ Column("name", String(50)),
+ Column("value", MyString(50)),
+ )
+ metadata.create_all(connection)
+ return t1
+
+ @testing.combinations("columns", "table", argnames="use_columns")
+ def test_plain_returning_column_expression(
+ self, column_expression_fixture, use_columns, connection
+ ):
+ """test #8770"""
+ table1 = column_expression_fixture
+
+ if use_columns == "columns":
+ stmt = (
+ insert(table1)
+ .values(name="n1", value="ValUE1")
+ .returning(table1)
+ )
+ else:
+ stmt = (
+ insert(table1)
+ .values(name="n1", value="ValUE1")
+ .returning(table1.c.name, table1.c.value)
+ )
+
+ result = connection.execute(stmt)
+ row = result.first()
+
+ eq_(row._mapping["name"], "n1")
+ eq_(row._mapping["value"], "value1")
+
@testing.fails_on_everything_except(
"postgresql", "mariadb>=10.5", "sqlite>=3.34"
)