diff options
-rw-r--r-- | doc/build/changelog/unreleased_14/7177.rst | 22 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/asyncpg.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg2.py | 29 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 106 | ||||
-rw-r--r-- | test/dialect/postgresql/test_types.py | 223 | ||||
-rw-r--r-- | test/sql/test_external_traversal.py | 10 | ||||
-rw-r--r-- | test/sql/test_type_expressions.py | 23 |
9 files changed, 338 insertions, 93 deletions
diff --git a/doc/build/changelog/unreleased_14/7177.rst b/doc/build/changelog/unreleased_14/7177.rst new file mode 100644 index 000000000..7766c838e --- /dev/null +++ b/doc/build/changelog/unreleased_14/7177.rst @@ -0,0 +1,22 @@ +.. change:: + :tags: sql, bug, regression + :tickets: 7177 + + Fixed issue where "expanding IN" would fail to function correctly with + datatypes that use the :meth:`_types.TypeEngine.bind_expression` method, + where the method would need to be applied to each element of the + IN expression rather than the overall IN expression itself. + +.. change:: + :tags: postgresql, bug, regression + :tickets: 7177 + + Fixed issue where IN expressions against a series of array elements, as can + be done with PostgreSQL, would fail to function correctly due to multiple + issues within the "expanding IN" feature of SQLAlchemy Core that was + standardized in version 1.4. The psycopg2 dialect now makes use of the + :meth:`_types.TypeEngine.bind_expression` method with :class:`_types.ARRAY` + to portably apply the correct casts to elements. The asyncpg dialect was + not affected by this issue as it applies bind-level casts at the driver + level rather than at the compiler level. + diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index dc3da224c..3d195e691 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -362,7 +362,6 @@ class AsyncAdapt_asyncpg_cursor: if not self._inputsizes: return tuple("$%d" % idx for idx, _ in enumerate(params, 1)) else: - return tuple( "$%d::%s" % (idx, typ) if typ else "$%d" % idx for idx, typ in enumerate( diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 2e28b45ca..c1a2cf81d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2047,6 +2047,15 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): self.drop(bind=bind, checkfirst=checkfirst) +class _ColonCast(elements.Cast): + __visit_name__ = "colon_cast" + + def __init__(self, expression, type_): + self.type = type_ + self.clause = expression + self.typeclause = elements.TypeClause(type_) + + colspecs = { sqltypes.ARRAY: _array.ARRAY, sqltypes.Interval: INTERVAL, @@ -2102,6 +2111,12 @@ ischema_names = { class PGCompiler(compiler.SQLCompiler): + def visit_colon_cast(self, element, **kw): + return "%s::%s" % ( + element.clause._compiler_dispatch(self, **kw), + element.typeclause._compiler_dispatch(self, **kw), + ) + def visit_array(self, element, **kw): return "ARRAY[%s]" % self.visit_clauselist(element, **kw) diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index a71bdf760..4143dd041 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -473,6 +473,8 @@ import logging import re from uuid import UUID as _python_UUID +from .array import ARRAY as PGARRAY +from .base import _ColonCast from .base import _DECIMAL_TYPES from .base import _FLOAT_TYPES from .base import _INT_TYPES @@ -490,7 +492,6 @@ from ... import processors from ... import types as sqltypes from ... import util from ...engine import cursor as _cursor -from ...sql import elements from ...util import collections_abc @@ -556,6 +557,11 @@ class _PGHStore(HSTORE): return super(_PGHStore, self).result_processor(dialect, coltype) +class _PGARRAY(PGARRAY): + def bind_expression(self, bindvalue): + return _ColonCast(bindvalue, self) + + class _PGJSON(JSON): def result_processor(self, dialect, coltype): return None @@ -638,25 +644,7 @@ class PGExecutionContext_psycopg2(PGExecutionContext): class PGCompiler_psycopg2(PGCompiler): - def visit_bindparam(self, bindparam, skip_bind_expression=False, **kw): - - text = super(PGCompiler_psycopg2, self).visit_bindparam( - bindparam, skip_bind_expression=skip_bind_expression, **kw - ) - # note that if the type has a bind_expression(), we will get a - # double compile here - if not skip_bind_expression and ( - bindparam.type._is_array or bindparam.type._is_type_decorator - ): - typ = bindparam.type._unwrapped_dialect_impl(self.dialect) - - if typ._is_array: - text += "::%s" % ( - elements.TypeClause(typ)._compiler_dispatch( - self, skip_bind_expression=skip_bind_expression, **kw - ), - ) - return text + pass class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): @@ -713,6 +701,7 @@ class PGDialect_psycopg2(PGDialect): sqltypes.JSON: _PGJSON, JSONB: _PGJSONB, UUID: _PGUUID, + sqltypes.ARRAY: _PGARRAY, }, ) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index eff28e340..75bca1905 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1584,7 +1584,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): from the bind parameter's ``TypeEngine`` objects. This method only called by those dialects which require it, - currently cx_oracle. + currently cx_oracle, asyncpg and pg8000. """ if self.isddl or self.is_text: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index efcfe0e51..0cd568fcc 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -165,11 +165,8 @@ BIND_TEMPLATES = { "named": ":%(name)s", } -BIND_TRANSLATE = { - "pyformat": re.compile(r"[%\(\)]"), - "named": re.compile(r"[\:]"), -} -_BIND_TRANSLATE_CHARS = {"%": "P", "(": "A", ")": "Z", ":": "C"} +_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]") +_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__")) OPERATORS = { # binary @@ -746,7 +743,6 @@ class SQLCompiler(Compiled): self.positiontup = [] self._numeric_binds = dialect.paramstyle == "numeric" self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] - self._bind_translate = BIND_TRANSLATE.get(dialect.paramstyle, None) self.ctes = None @@ -1113,7 +1109,6 @@ class SQLCompiler(Compiled): N as a bound parameter. """ - if parameters is None: parameters = self.construct_params() @@ -1141,22 +1136,36 @@ class SQLCompiler(Compiled): replacement_expressions = {} to_update_sets = {} + # notes: + # *unescaped* parameter names in: + # self.bind_names, self.binds, self._bind_processors + # + # *escaped* parameter names in: + # construct_params(), replacement_expressions + for name in ( self.positiontup if self.positional else self.bind_names.values() ): + escaped_name = ( + self.escaped_bind_names.get(name, name) + if self.escaped_bind_names + else name + ) parameter = self.binds[name] if parameter in self.literal_execute_params: - if name not in replacement_expressions: - value = parameters.pop(name) + if escaped_name not in replacement_expressions: + value = parameters.pop(escaped_name) - replacement_expressions[name] = self.render_literal_bindparam( + replacement_expressions[ + escaped_name + ] = self.render_literal_bindparam( parameter, render_literal_value=value ) continue if parameter in self.post_compile_params: - if name in replacement_expressions: - to_update = to_update_sets[name] + if escaped_name in replacement_expressions: + to_update = to_update_sets[escaped_name] else: # we are removing the parameter from parameters # because it is a list value, which is not expected by @@ -1164,13 +1173,15 @@ class SQLCompiler(Compiled): # process it. the single name is being replaced with # individual numbered parameters for each value in the # param. - values = parameters.pop(name) + values = parameters.pop(escaped_name) leep = self._literal_execute_expanding_parameter - to_update, replacement_expr = leep(name, parameter, values) + to_update, replacement_expr = leep( + escaped_name, parameter, values + ) - to_update_sets[name] = to_update - replacement_expressions[name] = replacement_expr + to_update_sets[escaped_name] = to_update + replacement_expressions[escaped_name] = replacement_expr if not parameter.literal_execute: parameters.update(to_update) @@ -1200,10 +1211,24 @@ class SQLCompiler(Compiled): positiontup.append(name) def process_expanding(m): - return replacement_expressions[m.group(1)] + key = m.group(1) + expr = replacement_expressions[key] + + # if POSTCOMPILE included a bind_expression, render that + # around each element + if m.group(2): + tok = m.group(2).split("~~") + be_left, be_right = tok[1], tok[3] + expr = ", ".join( + "%s%s%s" % (be_left, exp, be_right) + for exp in expr.split(", ") + ) + return expr statement = re.sub( - r"\[POSTCOMPILE_(\S+)\]", process_expanding, self.string + r"\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]", + process_expanding, + self.string, ) expanded_state = ExpandedState( @@ -1963,8 +1988,10 @@ class SQLCompiler(Compiled): self, parameter, values ): + typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect) + if not values: - if parameter.type._is_tuple_type: + if typ_dialect_impl._is_tuple_type: replacement_expression = ( "VALUES " if self.dialect.tuple_in_values else "" ) + self.visit_empty_set_op_expr( @@ -1977,7 +2004,7 @@ class SQLCompiler(Compiled): ) elif isinstance(values[0], (tuple, list)): - assert parameter.type._is_tuple_type + assert typ_dialect_impl._is_tuple_type replacement_expression = ( "VALUES " if self.dialect.tuple_in_values else "" ) + ", ".join( @@ -1993,7 +2020,7 @@ class SQLCompiler(Compiled): for i, tuple_element in enumerate(values) ) else: - assert not parameter.type._is_tuple_type + assert not typ_dialect_impl._is_tuple_type replacement_expression = ", ".join( self.render_literal_value(value, parameter.type) for value in values @@ -2008,9 +2035,11 @@ class SQLCompiler(Compiled): parameter, values ) + typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect) + if not values: to_update = [] - if parameter.type._is_tuple_type: + if typ_dialect_impl._is_tuple_type: replacement_expression = self.visit_empty_set_op_expr( parameter.type.types, parameter.expand_op @@ -2020,7 +2049,10 @@ class SQLCompiler(Compiled): [parameter.type], parameter.expand_op ) - elif isinstance(values[0], (tuple, list)): + elif ( + isinstance(values[0], (tuple, list)) + and not typ_dialect_impl._is_array + ): to_update = [ ("%s_%s_%s" % (name, i, j), value) for i, tuple_element in enumerate(values, 1) @@ -2299,14 +2331,27 @@ class SQLCompiler(Compiled): impl = bindparam.type.dialect_impl(self.dialect) if impl._has_bind_expression: bind_expression = impl.bind_expression(bindparam) - return self.process( + wrapped = self.process( bind_expression, skip_bind_expression=True, within_columns_clause=within_columns_clause, literal_binds=literal_binds, literal_execute=literal_execute, + render_postcompile=render_postcompile, **kwargs ) + if bindparam.expanding: + # for postcompile w/ expanding, move the "wrapped" part + # of this into the inside + m = re.match( + r"^(.*)\(\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped + ) + wrapped = "([POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % ( + m.group(2), + m.group(1), + m.group(3), + ) + return wrapped if not literal_binds: literal_execute = ( @@ -2489,12 +2534,13 @@ class SQLCompiler(Compiled): positional_names.append(name) else: self.positiontup.append(name) - elif not post_compile and not escaped_from: - tr_reg = self._bind_translate - if tr_reg.search(name): - # i'd rather use translate() here but I can't get it to work - # in all cases under Python 2, not worth it right now - new_name = tr_reg.sub( + elif not escaped_from: + + if _BIND_TRANSLATE_RE.search(name): + # not quite the translate use case as we want to + # also get a quick boolean if we even found + # unusual characters in the name + new_name = _BIND_TRANSLATE_RE.sub( lambda m: _BIND_TRANSLATE_CHARS[m.group(0)], name, ) diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 92641fcc6..dd0a1be0f 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1198,6 +1198,45 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): postgresql.ARRAY(Unicode(30), dimensions=3), "VARCHAR(30)[][][]" ) + def test_array_in_enum_psycopg2_cast(self): + expr = column( + "x", + postgresql.ARRAY( + postgresql.ENUM("one", "two", "three", name="myenum") + ), + ).in_([["one", "two"], ["three", "four"]]) + + self.assert_compile( + expr, + "x IN ([POSTCOMPILE_x_1~~~~REPL~~::myenum[]~~])", + dialect=postgresql.psycopg2.dialect(), + ) + + self.assert_compile( + expr, + "x IN (%(x_1_1)s::myenum[], %(x_1_2)s::myenum[])", + dialect=postgresql.psycopg2.dialect(), + render_postcompile=True, + ) + + def test_array_in_str_psycopg2_cast(self): + expr = column("x", postgresql.ARRAY(String(15))).in_( + [["one", "two"], ["three", "four"]] + ) + + self.assert_compile( + expr, + "x IN ([POSTCOMPILE_x_1~~~~REPL~~::VARCHAR(15)[]~~])", + dialect=postgresql.psycopg2.dialect(), + ) + + self.assert_compile( + expr, + "x IN (%(x_1_1)s::VARCHAR(15)[], %(x_1_2)s::VARCHAR(15)[])", + dialect=postgresql.psycopg2.dialect(), + render_postcompile=True, + ) + def test_array_type_render_str_collate_multidim(self): self.assert_compile( postgresql.ARRAY(Unicode(30, collation="en_US"), dimensions=2), @@ -1457,11 +1496,79 @@ class ArrayRoundTripTest(object): t = Table( "t", metadata, - Column("data", sqltypes.ARRAY(String(50, collation="en_US"))), + Column("data", self.ARRAY(String(50, collation="en_US"))), ) t.create(connection) + @testing.fixture + def array_in_fixture(self, connection): + arrtable = self.tables.arrtable + + connection.execute( + arrtable.insert(), + [ + { + "id": 1, + "intarr": [1, 2, 3], + "strarr": [u"one", u"two", u"three"], + }, + { + "id": 2, + "intarr": [4, 5, 6], + "strarr": [u"four", u"five", u"six"], + }, + {"id": 3, "intarr": [1, 5], "strarr": [u"one", u"five"]}, + {"id": 4, "intarr": [], "strarr": []}, + ], + ) + + def test_array_in_int(self, array_in_fixture, connection): + """test #7177""" + + arrtable = self.tables.arrtable + + stmt = ( + select(arrtable.c.intarr) + .where(arrtable.c.intarr.in_([[1, 5], [4, 5, 6], [9, 10]])) + .order_by(arrtable.c.id) + ) + + eq_( + connection.execute(stmt).all(), + [ + ([4, 5, 6],), + ([1, 5],), + ], + ) + + def test_array_in_str(self, array_in_fixture, connection): + """test #7177""" + + arrtable = self.tables.arrtable + + stmt = ( + select(arrtable.c.strarr) + .where( + arrtable.c.strarr.in_( + [ + [u"one", u"five"], + [u"four", u"five", u"six"], + [u"nine", u"ten"], + ] + ) + ) + .order_by(arrtable.c.id) + ) + + eq_( + connection.execute(stmt).all(), + [ + (["four", "five", "six"],), + (["one", "five"],), + ], + ) + def test_array_agg(self, metadata, connection): values_table = Table("values", metadata, Column("value", Integer)) metadata.create_all(connection) @@ -2151,6 +2258,9 @@ class _ArrayOfEnum(TypeDecorator): impl = postgresql.ARRAY cache_ok = True + # note expanding logic is checking _is_array here so that has to + # translate through the TypeDecorator + def bind_expression(self, bindvalue): return sa.cast(bindvalue, self) @@ -2207,56 +2317,93 @@ class ArrayEnum(fixtures.TestBase): connection, ) - @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls") - @testing.combinations( - sqltypes.ARRAY, - postgresql.ARRAY, - (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")), - argnames="array_cls", - ) - def test_array_of_enums(self, array_cls, enum_cls, metadata, connection): - tbl = Table( - "enum_table", - self.metadata, - Column("id", Integer, primary_key=True), - Column( - "enum_col", - array_cls(enum_cls("foo", "bar", "baz", name="an_enum")), - ), - ) - - if util.py3k: - from enum import Enum - - class MyEnum(Enum): - a = "aaa" - b = "bbb" - c = "ccc" - - tbl.append_column( + @testing.fixture + def array_of_enum_fixture(self, metadata, connection): + def go(array_cls, enum_cls): + tbl = Table( + "enum_table", + metadata, + Column("id", Integer, primary_key=True), Column( - "pyenum_col", - array_cls(enum_cls(MyEnum)), + "enum_col", + array_cls(enum_cls("foo", "bar", "baz", name="an_enum")), ), ) + if util.py3k: + from enum import Enum + + class MyEnum(Enum): + a = "aaa" + b = "bbb" + c = "ccc" + + tbl.append_column( + Column( + "pyenum_col", + array_cls(enum_cls(MyEnum)), + ), + ) + else: + MyEnum = None - self.metadata.create_all(connection) + metadata.create_all(connection) + connection.execute( + tbl.insert(), + [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}], + ) + return tbl, MyEnum - connection.execute( - tbl.insert(), [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}] + yield go + + def _enum_combinations(fn): + return testing.combinations( + sqltypes.Enum, postgresql.ENUM, argnames="enum_cls" + )( + testing.combinations( + sqltypes.ARRAY, + postgresql.ARRAY, + (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")), + argnames="array_cls", + )(fn) ) + @_enum_combinations + def test_array_of_enums_roundtrip( + self, array_of_enum_fixture, connection, array_cls, enum_cls + ): + tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls) + + # test select back sel = select(tbl.c.enum_col).order_by(tbl.c.id) eq_( connection.execute(sel).fetchall(), [(["foo"],), (["foo", "bar"],)] ) - if util.py3k: - connection.execute(tbl.insert(), {"pyenum_col": [MyEnum.a]}) - sel = select(tbl.c.pyenum_col).order_by(tbl.c.id.desc()) - eq_(connection.scalar(sel), [MyEnum.a]) + @_enum_combinations + def test_array_of_enums_expanding_in( + self, array_of_enum_fixture, connection, array_cls, enum_cls + ): + tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls) + + # test select with WHERE using expanding IN against arrays + # #7177 + sel = ( + select(tbl.c.enum_col) + .where(tbl.c.enum_col.in_([["foo", "bar"], ["bar", "foo"]])) + .order_by(tbl.c.id) + ) + eq_(connection.execute(sel).fetchall(), [(["foo", "bar"],)]) + + @_enum_combinations + @testing.requires.python3 + def test_array_of_enums_native_roundtrip( + self, array_of_enum_fixture, connection, array_cls, enum_cls + ): + tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls) - self.metadata.drop_all(connection) + connection.execute(tbl.insert(), {"pyenum_col": [MyEnum.a]}) + sel = select(tbl.c.pyenum_col).order_by(tbl.c.id.desc()) + eq_(connection.scalar(sel), [MyEnum.a]) class ArrayJSON(fixtures.TestBase): diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index 3d1b4fe85..0d43448d5 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -188,7 +188,10 @@ class TraversalTest( ("clone",), ("pickle",), ("conv_to_unique"), ("none"), argnames="meth" ) @testing.combinations( - ("name with space",), ("name with [brackets]",), argnames="name" + ("name with space",), + ("name with [brackets]",), + ("name with~~tildes~~",), + argnames="name", ) def test_bindparam_key_proc_for_copies(self, meth, name): r"""test :ticket:`6249`. @@ -199,7 +202,7 @@ class TraversalTest( Currently, the bind key reg is:: - re.sub(r"[%\(\) \$]+", "_", body).strip("_") + re.sub(r"[%\(\) \$\[\]]", "_", name) and the compiler postcompile reg is:: @@ -218,7 +221,8 @@ class TraversalTest( expr.right.unique = False expr.right._convert_to_unique() - token = re.sub(r"[%\(\) \$]+", "_", name).strip("_") + token = re.sub(r"[%\(\) \$\[\]]", "_", name) + self.assert_compile( expr, '"%(name)s" IN (:%(token)s_1_1, ' diff --git a/test/sql/test_type_expressions.py b/test/sql/test_type_expressions.py index 51ee0ae62..adcaef39c 100644 --- a/test/sql/test_type_expressions.py +++ b/test/sql/test_type_expressions.py @@ -182,6 +182,29 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): "test_table WHERE test_table.y = lower(:y_1)", ) + def test_in_binds(self): + table = self._fixture() + + self.assert_compile( + select(table).where( + table.c.y.in_(["hi", "there", "some", "expr"]) + ), + "SELECT test_table.x, lower(test_table.y) AS y FROM " + "test_table WHERE test_table.y IN " + "([POSTCOMPILE_y_1~~lower(~~REPL~~)~~])", + render_postcompile=False, + ) + + self.assert_compile( + select(table).where( + table.c.y.in_(["hi", "there", "some", "expr"]) + ), + "SELECT test_table.x, lower(test_table.y) AS y FROM " + "test_table WHERE test_table.y IN " + "(lower(:y_1_1), lower(:y_1_2), lower(:y_1_3), lower(:y_1_4))", + render_postcompile=True, + ) + def test_dialect(self): table = self._fixture() dialect = self._dialect_level_fixture() |