diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 148 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 96 |
2 files changed, 180 insertions, 64 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b140f9297..77bc1ea38 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -31,6 +31,13 @@ import itertools import operator import re from time import perf_counter +import typing +from typing import Any +from typing import Dict +from typing import List +from typing import MutableMapping +from typing import Optional +from typing import Tuple from . import base from . import coercions @@ -47,6 +54,12 @@ from .elements import quoted_name from .. import exc from .. import util +if typing.TYPE_CHECKING: + from .selectable import CTE + from .selectable import FromClause + +_FromHintsType = Dict["FromClause", str] + RESERVED_WORDS = set( [ "all", @@ -842,7 +855,7 @@ class SQLCompiler(Compiled): return {} @util.memoized_instancemethod - def _init_cte_state(self): + def _init_cte_state(self) -> None: """Initialize collections related to CTEs only if a CTE is located, to save on the overhead of these collections otherwise. @@ -850,19 +863,21 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT # To store the query to print - Dict[cte, text_query] - self.ctes = util.OrderedDict() + self.ctes: MutableMapping[CTE, str] = util.OrderedDict() # Detect same CTE references - Dict[(level, name), cte] # Level is required for supporting nesting - self.ctes_by_level_name = {} + self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {} # To retrieve key/level in ctes_by_level_name - - # Dict[cte_reference, (level, cte_name)] - self.level_name_by_cte = {} + # Dict[cte_reference, (level, cte_name, cte_opts)] + self.level_name_by_cte: Dict[ + CTE, Tuple[int, str, selectable._CTEOpts] + ] = {} - self.ctes_recursive = False + self.ctes_recursive: bool = False if self.positional: - self.cte_positional = {} + self.cte_positional: Dict[CTE, List[str]] = {} @contextlib.contextmanager def _nested_result(self): @@ -1604,8 +1619,7 @@ class SQLCompiler(Compiled): self.stack.append(new_entry) if taf._independent_ctes: - for cte in taf._independent_ctes: - cte._compiler_dispatch(self, **kw) + self._dispatch_independent_ctes(taf, kw) populate_result_map = ( toplevel @@ -1879,8 +1893,7 @@ class SQLCompiler(Compiled): ) if compound_stmt._independent_ctes: - for cte in compound_stmt._independent_ctes: - cte._compiler_dispatch(self, **kwargs) + self._dispatch_independent_ctes(compound_stmt, kwargs) keyword = self.compound_keywords.get(cs.keyword) @@ -2671,16 +2684,25 @@ class SQLCompiler(Compiled): return ret + def _dispatch_independent_ctes(self, stmt, kw): + local_kw = kw.copy() + local_kw.pop("cte_opts", None) + for cte, opt in zip( + stmt._independent_ctes, stmt._independent_ctes_opts + ): + cte._compiler_dispatch(self, cte_opts=opt, **local_kw) + def visit_cte( self, - cte, - asfrom=False, - ashint=False, - fromhints=None, - visiting_cte=None, - from_linter=None, - **kwargs, - ): + cte: CTE, + asfrom: bool = False, + ashint: bool = False, + fromhints: Optional[_FromHintsType] = None, + visiting_cte: Optional[CTE] = None, + from_linter: Optional[FromLinter] = None, + cte_opts: selectable._CTEOpts = selectable._CTEOpts(False), + **kwargs: Any, + ) -> Optional[str]: self._init_cte_state() kwargs["visiting_cte"] = cte @@ -2695,15 +2717,48 @@ class SQLCompiler(Compiled): _reference_cte = cte._get_reference_cte() + nesting = cte.nesting or cte_opts.nesting + + # check for CTE already encountered if _reference_cte in self.level_name_by_cte: - cte_level, _ = self.level_name_by_cte[_reference_cte] + cte_level, _, existing_cte_opts = self.level_name_by_cte[ + _reference_cte + ] assert _ == cte_name - else: - cte_level = len(self.stack) if cte.nesting else 1 - cte_level_name = (cte_level, cte_name) - if cte_level_name in self.ctes_by_level_name: + cte_level_name = (cte_level, cte_name) existing_cte = self.ctes_by_level_name[cte_level_name] + + # check if we are receiving it here with a specific + # "nest_here" location; if so, move it to this location + + if cte_opts.nesting: + if existing_cte_opts.nesting: + raise exc.CompileError( + "CTE is stated as 'nest_here' in " + "more than one location" + ) + + old_level_name = (cte_level, cte_name) + cte_level = len(self.stack) if nesting else 1 + cte_level_name = new_level_name = (cte_level, cte_name) + + del self.ctes_by_level_name[old_level_name] + self.ctes_by_level_name[new_level_name] = existing_cte + self.level_name_by_cte[_reference_cte] = new_level_name + ( + cte_opts, + ) + + else: + cte_level = len(self.stack) if nesting else 1 + cte_level_name = (cte_level, cte_name) + + if cte_level_name in self.ctes_by_level_name: + existing_cte = self.ctes_by_level_name[cte_level_name] + else: + existing_cte = None + + if existing_cte is not None: embedded_in_current_named_cte = visiting_cte is existing_cte # we've generated a same-named CTE that we are enclosed in, @@ -2718,10 +2773,8 @@ class SQLCompiler(Compiled): existing_cte_reference_cte = existing_cte._get_reference_cte() - # TODO: determine if these assertions are correct. they - # pass for current test cases - # assert existing_cte_reference_cte is _reference_cte - # assert existing_cte_reference_cte is existing_cte + assert existing_cte_reference_cte is _reference_cte + assert existing_cte_reference_cte is existing_cte del self.level_name_by_cte[existing_cte_reference_cte] else: @@ -2746,19 +2799,9 @@ class SQLCompiler(Compiled): if is_new_cte: self.ctes_by_level_name[cte_level_name] = cte - self.level_name_by_cte[_reference_cte] = cte_level_name - - if ( - "autocommit" in cte.element._execution_options - and "autocommit" not in self.execution_options - ): - self.execution_options = self.execution_options.union( - { - "autocommit": cte.element._execution_options[ - "autocommit" - ] - } - ) + self.level_name_by_cte[_reference_cte] = cte_level_name + ( + cte_opts, + ) if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -3378,8 +3421,7 @@ class SQLCompiler(Compiled): byfrom = None if select_stmt._independent_ctes: - for cte in select_stmt._independent_ctes: - cte._compiler_dispatch(self, **kwargs) + self._dispatch_independent_ctes(select_stmt, kwargs) if select_stmt._prefixes: text += self._generate_prefixes( @@ -3485,7 +3527,9 @@ class SQLCompiler(Compiled): return text - def _setup_select_hints(self, select): + def _setup_select_hints( + self, select: Select + ) -> Tuple[str, _FromHintsType]: byfrom = dict( [ ( @@ -3663,13 +3707,14 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: ctes = util.OrderedDict() for cte in list(self.ctes.keys()): - cte_level, cte_name = self.level_name_by_cte[ + cte_level, cte_name, cte_opts = self.level_name_by_cte[ cte._get_reference_cte() ] + nesting = cte.nesting or cte_opts.nesting is_rendered_level = cte_level == nesting_level or ( include_following_stack and cte_level == nesting_level + 1 ) - if not (cte.nesting and is_rendered_level): + if not (nesting and is_rendered_level): continue ctes[cte] = self.ctes[cte] @@ -3693,7 +3738,7 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: for cte in list(ctes.keys()): - cte_level, cte_name = self.level_name_by_cte[ + cte_level, cte_name, cte_opts = self.level_name_by_cte[ cte._get_reference_cte() ] del self.ctes[cte] @@ -3939,8 +3984,7 @@ class SQLCompiler(Compiled): _, table_text = self._setup_crud_hints(insert_stmt, table_text) if insert_stmt._independent_ctes: - for cte in insert_stmt._independent_ctes: - cte._compiler_dispatch(self, **kw) + self._dispatch_independent_ctes(insert_stmt, kw) text += table_text @@ -4108,8 +4152,7 @@ class SQLCompiler(Compiled): dialect_hints = None if update_stmt._independent_ctes: - for cte in update_stmt._independent_ctes: - cte._compiler_dispatch(self, **kw) + self._dispatch_independent_ctes(update_stmt, kw) text += table_text @@ -4221,8 +4264,7 @@ class SQLCompiler(Compiled): dialect_hints = None if delete_stmt._independent_ctes: - for cte in delete_stmt._independent_ctes: - cte._compiler_dispatch(self, **kw) + self._dispatch_independent_ctes(delete_stmt, kw) text += table_text diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 7f6360edb..836c30af7 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -19,6 +19,7 @@ import itertools from operator import attrgetter import typing from typing import Any as TODO_Any +from typing import NamedTuple from typing import Optional from typing import Tuple @@ -1809,6 +1810,10 @@ class CTE( SelfHasCTE = typing.TypeVar("SelfHasCTE", bound="HasCTE") +class _CTEOpts(NamedTuple): + nesting: bool + + class HasCTE(roles.HasCTERole): """Mixin that declares a class to include CTE support. @@ -1818,20 +1823,36 @@ class HasCTE(roles.HasCTERole): _has_ctes_traverse_internals = [ ("_independent_ctes", InternalTraversal.dp_clauseelement_list), + ("_independent_ctes_opts", InternalTraversal.dp_plain_obj), ] _independent_ctes = () + _independent_ctes_opts = () @_generative - def add_cte(self: SelfHasCTE, cte) -> SelfHasCTE: - """Add a :class:`_sql.CTE` to this statement object that will be - independently rendered even if not referenced in the statement - otherwise. + def add_cte(self: SelfHasCTE, *ctes, nest_here=False) -> SelfHasCTE: + r"""Add one or more :class:`_sql.CTE` constructs to this statement. + + This method will associate the given :class:`_sql.CTE` constructs with + the parent statement such that they will each be unconditionally + rendered in the WITH clause of the final statement, even if not + referenced elsewhere within the statement or any sub-selects. + + The optional :paramref:`.HasCTE.add_cte.nest_here` parameter when set + to True will have the effect that each given :class:`_sql.CTE` will + render in a WITH clause rendered directly along with this statement, + rather than being moved to the top of the ultimate rendered statement, + even if this statement is rendered as a subquery within a larger + statement. - This feature is useful for the use case of embedding a DML statement - such as an INSERT or UPDATE as a CTE inline with a primary statement - that may draw from its results indirectly; while PostgreSQL is known - to support this usage, it may not be supported by other backends. + This method has two general uses. One is to embed CTE statements that + serve some purpose without being referenced explicitly, such as the use + case of embedding a DML statement such as an INSERT or UPDATE as a CTE + inline with a primary statement that may draw from its results + indirectly. The other is to provide control over the exact placement + of a particular series of CTE constructs that should remain rendered + directly in terms of a particular statement that may be nested in a + larger statement. E.g.:: @@ -1885,9 +1906,32 @@ class HasCTE(roles.HasCTERole): .. versionadded:: 1.4.21 + :param \*ctes: zero or more :class:`.CTE` constructs. + + .. versionchanged:: 2.0 Multiple CTE instances are accepted + + :param nest_here: if True, the given CTE or CTEs will be rendered + as though they specified the :paramref:`.HasCTE.cte.nesting` flag + to ``True`` when they were added to this :class:`.HasCTE`. + Assuming the given CTEs are not referenced in an outer-enclosing + statement as well, the CTEs given should render at the level of + this statement when this flag is given. + + .. versionadded:: 2.0 + + .. seealso:: + + :paramref:`.HasCTE.cte.nesting` + + """ - cte = coercions.expect(roles.IsCTERole, cte) - self._independent_ctes += (cte,) + opt = _CTEOpts( + nest_here, + ) + for cte in ctes: + cte = coercions.expect(roles.IsCTERole, cte) + self._independent_ctes += (cte,) + self._independent_ctes_opts += (opt,) return self def cte(self, name=None, recursive=False, nesting=False): @@ -1931,10 +1975,18 @@ class HasCTE(roles.HasCTERole): conjunction with UNION ALL in order to derive rows from those already selected. :param nesting: if ``True``, will render the CTE locally to the - actual statement. + statement in which it is referenced. For more complex scenarios, + the :meth:`.HasCTE.add_cte` method using the + :paramref:`.HasCTE.add_cte.nest_here` + parameter may also be used to more carefully + control the exact placement of a particular CTE. .. versionadded:: 1.4.24 + .. seealso:: + + :meth:`.HasCTE.add_cte` + The following examples include two from PostgreSQL's documentation at https://www.postgresql.org/docs/current/static/queries-with.html, as well as additional examples. @@ -2084,6 +2136,28 @@ class HasCTE(roles.HasCTERole): SELECT value_a.n AS a, value_b.n AS b FROM value_a, value_b + The same CTE can be set up using the :meth:`.HasCTE.add_cte` method + as follows (SQLAlchemy 2.0 and above):: + + value_a = select( + literal("root").label("n") + ).cte("value_a") + + # A nested CTE with the same name as the root one + value_a_nested = select( + literal("nesting").label("n") + ).cte("value_a") + + # Nesting CTEs takes ascendency locally + # over the CTEs at a higher level + value_b = ( + select(value_a_nested.c.n). + add_cte(value_a_nested, nest_here=True). + cte("value_b") + ) + + value_ab = select(value_a.c.n.label("a"), value_b.c.n.label("b")) + Example 5, Non-Linear CTE (SQLAlchemy 1.4.28 and above):: edge = Table( |