diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/asyncpg.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg2.py | 29 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 106 |
5 files changed, 101 insertions, 52 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index dc3da224c..3d195e691 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -362,7 +362,6 @@ class AsyncAdapt_asyncpg_cursor: if not self._inputsizes: return tuple("$%d" % idx for idx, _ in enumerate(params, 1)) else: - return tuple( "$%d::%s" % (idx, typ) if typ else "$%d" % idx for idx, typ in enumerate( diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 2e28b45ca..c1a2cf81d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2047,6 +2047,15 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): self.drop(bind=bind, checkfirst=checkfirst) +class _ColonCast(elements.Cast): + __visit_name__ = "colon_cast" + + def __init__(self, expression, type_): + self.type = type_ + self.clause = expression + self.typeclause = elements.TypeClause(type_) + + colspecs = { sqltypes.ARRAY: _array.ARRAY, sqltypes.Interval: INTERVAL, @@ -2102,6 +2111,12 @@ ischema_names = { class PGCompiler(compiler.SQLCompiler): + def visit_colon_cast(self, element, **kw): + return "%s::%s" % ( + element.clause._compiler_dispatch(self, **kw), + element.typeclause._compiler_dispatch(self, **kw), + ) + def visit_array(self, element, **kw): return "ARRAY[%s]" % self.visit_clauselist(element, **kw) diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index a71bdf760..4143dd041 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -473,6 +473,8 @@ import logging import re from uuid import UUID as _python_UUID +from .array import ARRAY as PGARRAY +from .base import _ColonCast from .base import _DECIMAL_TYPES from .base import _FLOAT_TYPES from .base import _INT_TYPES @@ -490,7 +492,6 @@ from ... import processors from ... import types as sqltypes from ... import util from ...engine import cursor as _cursor -from ...sql import elements from ...util import collections_abc @@ -556,6 +557,11 @@ class _PGHStore(HSTORE): return super(_PGHStore, self).result_processor(dialect, coltype) +class _PGARRAY(PGARRAY): + def bind_expression(self, bindvalue): + return _ColonCast(bindvalue, self) + + class _PGJSON(JSON): def result_processor(self, dialect, coltype): return None @@ -638,25 +644,7 @@ class PGExecutionContext_psycopg2(PGExecutionContext): class PGCompiler_psycopg2(PGCompiler): - def visit_bindparam(self, bindparam, skip_bind_expression=False, **kw): - - text = super(PGCompiler_psycopg2, self).visit_bindparam( - bindparam, skip_bind_expression=skip_bind_expression, **kw - ) - # note that if the type has a bind_expression(), we will get a - # double compile here - if not skip_bind_expression and ( - bindparam.type._is_array or bindparam.type._is_type_decorator - ): - typ = bindparam.type._unwrapped_dialect_impl(self.dialect) - - if typ._is_array: - text += "::%s" % ( - elements.TypeClause(typ)._compiler_dispatch( - self, skip_bind_expression=skip_bind_expression, **kw - ), - ) - return text + pass class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): @@ -713,6 +701,7 @@ class PGDialect_psycopg2(PGDialect): sqltypes.JSON: _PGJSON, JSONB: _PGJSONB, UUID: _PGUUID, + sqltypes.ARRAY: _PGARRAY, }, ) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index eff28e340..75bca1905 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1584,7 +1584,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): from the bind parameter's ``TypeEngine`` objects. This method only called by those dialects which require it, - currently cx_oracle. + currently cx_oracle, asyncpg and pg8000. """ if self.isddl or self.is_text: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index efcfe0e51..0cd568fcc 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -165,11 +165,8 @@ BIND_TEMPLATES = { "named": ":%(name)s", } -BIND_TRANSLATE = { - "pyformat": re.compile(r"[%\(\)]"), - "named": re.compile(r"[\:]"), -} -_BIND_TRANSLATE_CHARS = {"%": "P", "(": "A", ")": "Z", ":": "C"} +_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]") +_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__")) OPERATORS = { # binary @@ -746,7 +743,6 @@ class SQLCompiler(Compiled): self.positiontup = [] self._numeric_binds = dialect.paramstyle == "numeric" self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] - self._bind_translate = BIND_TRANSLATE.get(dialect.paramstyle, None) self.ctes = None @@ -1113,7 +1109,6 @@ class SQLCompiler(Compiled): N as a bound parameter. """ - if parameters is None: parameters = self.construct_params() @@ -1141,22 +1136,36 @@ class SQLCompiler(Compiled): replacement_expressions = {} to_update_sets = {} + # notes: + # *unescaped* parameter names in: + # self.bind_names, self.binds, self._bind_processors + # + # *escaped* parameter names in: + # construct_params(), replacement_expressions + for name in ( self.positiontup if self.positional else self.bind_names.values() ): + escaped_name = ( + self.escaped_bind_names.get(name, name) + if self.escaped_bind_names + else name + ) parameter = self.binds[name] if parameter in self.literal_execute_params: - if name not in replacement_expressions: - value = parameters.pop(name) + if escaped_name not in replacement_expressions: + value = parameters.pop(escaped_name) - replacement_expressions[name] = self.render_literal_bindparam( + replacement_expressions[ + escaped_name + ] = self.render_literal_bindparam( parameter, render_literal_value=value ) continue if parameter in self.post_compile_params: - if name in replacement_expressions: - to_update = to_update_sets[name] + if escaped_name in replacement_expressions: + to_update = to_update_sets[escaped_name] else: # we are removing the parameter from parameters # because it is a list value, which is not expected by @@ -1164,13 +1173,15 @@ class SQLCompiler(Compiled): # process it. the single name is being replaced with # individual numbered parameters for each value in the # param. - values = parameters.pop(name) + values = parameters.pop(escaped_name) leep = self._literal_execute_expanding_parameter - to_update, replacement_expr = leep(name, parameter, values) + to_update, replacement_expr = leep( + escaped_name, parameter, values + ) - to_update_sets[name] = to_update - replacement_expressions[name] = replacement_expr + to_update_sets[escaped_name] = to_update + replacement_expressions[escaped_name] = replacement_expr if not parameter.literal_execute: parameters.update(to_update) @@ -1200,10 +1211,24 @@ class SQLCompiler(Compiled): positiontup.append(name) def process_expanding(m): - return replacement_expressions[m.group(1)] + key = m.group(1) + expr = replacement_expressions[key] + + # if POSTCOMPILE included a bind_expression, render that + # around each element + if m.group(2): + tok = m.group(2).split("~~") + be_left, be_right = tok[1], tok[3] + expr = ", ".join( + "%s%s%s" % (be_left, exp, be_right) + for exp in expr.split(", ") + ) + return expr statement = re.sub( - r"\[POSTCOMPILE_(\S+)\]", process_expanding, self.string + r"\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]", + process_expanding, + self.string, ) expanded_state = ExpandedState( @@ -1963,8 +1988,10 @@ class SQLCompiler(Compiled): self, parameter, values ): + typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect) + if not values: - if parameter.type._is_tuple_type: + if typ_dialect_impl._is_tuple_type: replacement_expression = ( "VALUES " if self.dialect.tuple_in_values else "" ) + self.visit_empty_set_op_expr( @@ -1977,7 +2004,7 @@ class SQLCompiler(Compiled): ) elif isinstance(values[0], (tuple, list)): - assert parameter.type._is_tuple_type + assert typ_dialect_impl._is_tuple_type replacement_expression = ( "VALUES " if self.dialect.tuple_in_values else "" ) + ", ".join( @@ -1993,7 +2020,7 @@ class SQLCompiler(Compiled): for i, tuple_element in enumerate(values) ) else: - assert not parameter.type._is_tuple_type + assert not typ_dialect_impl._is_tuple_type replacement_expression = ", ".join( self.render_literal_value(value, parameter.type) for value in values @@ -2008,9 +2035,11 @@ class SQLCompiler(Compiled): parameter, values ) + typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect) + if not values: to_update = [] - if parameter.type._is_tuple_type: + if typ_dialect_impl._is_tuple_type: replacement_expression = self.visit_empty_set_op_expr( parameter.type.types, parameter.expand_op @@ -2020,7 +2049,10 @@ class SQLCompiler(Compiled): [parameter.type], parameter.expand_op ) - elif isinstance(values[0], (tuple, list)): + elif ( + isinstance(values[0], (tuple, list)) + and not typ_dialect_impl._is_array + ): to_update = [ ("%s_%s_%s" % (name, i, j), value) for i, tuple_element in enumerate(values, 1) @@ -2299,14 +2331,27 @@ class SQLCompiler(Compiled): impl = bindparam.type.dialect_impl(self.dialect) if impl._has_bind_expression: bind_expression = impl.bind_expression(bindparam) - return self.process( + wrapped = self.process( bind_expression, skip_bind_expression=True, within_columns_clause=within_columns_clause, literal_binds=literal_binds, literal_execute=literal_execute, + render_postcompile=render_postcompile, **kwargs ) + if bindparam.expanding: + # for postcompile w/ expanding, move the "wrapped" part + # of this into the inside + m = re.match( + r"^(.*)\(\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped + ) + wrapped = "([POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % ( + m.group(2), + m.group(1), + m.group(3), + ) + return wrapped if not literal_binds: literal_execute = ( @@ -2489,12 +2534,13 @@ class SQLCompiler(Compiled): positional_names.append(name) else: self.positiontup.append(name) - elif not post_compile and not escaped_from: - tr_reg = self._bind_translate - if tr_reg.search(name): - # i'd rather use translate() here but I can't get it to work - # in all cases under Python 2, not worth it right now - new_name = tr_reg.sub( + elif not escaped_from: + + if _BIND_TRANSLATE_RE.search(name): + # not quite the translate use case as we want to + # also get a quick boolean if we even found + # unusual characters in the name + new_name = _BIND_TRANSLATE_RE.sub( lambda m: _BIND_TRANSLATE_CHARS[m.group(0)], name, ) |