diff options
-rw-r--r-- | doc/build/changelog/changelog_09.rst | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 17 | ||||
-rw-r--r-- | test/sql/test_cte.py | 36 |
3 files changed, 55 insertions, 7 deletions
diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst index c7cacefe8..d678862e7 100644 --- a/doc/build/changelog/changelog_09.rst +++ b/doc/build/changelog/changelog_09.rst @@ -16,6 +16,15 @@ .. change:: :tags: bug, sql + :tickets: 3090 + :versions: 1.0.0 + + Fixed bug in common table expressions whereby positional bound + parameters could be expressed in the wrong final order + when CTEs were nested in certain ways. + + .. change:: + :tags: bug, sql :tickets: 3069 :versions: 1.0.0 diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 32ecb2eae..7a8b07f8f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -407,7 +407,7 @@ class SQLCompiler(Compiled): self.ctes_by_name = {} self.ctes_recursive = False if self.positional: - self.cte_positional = [] + self.cte_positional = {} def _apply_numbered_params(self): poscount = itertools.count(1) @@ -1089,8 +1089,6 @@ class SQLCompiler(Compiled): fromhints=None, **kwargs): self._init_cte_state() - if self.positional: - kwargs['positional_names'] = self.cte_positional if isinstance(cte.name, elements._truncated_label): cte_name = self._truncated_identifier("alias", cte.name) @@ -1144,10 +1142,15 @@ class SQLCompiler(Compiled): text += "(%s)" % (", ".join( self.preparer.format_column(ident) for ident in recur_cols)) + + if self.positional: + kwargs['positional_names'] = self.cte_positional[cte] = [] + text += " AS \n" + \ cte.original._compiler_dispatch( self, asfrom=True, **kwargs ) + self.ctes[cte] = text if asfrom: @@ -1416,7 +1419,6 @@ class SQLCompiler(Compiled): iswrapper=False, fromhints=None, compound_index=0, force_result_map=False, - positional_names=None, nested_join_translation=False, **kwargs): @@ -1433,7 +1435,6 @@ class SQLCompiler(Compiled): iswrapper=iswrapper, fromhints=fromhints, compound_index=compound_index, force_result_map=force_result_map, - positional_names=positional_names, nested_join_translation=True, **kwargs ) @@ -1479,7 +1480,6 @@ class SQLCompiler(Compiled): column_clause_args = kwargs.copy() column_clause_args.update({ - 'positional_names': positional_names, 'within_label_clause': False, 'within_columns_clause': False }) @@ -1590,7 +1590,10 @@ class SQLCompiler(Compiled): def _render_cte_clause(self): if self.positional: - self.positiontup = self.cte_positional + self.positiontup + self.positiontup = sum([ + self.cte_positional[cte] + for cte in self.ctes], []) + \ + self.positiontup cte_text = self.get_cte_preamble(self.ctes_recursive) + " " cte_text += ", \n".join( [txt for txt in self.ctes.values()] diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 887d56710..18c85f9e6 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -367,6 +367,42 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): dialect=dialect ) + def test_positional_binds_2(self): + orders = table('orders', + column('order'), + ) + s = select([orders.c.order, literal("x")]).cte("regional_sales") + s = select([s.c.order, literal("y")]) + dialect = default.DefaultDialect() + dialect.positional = True + dialect.paramstyle = 'numeric' + s1 = select([orders.c.order]).where(orders.c.order == 'x').\ + cte("regional_sales_1") + + s1a = s1.alias() + + s2 = select([orders.c.order == 'y', s1a.c.order, + orders.c.order, s1.c.order]).\ + where(orders.c.order == 'z').\ + cte("regional_sales_2") + + + s3 = select([s2]) + + self.assert_compile( + s3, + 'WITH regional_sales_1 AS (SELECT orders."order" AS "order" ' + 'FROM orders WHERE orders."order" = :1), regional_sales_2 AS ' + '(SELECT orders."order" = :2 AS anon_1, ' + 'anon_2."order" AS "order", ' + 'orders."order" AS "order", ' + 'regional_sales_1."order" AS "order" FROM orders, ' + 'regional_sales_1 ' + 'AS anon_2, regional_sales_1 ' + 'WHERE orders."order" = :3) SELECT regional_sales_2.anon_1, ' + 'regional_sales_2."order" FROM regional_sales_2', + checkpositional=('x', 'y', 'z'), dialect=dialect) + def test_all_aliases(self): orders = table('order', column('order')) |