diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2022-04-07 15:19:10 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-04-07 15:19:10 +0000 |
commit | 735792d75681e3bc6cdd2d97a903909c6f993b7f (patch) | |
tree | 9b060eda44fc2a6025e6951af70e49f614d2890f /lib/sqlalchemy/sql | |
parent | 9edcdd03e8cc69e439e89c45c5083edcd28a23af (diff) | |
parent | 2acc9ec1281b2818bd44804f040d94ec46215688 (diff) | |
download | sqlalchemy-735792d75681e3bc6cdd2d97a903909c6f993b7f.tar.gz |
Merge "cx_Oracle modernize" into main
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/_typing.py | 23 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 116 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 65 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 24 |
7 files changed, 180 insertions, 64 deletions
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 0a72a93c5..bc1e0672c 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -14,6 +14,11 @@ from ..util.typing import Literal from ..util.typing import Protocol if TYPE_CHECKING: + from .compiler import Compiled + from .compiler import DDLCompiler + from .compiler import SQLCompiler + from .dml import UpdateBase + from .dml import ValuesBase from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement @@ -38,6 +43,7 @@ if TYPE_CHECKING: from .type_api import TypeEngine from ..util.typing import TypeGuard + _T = TypeVar("_T", bound=Any) @@ -153,6 +159,12 @@ _TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] if TYPE_CHECKING: + def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: + ... + + def is_ddl_compiler(c: Compiled) -> TypeGuard[DDLCompiler]: + ... + def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]: ... @@ -183,7 +195,13 @@ if TYPE_CHECKING: def is_subquery(t: FromClause) -> TypeGuard[Subquery]: ... + def is_dml(c: ClauseElement) -> TypeGuard[UpdateBase]: + ... + else: + + is_sql_compiler = operator.attrgetter("is_sql") + is_ddl_compiler = operator.attrgetter("is_ddl") is_named_from_clause = operator.attrgetter("named_with_column") is_column_element = operator.attrgetter("_is_column_element") is_text_clause = operator.attrgetter("_is_text_clause") @@ -194,6 +212,7 @@ else: is_select_statement = operator.attrgetter("_is_select_statement") is_table = operator.attrgetter("_is_table") is_subquery = operator.attrgetter("_is_subquery") + is_dml = operator.attrgetter("is_dml") def has_schema_attr(t: FromClauseRole) -> TypeGuard[TableClause]: @@ -206,3 +225,7 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]: def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]: return hasattr(s, "__clause_element__") + + +def is_insert_update(c: ClauseElement) -> TypeGuard[ValuesBase]: + return c.is_dml and (c.is_insert or c.is_update) # type: ignore diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 6b25d8fcd..f766a5ac5 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -937,7 +937,7 @@ class Executable(roles.StatementRole, Generative): _with_context_options: Tuple[ Tuple[Callable[[CompileState], None], Any], ... ] = () - _compile_options: Optional[CacheableOptions] + _compile_options: Optional[Union[Type[CacheableOptions], CacheableOptions]] _executable_traverse_internals = [ ("_with_options", InternalTraversal.dp_executable_options), @@ -982,7 +982,7 @@ class Executable(roles.StatementRole, Generative): ) -> Result: ... - @util.non_memoized_property + @util.ro_non_memoized_property def _all_selected_columns(self): raise NotImplementedError() diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6ecfbf986..522a0bd4a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -61,6 +61,8 @@ from . import operators from . import schema from . import selectable from . import sqltypes +from ._typing import is_column_element +from ._typing import is_dml from .base import _from_objects from .base import Executable from .base import NO_ARG @@ -90,6 +92,7 @@ if typing.TYPE_CHECKING: from .elements import _truncated_label from .elements import BindParameter from .elements import ColumnClause + from .elements import ColumnElement from .elements import Label from .functions import Function from .selectable import Alias @@ -492,6 +495,9 @@ class Compiled: defaults. """ + is_sql = False + is_ddl = False + _cached_metadata: Optional[CursorResultMetaData] = None _result_columns: Optional[List[ResultColumnsEntry]] = None @@ -701,6 +707,8 @@ class SQLCompiler(Compiled): extract_map = EXTRACT_MAP + is_sql = True + _result_columns: List[ResultColumnsEntry] compound_keywords = COMPOUND_KEYWORDS @@ -725,9 +733,14 @@ class SQLCompiler(Compiled): """list of columns for which onupdate default values should be evaluated before an UPDATE takes place""" - returning: Optional[Sequence[roles.ColumnsClauseRole]] - """list of columns that will be delivered to cursor.description or - dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE + implicit_returning: Optional[Sequence[ColumnElement[Any]]] = None + """list of "implicit" returning columns for a toplevel INSERT or UPDATE + statement, used to receive newly generated values of columns. + + .. versionadded:: 2.0 ``implicit_returning`` replaces the previous + ``returning`` collection, which was not a generalized RETURNING + collection and instead was in fact specific to the "implicit returning" + feature. """ @@ -750,12 +763,6 @@ class SQLCompiler(Compiled): TypeEngine. CursorResult uses this for type processing and column targeting""" - returning = None - """holds the "returning" collection of columns if - the statement is CRUD and defines returning columns - either implicitly or explicitly - """ - returning_precedes_values: bool = False """set to True classwide to generate RETURNING clauses before the VALUES or WHERE clause (i.e. MSSQL) @@ -978,9 +985,6 @@ class SQLCompiler(Compiled): if TYPE_CHECKING: assert isinstance(statement, UpdateBase) - if statement._returning: - self.returning = statement._returning - if self.isinsert or self.isupdate: if TYPE_CHECKING: assert isinstance(statement, ValuesBase) @@ -1001,6 +1005,39 @@ class SQLCompiler(Compiled): if self._render_postcompile: self._process_parameters_for_postcompile(_populate_self=True) + @util.ro_memoized_property + def effective_returning(self) -> Optional[Sequence[ColumnElement[Any]]]: + """The effective "returning" columns for INSERT, UPDATE or DELETE. + + This is either the so-called "implicit returning" columns which are + calculated by the compiler on the fly, or those present based on what's + present in ``self.statement._returning`` (expanded into individual + columns using the ``._all_selected_columns`` attribute) i.e. those set + explicitly using the :meth:`.UpdateBase.returning` method. + + .. versionadded:: 2.0 + + """ + if self.implicit_returning: + return self.implicit_returning + elif is_dml(self.statement): + return [ + c + for c in self.statement._all_selected_columns + if is_column_element(c) + ] + + else: + return None + + @property + def returning(self): + """backwards compatibility; returns the + effective_returning collection. + + """ + return self.effective_returning + @property def current_executable(self): """Return the current 'executable' that is being compiled. @@ -1569,7 +1606,7 @@ class SQLCompiler(Compiled): param_key_getter = self._within_exec_param_key_getter table = self.statement.table - returning = self.returning + returning = self.implicit_returning assert returning is not None ret = {col: idx for idx, col in enumerate(returning)} @@ -3373,7 +3410,9 @@ class SQLCompiler(Compiled): ResultColumnsEntry(keyname, name, objects, type_) ) - def _label_returning_column(self, stmt, column, column_clause_args=None): + def _label_returning_column( + self, stmt, column, populate_result_map, column_clause_args=None + ): """Render a column with necessary labels inside of a RETURNING clause. This method is provided for individual dialects in place of calling @@ -3386,7 +3425,7 @@ class SQLCompiler(Compiled): return self._label_select_column( None, column, - True, + populate_result_map, False, {} if column_clause_args is None else column_clause_args, ) @@ -4103,7 +4142,10 @@ class SQLCompiler(Compiled): def returning_clause( self, stmt: UpdateBase, - returning_cols: Sequence[roles.ColumnsClauseRole], + returning_cols: Sequence[ColumnElement[Any]], + *, + populate_result_map: bool, + **kw: Any, ) -> str: raise exc.CompileError( "RETURNING is not supported by this " @@ -4228,7 +4270,6 @@ class SQLCompiler(Compiled): return dialect_hints, table_text def visit_insert(self, insert_stmt, **kw): - compile_state = insert_stmt._compile_state_factory( insert_stmt, self, **kw ) @@ -4250,7 +4291,7 @@ class SQLCompiler(Compiled): ) crud_params_struct = crud._get_crud_params( - self, insert_stmt, compile_state, **kw + self, insert_stmt, compile_state, toplevel, **kw ) crud_params_single = crud_params_struct.single_params @@ -4303,9 +4344,11 @@ class SQLCompiler(Compiled): [expr for _, expr, _ in crud_params_single] ) - if self.returning or insert_stmt._returning: + if self.implicit_returning or insert_stmt._returning: returning_clause = self.returning_clause( - insert_stmt, self.returning or insert_stmt._returning + insert_stmt, + self.implicit_returning or insert_stmt._returning, + populate_result_map=toplevel, ) if self.returning_precedes_values: @@ -4449,7 +4492,7 @@ class SQLCompiler(Compiled): update_stmt, update_stmt.table, render_extra_froms, **kw ) crud_params_struct = crud._get_crud_params( - self, update_stmt, compile_state, **kw + self, update_stmt, compile_state, toplevel, **kw ) crud_params = crud_params_struct.single_params @@ -4473,10 +4516,12 @@ class SQLCompiler(Compiled): ) ) - if self.returning or update_stmt._returning: + if self.implicit_returning or update_stmt._returning: if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning or update_stmt._returning + update_stmt, + self.implicit_returning or update_stmt._returning, + populate_result_map=toplevel, ) if extra_froms: @@ -4502,10 +4547,12 @@ class SQLCompiler(Compiled): text += " " + limit_clause if ( - self.returning or update_stmt._returning + self.implicit_returning or update_stmt._returning ) and not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning or update_stmt._returning + update_stmt, + self.implicit_returning or update_stmt._returning, + populate_result_map=toplevel, ) if self.ctes: @@ -4585,7 +4632,9 @@ class SQLCompiler(Compiled): if delete_stmt._returning: if self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning + delete_stmt, + delete_stmt._returning, + populate_result_map=toplevel, ) if extra_froms: @@ -4608,7 +4657,9 @@ class SQLCompiler(Compiled): if delete_stmt._returning and not self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning + delete_stmt, + delete_stmt._returning, + populate_result_map=toplevel, ) if self.ctes: @@ -4685,7 +4736,14 @@ class StrSQLCompiler(SQLCompiler): def visit_sequence(self, seq, **kw): return "<next sequence value: %s>" % self.preparer.format_sequence(seq) - def returning_clause(self, stmt, returning_cols): + def returning_clause( + self, + stmt: UpdateBase, + returning_cols: Sequence[ColumnElement[Any]], + *, + populate_result_map: bool, + **kw: Any, + ) -> str: columns = [ self._label_select_column(None, c, True, False, {}) for c in base._select_iterables(returning_cols) @@ -4733,6 +4791,8 @@ class StrSQLCompiler(SQLCompiler): class DDLCompiler(Compiled): + is_ddl = True + if TYPE_CHECKING: def __init__( diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 91a3f70c9..f6db2c4b2 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -87,6 +87,7 @@ def _get_crud_params( compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState, + toplevel: bool, **kw: Any, ) -> _CrudParams: """create a set of tuples representing column/string pairs for use @@ -99,10 +100,33 @@ def _get_crud_params( """ + # note: the _get_crud_params() system was written with the notion in mind + # that INSERT, UPDATE, DELETE are always the top level statement and + # that there is only one of them. With the addition of CTEs that can + # make use of DML, this assumption is no longer accurate; the DML + # statement is not necessarily the top-level "row returning" thing + # and it is also theoretically possible (fortunately nobody has asked yet) + # to have a single statement with multiple DMLs inside of it via CTEs. + + # the current _get_crud_params() design doesn't accommodate these cases + # right now. It "just works" for a CTE that has a single DML inside of + # it, and for a CTE with multiple DML, it's not clear what would happen. + + # overall, the "compiler.XYZ" collections here would need to be in a + # per-DML structure of some kind, and DefaultDialect would need to + # navigate these collections on a per-statement basis, with additional + # emphasis on the "toplevel returning data" statement. However we + # still need to run through _get_crud_params() for all DML as we have + # Python / SQL generated column defaults that need to be rendered. + + # if there is user need for this kind of thing, it's likely a post 2.0 + # kind of change as it would require deep changes to DefaultDialect + # as well as here. + compiler.postfetch = [] compiler.insert_prefetch = [] compiler.update_prefetch = [] - compiler.returning = [] + compiler.implicit_returning = [] # getters - these are normally just column.key, # but in the case of mysql multi-table update, the rules for @@ -213,6 +237,7 @@ def _get_crud_params( _col_bind_name, check_columns, values, + toplevel, kw, ) else: @@ -226,6 +251,7 @@ def _get_crud_params( _col_bind_name, check_columns, values, + toplevel, kw, ) @@ -419,6 +445,7 @@ def _scan_insert_from_select_cols( _col_bind_name, check_columns, values, + toplevel, kw, ): @@ -427,7 +454,7 @@ def _scan_insert_from_select_cols( implicit_returning, implicit_return_defaults, postfetch_lastrowid, - ) = _get_returning_modifiers(compiler, stmt, compile_state) + ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names] @@ -472,6 +499,7 @@ def _scan_cols( _col_bind_name, check_columns, values, + toplevel, kw, ): ( @@ -479,7 +507,7 @@ def _scan_cols( implicit_returning, implicit_return_defaults, postfetch_lastrowid, - ) = _get_returning_modifiers(compiler, stmt, compile_state) + ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) if compile_state._parameter_ordering: parameter_ordering = [ @@ -556,11 +584,11 @@ def _scan_cols( # column has a DDL-level default, and is either not a pk # column or we don't need the pk. if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) elif implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif ( c.primary_key and c is not stmt.table._autoincrement_column @@ -628,7 +656,7 @@ def _append_param_parameter( if compile_state.isupdate: if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: compiler.postfetch.append(c) @@ -636,12 +664,12 @@ def _append_param_parameter( if c.primary_key: if implicit_returning: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif compiler.dialect.postfetch_lastrowid: compiler.postfetch_lastrowid = True elif implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: # postfetch specifically means, "we can SELECT the row we just @@ -674,7 +702,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): compiler.process(c.default, **kw), ) ) - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif c.default.is_clause_element: values.append( ( @@ -683,7 +711,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): compiler.process(c.default.arg.self_group(), **kw), ) ) - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: # client side default. OK we can't use RETURNING, need to # do a "prefetch", which in fact fetches the default value @@ -696,7 +724,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): ) ) elif c is stmt.table._autoincrement_column or c.server_default is not None: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif not c.nullable: # no .default, no .server_default, not autoincrement, we have # no indication this primary key column will have any value @@ -794,7 +822,7 @@ def _append_param_insert_hasdefault( ) ) if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) elif c.default.is_clause_element: @@ -807,7 +835,7 @@ def _append_param_insert_hasdefault( ) if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif not c.primary_key: # don't add primary key column to postfetch compiler.postfetch.append(c) @@ -870,7 +898,7 @@ def _append_param_update( ) ) if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: compiler.postfetch.append(c) else: @@ -886,7 +914,7 @@ def _append_param_update( ) elif c.server_onupdate is not None: if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: compiler.postfetch.append(c) elif ( @@ -894,7 +922,7 @@ def _append_param_update( and (stmt._return_defaults_columns or not stmt._return_defaults) and c in implicit_return_defaults ): - compiler.returning.append(c) + compiler.implicit_returning.append(c) @overload @@ -1195,10 +1223,11 @@ def _get_stmt_parameter_tuples_params( values.append((k, col_expr, v)) -def _get_returning_modifiers(compiler, stmt, compile_state): +def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): need_pks = ( - compile_state.isinsert + toplevel + and compile_state.isinsert and not stmt._inline and ( not compiler.for_executemany diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 8a3a1b38f..f23ba2e6e 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -463,11 +463,18 @@ class UpdateBase( ) return self - @util.non_memoized_property + def corresponding_column( + self, column: ColumnElement[Any], require_embedded: bool = False + ) -> Optional[ColumnElement[Any]]: + return self.exported_columns.corresponding_column( + column, require_embedded=require_embedded + ) + + @util.ro_memoized_property def _all_selected_columns(self) -> _SelectIterable: return [c for c in _select_iterables(self._returning)] - @property + @util.ro_memoized_property def exported_columns( self, ) -> ReadOnlyColumnCollection[Optional[str], ColumnElement[Any]]: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 309c01e40..7ea09e758 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -305,6 +305,7 @@ class ClauseElement( is_clause_element = True is_selectable = False + is_dml = False _is_column_element = False _is_table = False _is_textual = False diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 6504449f1..292225ce2 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -127,6 +127,9 @@ if TYPE_CHECKING: _ColumnsClauseElement = Union["FromClause", ColumnElement[Any], "TextClause"] +_LabelConventionCallable = Callable[ + [Union["ColumnElement[Any]", "TextClause"]], Optional[str] +] class _JoinTargetProtocol(Protocol): @@ -183,7 +186,7 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): def selectable(self) -> ReturnsRows: return self - @util.non_memoized_property + @util.ro_non_memoized_property def _all_selected_columns(self) -> _SelectIterable: """A sequence of column expression objects that represents the "selected" columns of this :class:`_expression.ReturnsRows`. @@ -3277,7 +3280,7 @@ class SelectBase( """ raise NotImplementedError() - @util.non_memoized_property + @util.ro_non_memoized_property def _all_selected_columns(self) -> _SelectIterable: """A sequence of expressions that correspond to what is rendered in the columns clause, including :class:`_sql.TextClause` @@ -3586,7 +3589,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase): ) -> None: self.element._generate_fromclause_column_proxies(subquery) - @util.non_memoized_property + @util.ro_non_memoized_property def _all_selected_columns(self) -> _SelectIterable: return self.element._all_selected_columns @@ -4297,7 +4300,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect): for select in self.selects: select._refresh_for_new_column(column) - @util.non_memoized_property + @util.ro_non_memoized_property def _all_selected_columns(self) -> _SelectIterable: return self.selects[0]._all_selected_columns @@ -4408,7 +4411,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def _column_naming_convention( cls, label_style: SelectLabelStyle - ) -> Callable[[Union[ColumnElement[Any], TextClause]], Optional[str]]: + ) -> _LabelConventionCallable: table_qualified = label_style is LABEL_STYLE_TABLENAME_PLUS_COL dedupe = label_style is not LABEL_STYLE_NONE @@ -5984,7 +5987,7 @@ class Select( ) return cc.as_readonly() - @HasMemoized.memoized_attribute + @HasMemoized_ro_memoized_attribute def _all_selected_columns(self) -> _SelectIterable: meth = SelectState.get_plugin_class(self).all_selected_columns return list(meth(self)) @@ -6537,14 +6540,7 @@ class TextualSelect(SelectBase): (c.key, c) for c in self.column_args ).as_readonly() - # def _generate_columns_plus_names( - # self, anon_for_dupe_key: bool - # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: - # return Select._generate_columns_plus_names( - # self, anon_for_dupe_key=anon_for_dupe_key - # ) - - @util.non_memoized_property + @util.ro_non_memoized_property def _all_selected_columns(self) -> _SelectIterable: return self.column_args |