diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 70 |
1 files changed, 60 insertions, 10 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 87ae5232e..799fca2f5 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -470,7 +470,7 @@ class Compiled(object): return self.string or "" - def construct_params(self, params=None): + def construct_params(self, params=None, extracted_parameters=None): """Return the bind params for this compiled object. :param params: a dict of string/object pairs whose values will @@ -664,6 +664,7 @@ class SQLCompiler(Compiled): self, dialect, statement, + cache_key=None, column_keys=None, inline=False, linting=NO_LINTING, @@ -687,6 +688,8 @@ class SQLCompiler(Compiled): """ self.column_keys = column_keys + self.cache_key = cache_key + # compile INSERT/UPDATE defaults/sequences inlined (no pre- # execute) self.inline = inline or getattr(statement, "_inline", False) @@ -818,9 +821,38 @@ class SQLCompiler(Compiled): def sql_compiler(self): return self - def construct_params(self, params=None, _group_number=None, _check=True): + def construct_params( + self, + params=None, + _group_number=None, + _check=True, + extracted_parameters=None, + ): """return a dictionary of bind parameter keys and values""" + if extracted_parameters: + # related the bound parameters collected in the original cache key + # to those collected in the incoming cache key. They will not have + # matching names but they will line up positionally in the same + # way. The parameters present in self.bind_names may be clones of + # these original cache key params in the case of DML but the .key + # will be guaranteed to match. + try: + orig_extracted = self.cache_key[1] + except TypeError as err: + util.raise_( + exc.CompileError( + "This compiled object has no original cache key; " + "can't pass extracted_parameters to construct_params" + ), + replace_context=err, + ) + resolved_extracted = dict( + zip([b.key for b in orig_extracted], extracted_parameters) + ) + else: + resolved_extracted = None + if params: pd = {} for bindparam in self.bind_names: @@ -844,11 +876,18 @@ class SQLCompiler(Compiled): % bindparam.key, code="cd3x", ) - - elif bindparam.callable: - pd[name] = bindparam.effective_value else: - pd[name] = bindparam.value + if resolved_extracted: + value_param = resolved_extracted.get( + bindparam.key, bindparam + ) + else: + value_param = bindparam + + if bindparam.callable: + pd[name] = value_param.effective_value + else: + pd[name] = value_param.value return pd else: pd = {} @@ -868,10 +907,19 @@ class SQLCompiler(Compiled): code="cd3x", ) + if resolved_extracted: + value_param = resolved_extracted.get( + bindparam.key, bindparam + ) + else: + value_param = bindparam + if bindparam.callable: - pd[self.bind_names[bindparam]] = bindparam.effective_value + pd[ + self.bind_names[bindparam] + ] = value_param.effective_value else: - pd[self.bind_names[bindparam]] = bindparam.value + pd[self.bind_names[bindparam]] = value_param.value return pd @property @@ -2144,7 +2192,9 @@ class SQLCompiler(Compiled): assert False recur_cols = [ c - for c in util.unique_list(col_source.inner_columns) + for c in util.unique_list( + col_source._exported_columns_iterator() + ) if c is not None ] @@ -3375,7 +3425,7 @@ class DDLCompiler(Compiled): def type_compiler(self): return self.dialect.type_compiler - def construct_params(self, params=None): + def construct_params(self, params=None, extracted_parameters=None): return None def visit_ddl(self, ddl, **kwargs): |