summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-10-13 15:52:12 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2021-10-15 09:28:49 -0400
commit639cf972f15c8fbf77980b04fff8e5dbc82af7b6 (patch)
tree162aafe94f82df3e34675ba26b5c88ce4f1b2044 /lib/sqlalchemy/sql/compiler.py
parentfec2b6560c14bb28ee7fc9d21028844acf700b04 (diff)
downloadsqlalchemy-639cf972f15c8fbf77980b04fff8e5dbc82af7b6.tar.gz
support bind expressions w/ expanding IN; apply to psycopg2
Fixed issue where "expanding IN" would fail to function correctly with datatypes that use the :meth:`_types.TypeEngine.bind_expression` method, where the method would need to be applied to each element of the IN expression rather than the overall IN expression itself. Fixed issue where IN expressions against a series of array elements, as can be done with PostgreSQL, would fail to function correctly due to multiple issues within the "expanding IN" feature of SQLAlchemy Core that was standardized in version 1.4. The psycopg2 dialect now makes use of the :meth:`_types.TypeEngine.bind_expression` method with :class:`_types.ARRAY` to portably apply the correct casts to elements. The asyncpg dialect was not affected by this issue as it applies bind-level casts at the driver level rather than at the compiler level. as part of this commit the "bind translate" feature has been simplified and also applies to the names in the POSTCOMPILE tag to accommodate for brackets. Fixes: #7177 Change-Id: I08c703adb0a9bd6f5aeee5de3ff6f03cccdccdc5
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py106
1 files changed, 76 insertions, 30 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index efcfe0e51..0cd568fcc 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -165,11 +165,8 @@ BIND_TEMPLATES = {
"named": ":%(name)s",
}
-BIND_TRANSLATE = {
- "pyformat": re.compile(r"[%\(\)]"),
- "named": re.compile(r"[\:]"),
-}
-_BIND_TRANSLATE_CHARS = {"%": "P", "(": "A", ")": "Z", ":": "C"}
+_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
+_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
OPERATORS = {
# binary
@@ -746,7 +743,6 @@ class SQLCompiler(Compiled):
self.positiontup = []
self._numeric_binds = dialect.paramstyle == "numeric"
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
- self._bind_translate = BIND_TRANSLATE.get(dialect.paramstyle, None)
self.ctes = None
@@ -1113,7 +1109,6 @@ class SQLCompiler(Compiled):
N as a bound parameter.
"""
-
if parameters is None:
parameters = self.construct_params()
@@ -1141,22 +1136,36 @@ class SQLCompiler(Compiled):
replacement_expressions = {}
to_update_sets = {}
+ # notes:
+ # *unescaped* parameter names in:
+ # self.bind_names, self.binds, self._bind_processors
+ #
+ # *escaped* parameter names in:
+ # construct_params(), replacement_expressions
+
for name in (
self.positiontup if self.positional else self.bind_names.values()
):
+ escaped_name = (
+ self.escaped_bind_names.get(name, name)
+ if self.escaped_bind_names
+ else name
+ )
parameter = self.binds[name]
if parameter in self.literal_execute_params:
- if name not in replacement_expressions:
- value = parameters.pop(name)
+ if escaped_name not in replacement_expressions:
+ value = parameters.pop(escaped_name)
- replacement_expressions[name] = self.render_literal_bindparam(
+ replacement_expressions[
+ escaped_name
+ ] = self.render_literal_bindparam(
parameter, render_literal_value=value
)
continue
if parameter in self.post_compile_params:
- if name in replacement_expressions:
- to_update = to_update_sets[name]
+ if escaped_name in replacement_expressions:
+ to_update = to_update_sets[escaped_name]
else:
# we are removing the parameter from parameters
# because it is a list value, which is not expected by
@@ -1164,13 +1173,15 @@ class SQLCompiler(Compiled):
# process it. the single name is being replaced with
# individual numbered parameters for each value in the
# param.
- values = parameters.pop(name)
+ values = parameters.pop(escaped_name)
leep = self._literal_execute_expanding_parameter
- to_update, replacement_expr = leep(name, parameter, values)
+ to_update, replacement_expr = leep(
+ escaped_name, parameter, values
+ )
- to_update_sets[name] = to_update
- replacement_expressions[name] = replacement_expr
+ to_update_sets[escaped_name] = to_update
+ replacement_expressions[escaped_name] = replacement_expr
if not parameter.literal_execute:
parameters.update(to_update)
@@ -1200,10 +1211,24 @@ class SQLCompiler(Compiled):
positiontup.append(name)
def process_expanding(m):
- return replacement_expressions[m.group(1)]
+ key = m.group(1)
+ expr = replacement_expressions[key]
+
+ # if POSTCOMPILE included a bind_expression, render that
+ # around each element
+ if m.group(2):
+ tok = m.group(2).split("~~")
+ be_left, be_right = tok[1], tok[3]
+ expr = ", ".join(
+ "%s%s%s" % (be_left, exp, be_right)
+ for exp in expr.split(", ")
+ )
+ return expr
statement = re.sub(
- r"\[POSTCOMPILE_(\S+)\]", process_expanding, self.string
+ r"\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]",
+ process_expanding,
+ self.string,
)
expanded_state = ExpandedState(
@@ -1963,8 +1988,10 @@ class SQLCompiler(Compiled):
self, parameter, values
):
+ typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+
if not values:
- if parameter.type._is_tuple_type:
+ if typ_dialect_impl._is_tuple_type:
replacement_expression = (
"VALUES " if self.dialect.tuple_in_values else ""
) + self.visit_empty_set_op_expr(
@@ -1977,7 +2004,7 @@ class SQLCompiler(Compiled):
)
elif isinstance(values[0], (tuple, list)):
- assert parameter.type._is_tuple_type
+ assert typ_dialect_impl._is_tuple_type
replacement_expression = (
"VALUES " if self.dialect.tuple_in_values else ""
) + ", ".join(
@@ -1993,7 +2020,7 @@ class SQLCompiler(Compiled):
for i, tuple_element in enumerate(values)
)
else:
- assert not parameter.type._is_tuple_type
+ assert not typ_dialect_impl._is_tuple_type
replacement_expression = ", ".join(
self.render_literal_value(value, parameter.type)
for value in values
@@ -2008,9 +2035,11 @@ class SQLCompiler(Compiled):
parameter, values
)
+ typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+
if not values:
to_update = []
- if parameter.type._is_tuple_type:
+ if typ_dialect_impl._is_tuple_type:
replacement_expression = self.visit_empty_set_op_expr(
parameter.type.types, parameter.expand_op
@@ -2020,7 +2049,10 @@ class SQLCompiler(Compiled):
[parameter.type], parameter.expand_op
)
- elif isinstance(values[0], (tuple, list)):
+ elif (
+ isinstance(values[0], (tuple, list))
+ and not typ_dialect_impl._is_array
+ ):
to_update = [
("%s_%s_%s" % (name, i, j), value)
for i, tuple_element in enumerate(values, 1)
@@ -2299,14 +2331,27 @@ class SQLCompiler(Compiled):
impl = bindparam.type.dialect_impl(self.dialect)
if impl._has_bind_expression:
bind_expression = impl.bind_expression(bindparam)
- return self.process(
+ wrapped = self.process(
bind_expression,
skip_bind_expression=True,
within_columns_clause=within_columns_clause,
literal_binds=literal_binds,
literal_execute=literal_execute,
+ render_postcompile=render_postcompile,
**kwargs
)
+ if bindparam.expanding:
+ # for postcompile w/ expanding, move the "wrapped" part
+ # of this into the inside
+ m = re.match(
+ r"^(.*)\(\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
+ )
+ wrapped = "([POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
+ m.group(2),
+ m.group(1),
+ m.group(3),
+ )
+ return wrapped
if not literal_binds:
literal_execute = (
@@ -2489,12 +2534,13 @@ class SQLCompiler(Compiled):
positional_names.append(name)
else:
self.positiontup.append(name)
- elif not post_compile and not escaped_from:
- tr_reg = self._bind_translate
- if tr_reg.search(name):
- # i'd rather use translate() here but I can't get it to work
- # in all cases under Python 2, not worth it right now
- new_name = tr_reg.sub(
+ elif not escaped_from:
+
+ if _BIND_TRANSLATE_RE.search(name):
+ # not quite the translate use case as we want to
+ # also get a quick boolean if we even found
+ # unusual characters in the name
+ new_name = _BIND_TRANSLATE_RE.sub(
lambda m: _BIND_TRANSLATE_CHARS[m.group(0)],
name,
)