diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-12-08 11:00:31 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-12-08 17:37:25 -0500 |
commit | 8a152ec391118a05ac54974d0f013cf0e99c7832 (patch) | |
tree | 8be0d84b96e38ecb4f60f178507f8805e0b30309 /lib/sqlalchemy/sql/compiler.py | |
parent | caccf151f2e1b357fa2a5d37135580ce9931eec2 (diff) | |
download | sqlalchemy-8a152ec391118a05ac54974d0f013cf0e99c7832.tar.gz |
fix construct_params() for render_postcompile; add new API
The :meth:`.SQLCompiler.construct_params` method, as well as the
:attr:`.SQLCompiler.params` accessor, will now return the
exact parameters that correspond to a compiled statement that used
the ``render_postcompile`` parameter to compile. Previously,
the method returned a parameter structure that by itself didn't correspond
to either the original parameters or the expanded ones.
Passing a new dictionary of parameters to
:meth:`.SQLCompiler.construct_params` for a :class:`.SQLCompiler` that was
constructed with ``render_postcompile`` is now disallowed; instead, to make
a new SQL string and parameter set for an alternate set of parameters, a
new method :meth:`.SQLCompiler.construct_expanded_state` is added which
will produce a new expanded form for the given parameter set, using the
:class:`.ExpandedState` container which includes a new SQL statement
and new parameter dictionary, as well as a positional parameter tuple.
Fixes: #6114
Change-Id: I9874905bb90f86799b82b244d57369558b18fd93
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 177 |
1 files changed, 141 insertions, 36 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index d7358ad3b..7aa89869e 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -404,13 +404,53 @@ class ExpandedState(NamedTuple): will be rendered into the SQL statement at execution time, rather than being passed as separate parameters to the driver. + To create an :class:`.ExpandedState` instance, use the + :meth:`.SQLCompiler.construct_expanded_state` method on any + :class:`.SQLCompiler` instance. + """ statement: str - additional_parameters: _CoreSingleExecuteParams + """String SQL statement with parameters fully expanded""" + + parameters: _CoreSingleExecuteParams + """Parameter dictionary with parameters fully expanded. + + For a statement that uses named parameters, this dictionary will map + exactly to the names in the statement. For a statement that uses + positional parameters, the :attr:`.ExpandedState.positional_parameters` + will yield a tuple with the positional parameter set. + + """ + processors: Mapping[str, _BindProcessorType[Any]] + """mapping of bound value processors""" + positiontup: Optional[Sequence[str]] + """Sequence of string names indicating the order of positional + parameters""" + parameter_expansion: Mapping[str, List[str]] + """Mapping representing the intermediary link from original parameter + name to list of "expanded" parameter names, for those parameters that + were expanded.""" + + @property + def positional_parameters(self) -> Tuple[Any, ...]: + """Tuple of positional parameters, for statements that were compiled + using a positional paramstyle. + + """ + if self.positiontup is None: + raise exc.InvalidRequestError( + "statement does not use a positional paramstyle" + ) + return tuple(self.parameters[key] for key in self.positiontup) + + @property + def additional_parameters(self) -> _CoreSingleExecuteParams: + """synonym for :attr:`.ExpandedState.parameters`.""" + return self.parameters class _InsertManyValues(NamedTuple): @@ -956,8 +996,30 @@ class SQLCompiler(Compiled): """ whether to render out POSTCOMPILE params during the compile phase. + This attribute is used only for end-user invocation of stmt.compile(); + it's never used for actual statement execution, where instead the + dialect internals access and render the internal postcompile structure + directly. + + """ + + _post_compile_expanded_state: Optional[ExpandedState] = None + """When render_postcompile is used, the ``ExpandedState`` used to create + the "expanded" SQL is assigned here, and then used by the ``.params`` + accessor and ``.construct_params()`` methods for their return values. + + .. versionadded:: 2.0.0b5 + """ + _pre_expanded_string: Optional[str] = None + """Stores the original string SQL before 'post_compile' is applied, + for cases where 'post_compile' were used. + + """ + + _pre_expanded_positiontup: Optional[List[str]] = None + _insertmanyvalues: Optional[_InsertManyValues] = None _insert_crud_params: Optional[crud._CrudParamSequence] = None @@ -1164,7 +1226,14 @@ class SQLCompiler(Compiled): self._process_positional() if self._render_postcompile: - self._process_parameters_for_postcompile(_populate_self=True) + parameters = self.construct_params( + escape_names=False, + _no_postcompile=True, + ) + + self._process_parameters_for_postcompile( + parameters, _populate_self=True + ) @property def insert_single_values_expr(self) -> Optional[str]: @@ -1481,6 +1550,29 @@ class SQLCompiler(Compiled): def sql_compiler(self): return self + def construct_expanded_state( + self, + params: Optional[_CoreSingleExecuteParams] = None, + escape_names: bool = True, + ) -> ExpandedState: + """Return a new :class:`.ExpandedState` for a given parameter set. + + For queries that use "expanding" or other late-rendered parameters, + this method will provide for both the finalized SQL string as well + as the parameters that would be used for a particular parameter set. + + .. versionadded:: 2.0.0b5 + + """ + parameters = self.construct_params( + params, + escape_names=escape_names, + _no_postcompile=True, + ) + return self._process_parameters_for_postcompile( + parameters, + ) + def construct_params( self, params: Optional[_CoreSingleExecuteParams] = None, @@ -1488,12 +1580,26 @@ class SQLCompiler(Compiled): escape_names: bool = True, _group_number: Optional[int] = None, _check: bool = True, + _no_postcompile: bool = False, ) -> _MutableCoreSingleExecuteParams: """return a dictionary of bind parameter keys and values""" + if self._render_postcompile and not _no_postcompile: + assert self._post_compile_expanded_state is not None + if not params: + return dict(self._post_compile_expanded_state.parameters) + else: + raise exc.InvalidRequestError( + "can't construct new parameters when render_postcompile " + "is used; the statement is hard-linked to the original " + "parameters. Use construct_expanded_state to generate a " + "new statement and parameters." + ) + has_escaped_names = escape_names and bool(self.escaped_bind_names) if extracted_parameters: + # related the bound parameters collected in the original cache key # to those collected in the incoming cache key. They will not have # matching names but they will line up positionally in the same @@ -1520,6 +1626,7 @@ class SQLCompiler(Compiled): resolved_extracted = None if params: + pd = {} for bindparam, name in self.bind_names.items(): escaped_name = ( @@ -1593,6 +1700,7 @@ class SQLCompiler(Compiled): pd[escaped_name] = value_param.effective_value else: pd[escaped_name] = value_param.value + return pd @util.memoized_instancemethod @@ -1649,7 +1757,7 @@ class SQLCompiler(Compiled): def _process_parameters_for_postcompile( self, - parameters: Optional[_MutableCoreSingleExecuteParams] = None, + parameters: _MutableCoreSingleExecuteParams, _populate_self: bool = False, ) -> ExpandedState: """handle special post compile parameters. @@ -1665,16 +1773,22 @@ class SQLCompiler(Compiled): """ - if parameters is None: - parameters = self.construct_params(escape_names=False) - expanded_parameters = {} - positiontup: Optional[List[str]] + new_positiontup: Optional[List[str]] + + pre_expanded_string = self._pre_expanded_string + if pre_expanded_string is None: + pre_expanded_string = self.string if self.positional: - positiontup = [] + new_positiontup = [] + + pre_expanded_positiontup = self._pre_expanded_positiontup + if pre_expanded_positiontup is None: + pre_expanded_positiontup = self.positiontup + else: - positiontup = None + new_positiontup = pre_expanded_positiontup = None processors = self._bind_processors single_processors = cast( @@ -1698,8 +1812,8 @@ class SQLCompiler(Compiled): numeric_positiontup: Optional[List[str]] = None - if self.positional and self.positiontup is not None: - names: Iterable[str] = self.positiontup + if self.positional and pre_expanded_positiontup is not None: + names: Iterable[str] = pre_expanded_positiontup if self._numeric_binds: numeric_positiontup = [] else: @@ -1772,16 +1886,16 @@ class SQLCompiler(Compiled): numeric_positiontup.extend( name for name, _ in to_update ) - elif positiontup is not None: + elif new_positiontup is not None: # to_update has escaped names, but that's ok since # these are new names, that aren't in the # escaped_bind_names dict. - positiontup.extend(name for name, _ in to_update) + new_positiontup.extend(name for name, _ in to_update) expanded_parameters[name] = [ expand_key for expand_key, _ in to_update ] - elif positiontup is not None: - positiontup.append(name) + elif new_positiontup is not None: + new_positiontup.append(name) def process_expanding(m): key = m.group(1) @@ -1799,11 +1913,11 @@ class SQLCompiler(Compiled): return expr statement = re.sub( - self._post_compile_pattern, process_expanding, self.string + self._post_compile_pattern, process_expanding, pre_expanded_string ) if numeric_positiontup is not None: - assert positiontup is not None + assert new_positiontup is not None param_pos = { key: f"{self._numeric_binds_identifier_char}{num}" for num, key in enumerate( @@ -1814,13 +1928,13 @@ class SQLCompiler(Compiled): statement = self._pyformat_pattern.sub( lambda m: param_pos[m.group(1)], statement ) - positiontup.extend(numeric_positiontup) + new_positiontup.extend(numeric_positiontup) expanded_state = ExpandedState( statement, parameters, new_processors, - positiontup, + new_positiontup, expanded_parameters, ) @@ -1828,24 +1942,15 @@ class SQLCompiler(Compiled): # this is for the "render_postcompile" flag, which is not # otherwise used internally and is for end-user debugging and # special use cases. + self._pre_expanded_string = pre_expanded_string + self._pre_expanded_positiontup = pre_expanded_positiontup self.string = expanded_state.statement - self._bind_processors.update(expanded_state.processors) - self.positiontup = list(expanded_state.positiontup or ()) - self.post_compile_params = frozenset() - for key in expanded_state.parameter_expansion: - bind = self.binds.pop(key) - - if TYPE_CHECKING: - assert bind.value is not None - - self.bind_names.pop(bind) - for value, expanded_key in zip( - bind.value, expanded_state.parameter_expansion[key] - ): - self.binds[expanded_key] = new_param = bind._with_value( - value - ) - self.bind_names[new_param] = expanded_key + self.positiontup = ( + list(expanded_state.positiontup or ()) + if self.positional + else None + ) + self._post_compile_expanded_state = expanded_state return expanded_state |