diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 106 |
1 files changed, 76 insertions, 30 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index efcfe0e51..0cd568fcc 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -165,11 +165,8 @@ BIND_TEMPLATES = { "named": ":%(name)s", } -BIND_TRANSLATE = { - "pyformat": re.compile(r"[%\(\)]"), - "named": re.compile(r"[\:]"), -} -_BIND_TRANSLATE_CHARS = {"%": "P", "(": "A", ")": "Z", ":": "C"} +_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]") +_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__")) OPERATORS = { # binary @@ -746,7 +743,6 @@ class SQLCompiler(Compiled): self.positiontup = [] self._numeric_binds = dialect.paramstyle == "numeric" self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] - self._bind_translate = BIND_TRANSLATE.get(dialect.paramstyle, None) self.ctes = None @@ -1113,7 +1109,6 @@ class SQLCompiler(Compiled): N as a bound parameter. """ - if parameters is None: parameters = self.construct_params() @@ -1141,22 +1136,36 @@ class SQLCompiler(Compiled): replacement_expressions = {} to_update_sets = {} + # notes: + # *unescaped* parameter names in: + # self.bind_names, self.binds, self._bind_processors + # + # *escaped* parameter names in: + # construct_params(), replacement_expressions + for name in ( self.positiontup if self.positional else self.bind_names.values() ): + escaped_name = ( + self.escaped_bind_names.get(name, name) + if self.escaped_bind_names + else name + ) parameter = self.binds[name] if parameter in self.literal_execute_params: - if name not in replacement_expressions: - value = parameters.pop(name) + if escaped_name not in replacement_expressions: + value = parameters.pop(escaped_name) - replacement_expressions[name] = self.render_literal_bindparam( + replacement_expressions[ + escaped_name + ] = self.render_literal_bindparam( parameter, render_literal_value=value ) continue if parameter in self.post_compile_params: - if name in replacement_expressions: - to_update = to_update_sets[name] + if escaped_name in replacement_expressions: + to_update = to_update_sets[escaped_name] else: # we are removing the parameter from parameters # because it is a list value, which is not expected by @@ -1164,13 +1173,15 @@ class SQLCompiler(Compiled): # process it. the single name is being replaced with # individual numbered parameters for each value in the # param. - values = parameters.pop(name) + values = parameters.pop(escaped_name) leep = self._literal_execute_expanding_parameter - to_update, replacement_expr = leep(name, parameter, values) + to_update, replacement_expr = leep( + escaped_name, parameter, values + ) - to_update_sets[name] = to_update - replacement_expressions[name] = replacement_expr + to_update_sets[escaped_name] = to_update + replacement_expressions[escaped_name] = replacement_expr if not parameter.literal_execute: parameters.update(to_update) @@ -1200,10 +1211,24 @@ class SQLCompiler(Compiled): positiontup.append(name) def process_expanding(m): - return replacement_expressions[m.group(1)] + key = m.group(1) + expr = replacement_expressions[key] + + # if POSTCOMPILE included a bind_expression, render that + # around each element + if m.group(2): + tok = m.group(2).split("~~") + be_left, be_right = tok[1], tok[3] + expr = ", ".join( + "%s%s%s" % (be_left, exp, be_right) + for exp in expr.split(", ") + ) + return expr statement = re.sub( - r"\[POSTCOMPILE_(\S+)\]", process_expanding, self.string + r"\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]", + process_expanding, + self.string, ) expanded_state = ExpandedState( @@ -1963,8 +1988,10 @@ class SQLCompiler(Compiled): self, parameter, values ): + typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect) + if not values: - if parameter.type._is_tuple_type: + if typ_dialect_impl._is_tuple_type: replacement_expression = ( "VALUES " if self.dialect.tuple_in_values else "" ) + self.visit_empty_set_op_expr( @@ -1977,7 +2004,7 @@ class SQLCompiler(Compiled): ) elif isinstance(values[0], (tuple, list)): - assert parameter.type._is_tuple_type + assert typ_dialect_impl._is_tuple_type replacement_expression = ( "VALUES " if self.dialect.tuple_in_values else "" ) + ", ".join( @@ -1993,7 +2020,7 @@ class SQLCompiler(Compiled): for i, tuple_element in enumerate(values) ) else: - assert not parameter.type._is_tuple_type + assert not typ_dialect_impl._is_tuple_type replacement_expression = ", ".join( self.render_literal_value(value, parameter.type) for value in values @@ -2008,9 +2035,11 @@ class SQLCompiler(Compiled): parameter, values ) + typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect) + if not values: to_update = [] - if parameter.type._is_tuple_type: + if typ_dialect_impl._is_tuple_type: replacement_expression = self.visit_empty_set_op_expr( parameter.type.types, parameter.expand_op @@ -2020,7 +2049,10 @@ class SQLCompiler(Compiled): [parameter.type], parameter.expand_op ) - elif isinstance(values[0], (tuple, list)): + elif ( + isinstance(values[0], (tuple, list)) + and not typ_dialect_impl._is_array + ): to_update = [ ("%s_%s_%s" % (name, i, j), value) for i, tuple_element in enumerate(values, 1) @@ -2299,14 +2331,27 @@ class SQLCompiler(Compiled): impl = bindparam.type.dialect_impl(self.dialect) if impl._has_bind_expression: bind_expression = impl.bind_expression(bindparam) - return self.process( + wrapped = self.process( bind_expression, skip_bind_expression=True, within_columns_clause=within_columns_clause, literal_binds=literal_binds, literal_execute=literal_execute, + render_postcompile=render_postcompile, **kwargs ) + if bindparam.expanding: + # for postcompile w/ expanding, move the "wrapped" part + # of this into the inside + m = re.match( + r"^(.*)\(\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped + ) + wrapped = "([POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % ( + m.group(2), + m.group(1), + m.group(3), + ) + return wrapped if not literal_binds: literal_execute = ( @@ -2489,12 +2534,13 @@ class SQLCompiler(Compiled): positional_names.append(name) else: self.positiontup.append(name) - elif not post_compile and not escaped_from: - tr_reg = self._bind_translate - if tr_reg.search(name): - # i'd rather use translate() here but I can't get it to work - # in all cases under Python 2, not worth it right now - new_name = tr_reg.sub( + elif not escaped_from: + + if _BIND_TRANSLATE_RE.search(name): + # not quite the translate use case as we want to + # also get a quick boolean if we even found + # unusual characters in the name + new_name = _BIND_TRANSLATE_RE.sub( lambda m: _BIND_TRANSLATE_CHARS[m.group(0)], name, ) |