diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 32 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/schema.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 3 |
4 files changed, 33 insertions, 12 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 09e38a5ab..423c3d446 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -34,6 +34,7 @@ import re from time import perf_counter import typing from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Mapping @@ -629,11 +630,11 @@ class SQLCompiler(Compiled): """list of columns that can be post-fetched after INSERT or UPDATE to receive server-updated values""" - insert_prefetch: Optional[List[Column[Any]]] + insert_prefetch: Sequence[Column[Any]] = () """list of columns for which default values should be evaluated before an INSERT takes place""" - update_prefetch: Optional[List[Column[Any]]] + update_prefetch: Sequence[Column[Any]] = () """list of columns for which onupdate default values should be evaluated before an UPDATE takes place""" @@ -739,8 +740,6 @@ class SQLCompiler(Compiled): """if True, there are bindparam() objects that have the isoutparam flag set.""" - insert_prefetch = update_prefetch = () - postfetch_lastrowid = False """if True, and this in insert, use cursor.lastrowid to populate result.inserted_primary_key. """ @@ -1340,7 +1339,7 @@ class SQLCompiler(Compiled): ) @util.memoized_property - def _within_exec_param_key_getter(self): + def _within_exec_param_key_getter(self) -> Callable[[Any], str]: getter = self._key_getters_for_crud_column[2] if self.escaped_bind_names: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 4c38c4efa..168da17cc 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -58,12 +58,13 @@ from ..util.langhelpers import TypingOnly if typing.TYPE_CHECKING: from decimal import Decimal + from .compiler import Compiled + from .compiler import SQLCompiler from .operators import OperatorType from .selectable import FromClause from .selectable import Select from .sqltypes import Boolean # noqa from .type_api import TypeEngine - from ..engine import Compiled from ..engine import Connection from ..engine import Dialect from ..engine import Engine @@ -573,6 +574,25 @@ class ClauseElement( ) +class DQLDMLClauseElement(ClauseElement): + """represents a :class:`.ClauseElement` that compiles to a DQL or DML + expression, not DDL. + + .. versionadded:: 2.0 + + """ + + if typing.TYPE_CHECKING: + + def compile( # noqa: A001 + self, + bind: Optional[Union[Engine, Connection]] = None, + dialect: Optional[Dialect] = None, + **kw: Any, + ) -> SQLCompiler: + ... + + class CompilerColumnElement( roles.DMLColumnRole, roles.DDLConstraintColumnRole, @@ -955,7 +975,7 @@ class ColumnElement( roles.DDLExpressionRole, SQLCoreOperations[_T], operators.ColumnOperators[SQLCoreOperations], - ClauseElement, + DQLDMLClauseElement, ): """Represent a column-oriented SQL expression suitable for usage in the "columns" clause, WHERE clause etc. of a statement. @@ -1820,7 +1840,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): ) -class TypeClause(ClauseElement): +class TypeClause(DQLDMLClauseElement): """Handle a type keyword in a SQL statement. Used by the ``Case`` statement. @@ -1849,7 +1869,7 @@ class TextClause( roles.BinaryElementRole, roles.InElementRole, Executable, - ClauseElement, + DQLDMLClauseElement, ): """Represent a literal SQL text fragment. @@ -2285,7 +2305,7 @@ class ClauseList( roles.OrderByRole, roles.ColumnsClauseRole, roles.DMLColumnRole, - ClauseElement, + DQLDMLClauseElement, ): """Describe a list of clauses, separated by an operator. @@ -3205,7 +3225,7 @@ class IndexExpression(BinaryExpression): inherit_cache = True -class GroupedElement(ClauseElement): +class GroupedElement(DQLDMLClauseElement): """Represent any parenthesized expression""" __visit_name__ = "grouping" diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index fdae4d7b0..c270e1564 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1131,6 +1131,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): __visit_name__ = "column" inherit_cache = True + key: str @overload def __init__( diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index a5cbffb5e..e5c2bef68 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -62,6 +62,7 @@ from .elements import ClauseElement from .elements import ClauseList from .elements import ColumnClause from .elements import ColumnElement +from .elements import DQLDMLClauseElement from .elements import GroupedElement from .elements import Grouping from .elements import literal_column @@ -85,7 +86,7 @@ class _OffsetLimitParam(BindParameter): return self.effective_value -class ReturnsRows(roles.ReturnsRowsRole, ClauseElement): +class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): """The base-most class for Core constructs that have some concept of columns that can represent rows. |