diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 38 |
1 files changed, 32 insertions, 6 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 23cd778d0..8718e15ea 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -663,6 +663,12 @@ class SQLCompiler(Compiled): """ + escaped_bind_names = util.EMPTY_DICT + """Late escaping of bound parameter names that has to be converted + to the original name when looking in the parameter dictionary. + + """ + has_out_parameters = False """if True, there are bindparam() objects that have the isoutparam flag set.""" @@ -879,6 +885,8 @@ class SQLCompiler(Compiled): ): """return a dictionary of bind parameter keys and values""" + has_escaped_names = bool(self.escaped_bind_names) + if extracted_parameters: # related the bound parameters collected in the original cache key # to those collected in the incoming cache key. They will not have @@ -908,10 +916,16 @@ class SQLCompiler(Compiled): if params: pd = {} for bindparam, name in self.bind_names.items(): + escaped_name = ( + self.escaped_bind_names.get(name, name) + if has_escaped_names + else name + ) + if bindparam.key in params: - pd[name] = params[bindparam.key] + pd[escaped_name] = params[bindparam.key] elif name in params: - pd[name] = params[name] + pd[escaped_name] = params[name] elif _check and bindparam.required: if _group_number: @@ -936,13 +950,19 @@ class SQLCompiler(Compiled): value_param = bindparam if bindparam.callable: - pd[name] = value_param.effective_value + pd[escaped_name] = value_param.effective_value else: - pd[name] = value_param.value + pd[escaped_name] = value_param.value return pd else: pd = {} for bindparam, name in self.bind_names.items(): + escaped_name = ( + self.escaped_bind_names.get(name, name) + if has_escaped_names + else name + ) + if _check and bindparam.required: if _group_number: raise exc.InvalidRequestError( @@ -964,9 +984,9 @@ class SQLCompiler(Compiled): value_param = bindparam if bindparam.callable: - pd[name] = value_param.effective_value + pd[escaped_name] = value_param.effective_value else: - pd[name] = value_param.value + pd[escaped_name] = value_param.value return pd @util.memoized_instancemethod @@ -2316,6 +2336,7 @@ class SQLCompiler(Compiled): positional_names=None, post_compile=False, expanding=False, + escaped_from=None, **kw ): if self.positional: @@ -2323,6 +2344,11 @@ class SQLCompiler(Compiled): positional_names.append(name) else: self.positiontup.append(name) + + if escaped_from: + if not self.escaped_bind_names: + self.escaped_bind_names = {} + self.escaped_bind_names[escaped_from] = name if post_compile: return "[POSTCOMPILE_%s]" % name else: |