summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFederico Caselli <cfederico87@gmail.com>2021-03-10 23:54:52 +0100
committerMike Bayer <mike_mp@zzzcomputing.com>2021-03-15 20:11:20 -0400
commitdfa1d3b28f1a0abf1e11c76a94f7a65bf98d29af (patch)
tree975a06018edcc9a9fa75b709f40698842a82e494
parent28b0b6515af26ee3ba09600a8212849b2dae0699 (diff)
downloadsqlalchemy-dfa1d3b28f1a0abf1e11c76a94f7a65bf98d29af.tar.gz
CAST the elements in ARRAYs when using psycopg2
Adjusted the psycopg2 dialect to emit an explicit PostgreSQL-style cast for bound parameters that contain ARRAY elements. This allows the full range of datatypes to function correctly within arrays. The asyncpg dialect already generated these internal casts in the final statement. This also includes support for array slice updates as well as the PostgreSQL-specific :meth:`_postgresql.ARRAY.contains` method. Fixes: #6023 Change-Id: Ia7519ac4371a635f05ac69a3a4d0f4e6d2f04cad
-rw-r--r--doc/build/changelog/unreleased_13/6023.rst10
-rw-r--r--lib/sqlalchemy/dialects/postgresql/array.py11
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py16
-rw-r--r--lib/sqlalchemy/sql/crud.py9
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py2
-rw-r--r--lib/sqlalchemy/sql/type_api.py1
-rw-r--r--lib/sqlalchemy/testing/__init__.py1
-rw-r--r--lib/sqlalchemy/testing/config.py5
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py4
-rw-r--r--lib/sqlalchemy/testing/schema.py41
-rw-r--r--test/dialect/postgresql/test_compiler.py60
-rw-r--r--test/dialect/postgresql/test_types.py271
-rw-r--r--test/sql/test_types.py20
13 files changed, 409 insertions, 42 deletions
diff --git a/doc/build/changelog/unreleased_13/6023.rst b/doc/build/changelog/unreleased_13/6023.rst
new file mode 100644
index 000000000..2cfe88567
--- /dev/null
+++ b/doc/build/changelog/unreleased_13/6023.rst
@@ -0,0 +1,10 @@
+.. change::
+ :tags: bug, types, postgresql
+ :tickets: 6023
+
+ Adjusted the psycopg2 dialect to emit an explicit PostgreSQL-style cast for
+ bound parameters that contain ARRAY elements. This allows the full range of
+ datatypes to function correctly within arrays. The asyncpg dialect already
+ generated these internal casts in the final statement. This also includes
+ support for array slice updates as well as the PostgreSQL-specific
+ :meth:`_postgresql.ARRAY.contains` method. \ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py
index 91bb89ea9..c2d99845f 100644
--- a/lib/sqlalchemy/dialects/postgresql/array.py
+++ b/lib/sqlalchemy/dialects/postgresql/array.py
@@ -331,12 +331,6 @@ class ARRAY(sqltypes.ARRAY):
)
@util.memoized_property
- def _require_cast(self):
- return self._against_native_enum or isinstance(
- self.item_type, sqltypes.JSON
- )
-
- @util.memoized_property
def _against_native_enum(self):
return (
isinstance(self.item_type, sqltypes.Enum)
@@ -344,10 +338,7 @@ class ARRAY(sqltypes.ARRAY):
)
def bind_expression(self, bindvalue):
- if self._require_cast:
- return expression.cast(bindvalue, self)
- else:
- return bindvalue
+ return bindvalue
def bind_processor(self, dialect):
item_proc = self.item_type.dialect_impl(dialect).bind_processor(
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
index a52eacd8b..1969eb844 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -450,6 +450,7 @@ from ... import processors
from ... import types as sqltypes
from ... import util
from ...engine import cursor as _cursor
+from ...sql import elements
from ...util import collections_abc
@@ -597,7 +598,20 @@ class PGExecutionContext_psycopg2(PGExecutionContext):
class PGCompiler_psycopg2(PGCompiler):
- pass
+ def visit_bindparam(self, bindparam, skip_bind_expression=False, **kw):
+
+ text = super(PGCompiler_psycopg2, self).visit_bindparam(
+ bindparam, skip_bind_expression=skip_bind_expression, **kw
+ )
+ # note that if the type has a bind_expression(), we will get a
+ # double compile here
+ if not skip_bind_expression and bindparam.type._is_array:
+ text += "::%s" % (
+ elements.TypeClause(bindparam.type)._compiler_dispatch(
+ self, skip_bind_expression=skip_bind_expression, **kw
+ ),
+ )
+ return text
class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 8b4950aa3..174a1c131 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -943,6 +943,9 @@ def _get_stmt_parameter_tuples_params(
# add it to values() in an "as-is" state,
# coercing right side to bound param
+ # note one of the main use cases for this is array slice
+ # updates on PostgreSQL, as the left side is also an expression.
+
col_expr = compiler.process(
k, include_table=compile_state.include_table_with_column_exprs
)
@@ -952,6 +955,12 @@ def _get_stmt_parameter_tuples_params(
elements.BindParameter(None, v, type_=k.type), **kw
)
else:
+ if v._is_bind_parameter and v.type._isnull:
+ # either unique parameter, or other bound parameters that
+ # were passed in directly
+ # set type to that of the column unconditionally
+ v = v._with_binary_element_type(k.type)
+
v = compiler.process(v.self_group(), **kw)
values.append((k, col_expr, v))
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index d075ef77d..816423d1b 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -2675,6 +2675,8 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
__visit_name__ = "ARRAY"
+ _is_array = True
+
zero_indexes = False
"""If True, Python zero-based indexes should be interpreted as one-based
on the SQL expression side."""
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 46751cb22..9752750c5 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -47,6 +47,7 @@ class TypeEngine(Traversible):
_isnull = False
_is_tuple_type = False
_is_table_value = False
+ _is_array = False
class Comparator(operators.ColumnOperators):
"""Base class for custom comparison operations defined at the
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index adbb8f643..a3ce24226 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -42,6 +42,7 @@ from .assertions import startswith_
from .assertions import uses_deprecated
from .config import async_test
from .config import combinations
+from .config import combinations_list
from .config import db
from .config import fixture
from .config import requirements as requires
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
index 750671f9f..6589e5097 100644
--- a/lib/sqlalchemy/testing/config.py
+++ b/lib/sqlalchemy/testing/config.py
@@ -89,6 +89,11 @@ def combinations(*comb, **kw):
return _fixture_functions.combinations(*comb, **kw)
+def combinations_list(arg_iterable, **kw):
+ "As combination, but takes a single iterable"
+ return combinations(*arg_iterable, **kw)
+
+
def fixture(*arg, **kw):
return _fixture_functions.fixture(*arg, **kw)
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index 4eaaecebb..388d71c73 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -578,7 +578,9 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
"i": lambda obj: obj,
"r": repr,
"s": str,
- "n": operator.attrgetter("__name__"),
+ "n": lambda obj: obj.__name__
+ if hasattr(obj, "__name__")
+ else type(obj).__name__,
}
def combinations(self, *arg_sets, **kw):
diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py
index 22b1f7b77..fee021cff 100644
--- a/lib/sqlalchemy/testing/schema.py
+++ b/lib/sqlalchemy/testing/schema.py
@@ -5,11 +5,14 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+import sys
+
from . import config
from . import exclusions
from .. import event
from .. import schema
from .. import types as sqltypes
+from ..util import OrderedDict
__all__ = ["Table", "Column"]
@@ -162,3 +165,41 @@ def _truncate_name(dialect, name):
)
else:
return name
+
+
+def pep435_enum(name):
+ # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
+ __members__ = OrderedDict()
+
+ def __init__(self, name, value, alias=None):
+ self.name = name
+ self.value = value
+ self.__members__[name] = self
+ value_to_member[value] = self
+ setattr(self.__class__, name, self)
+ if alias:
+ self.__members__[alias] = self
+ setattr(self.__class__, alias, self)
+
+ value_to_member = {}
+
+ @classmethod
+ def get(cls, value):
+ return value_to_member[value]
+
+ someenum = type(
+ name,
+ (object,),
+ {"__members__": __members__, "__init__": __init__, "get": get},
+ )
+
+ # getframe() trick for pickling I don't understand courtesy
+ # Python namedtuple()
+ try:
+ module = sys._getframe(1).f_globals.get("__name__", "__main__")
+ except (AttributeError, ValueError):
+ pass
+ if module is not None:
+ someenum.__module__ = module
+
+ return someenum
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index 44f8c9398..4b2004a5f 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -1,5 +1,4 @@
# coding: utf-8
-
from sqlalchemy import and_
from sqlalchemy import BigInteger
from sqlalchemy import bindparam
@@ -36,6 +35,8 @@ from sqlalchemy.dialects.postgresql import array_agg as pg_array_agg
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.dialects.postgresql import TSRANGE
+from sqlalchemy.dialects.postgresql.base import PGDialect
+from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.orm import aliased
from sqlalchemy.orm import mapper
from sqlalchemy.orm import Session
@@ -1351,13 +1352,28 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
)
self.assert_compile(
- c.contains([1]), "x @> %(x_1)s", checkparams={"x_1": [1]}
+ c.contains([1]),
+ "x @> %(x_1)s::INTEGER[]",
+ checkparams={"x_1": [1]},
+ dialect=PGDialect_psycopg2(),
+ )
+ self.assert_compile(
+ c.contained_by([2]),
+ "x <@ %(x_1)s::INTEGER[]",
+ checkparams={"x_1": [2]},
+ dialect=PGDialect_psycopg2(),
)
self.assert_compile(
- c.contained_by([2]), "x <@ %(x_1)s", checkparams={"x_1": [2]}
+ c.contained_by([2]),
+ "x <@ %(x_1)s",
+ checkparams={"x_1": [2]},
+ dialect=PGDialect(),
)
self.assert_compile(
- c.overlap([3]), "x && %(x_1)s", checkparams={"x_1": [3]}
+ c.overlap([3]),
+ "x && %(x_1)s::INTEGER[]",
+ checkparams={"x_1": [3]},
+ dialect=PGDialect_psycopg2(),
)
self.assert_compile(
postgresql.Any(4, c),
@@ -1405,7 +1421,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
checkparams={"param_1": 7},
)
- def _test_array_zero_indexes(self, zero_indexes):
+ @testing.combinations((True,), (False,))
+ def test_array_zero_indexes(self, zero_indexes):
c = Column("x", postgresql.ARRAY(Integer, zero_indexes=zero_indexes))
add_one = 1 if zero_indexes else 0
@@ -1443,12 +1460,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
},
)
- def test_array_zero_indexes_true(self):
- self._test_array_zero_indexes(True)
-
- def test_array_zero_indexes_false(self):
- self._test_array_zero_indexes(False)
-
def test_array_literal_type(self):
isinstance(postgresql.array([1, 2]).type, postgresql.ARRAY)
is_(postgresql.array([1, 2]).type.item_type._type_affinity, Integer)
@@ -1576,6 +1587,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
"%(param_2)s, %(param_3)s])",
)
+ def test_update_array(self):
+ m = MetaData()
+ t = Table("t", m, Column("data", postgresql.ARRAY(Integer)))
+ self.assert_compile(
+ t.update().values({t.c.data: [1, 3, 4]}),
+ "UPDATE t SET data=%(data)s::INTEGER[]",
+ checkparams={"data": [1, 3, 4]},
+ )
+
def test_update_array_element(self):
m = MetaData()
t = Table("t", m, Column("data", postgresql.ARRAY(Integer)))
@@ -1588,10 +1608,22 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
def test_update_array_slice(self):
m = MetaData()
t = Table("t", m, Column("data", postgresql.ARRAY(Integer)))
+
+ # psycopg2-specific, has a cast
+ self.assert_compile(
+ t.update().values({t.c.data[2:5]: [2, 3, 4]}),
+ "UPDATE t SET data[%(data_1)s:%(data_2)s]="
+ "%(param_1)s::INTEGER[]",
+ checkparams={"param_1": [2, 3, 4], "data_2": 5, "data_1": 2},
+ dialect=PGDialect_psycopg2(),
+ )
+
+ # default dialect does not, as DBAPIs may be doing this for us
self.assert_compile(
- t.update().values({t.c.data[2:5]: 2}),
- "UPDATE t SET data[%(data_1)s:%(data_2)s]=%(param_1)s",
- checkparams={"param_1": 2, "data_2": 5, "data_1": 2},
+ t.update().values({t.c.data[2:5]: [2, 3, 4]}),
+ "UPDATE t SET data[%s:%s]=" "%s",
+ checkparams={"param_1": [2, 3, 4], "data_2": 5, "data_1": 2},
+ dialect=PGDialect(paramstyle="format"),
)
def test_from_only(self):
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index b24132f69..59a6c0c85 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -58,9 +58,15 @@ from sqlalchemy.testing.assertions import ComparesTables
from sqlalchemy.testing.assertions import eq_
from sqlalchemy.testing.assertions import is_
from sqlalchemy.testing.assertsql import RegexSQL
+from sqlalchemy.testing.schema import pep435_enum
from sqlalchemy.testing.suite import test_types as suite
from sqlalchemy.testing.util import round_decimal
+try:
+ import enum
+except ImportError:
+ enum = None
+
class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
__only_on__ = "postgresql"
@@ -1307,6 +1313,12 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
is_(expr.type.item_type.__class__, element_type)
+AnEnum = pep435_enum("AnEnum")
+AnEnum("Foo", 1)
+AnEnum("Bar", 2)
+AnEnum("Baz", 3)
+
+
class ArrayRoundTripTest(object):
__only_on__ = "postgresql"
@@ -1699,6 +1711,248 @@ class ArrayRoundTripTest(object):
t.drop(connection)
eq_(inspect(connection).get_enums(), [])
+ def _type_combinations(exclude_json=False):
+ def str_values(x):
+ return ["one", "two: %s" % x, "three", "four", "five"]
+
+ def unicode_values(x):
+ return [
+ util.u("réveillé"),
+ util.u("drôle"),
+ util.u("S’il %s" % x),
+ util.u("🐍 %s" % x),
+ util.u("« S’il vous"),
+ ]
+
+ def json_values(x):
+ return [
+ 1,
+ {"a": x},
+ {"b": [1, 2, 3]},
+ ["d", "e", "f"],
+ {"struct": True, "none": None},
+ ]
+
+ def binary_values(x):
+ return [v.encode("utf-8") for v in unicode_values(x)]
+
+ def enum_values(x):
+ return [
+ AnEnum.Foo,
+ AnEnum.Baz,
+ AnEnum.get(x),
+ AnEnum.Baz,
+ AnEnum.Foo,
+ ]
+
+ class inet_str(str):
+ def __eq__(self, other):
+ return str(self) == str(other)
+
+ def __ne__(self, other):
+ return str(self) != str(other)
+
+ class money_str(str):
+ def __eq__(self, other):
+ comp = re.sub(r"[^\d\.]", "", other)
+ return float(self) == float(comp)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ elements = [
+ (sqltypes.Integer, lambda x: [1, x, 3, 4, 5]),
+ (sqltypes.Text, str_values),
+ (sqltypes.String, str_values),
+ (sqltypes.Unicode, unicode_values),
+ (postgresql.JSONB, json_values),
+ (sqltypes.Boolean, lambda x: [False] + [True] * x),
+ (
+ sqltypes.LargeBinary,
+ binary_values,
+ ),
+ (
+ postgresql.BYTEA,
+ binary_values,
+ ),
+ (
+ postgresql.INET,
+ lambda x: [
+ inet_str("1.1.1.1"),
+ inet_str("{0}.{0}.{0}.{0}".format(x)),
+ inet_str("192.168.1.1"),
+ inet_str("10.1.2.25"),
+ inet_str("192.168.22.5"),
+ ],
+ ),
+ (
+ postgresql.CIDR,
+ lambda x: [
+ inet_str("10.0.0.0/8"),
+ inet_str("%s.0.0.0/8" % x),
+ inet_str("192.168.1.0/24"),
+ inet_str("192.168.0.0/16"),
+ inet_str("192.168.1.25/32"),
+ ],
+ ),
+ (
+ sqltypes.Date,
+ lambda x: [
+ datetime.date(2020, 5, x),
+ datetime.date(2020, 7, 12),
+ datetime.date(2018, 12, 15),
+ datetime.date(2009, 1, 5),
+ datetime.date(2021, 3, 18),
+ ],
+ ),
+ (
+ sqltypes.DateTime,
+ lambda x: [
+ datetime.datetime(2020, 5, x, 2, 15, 0),
+ datetime.datetime(2020, 7, 12, 15, 30, x),
+ datetime.datetime(2018, 12, 15, 3, x, 25),
+ datetime.datetime(2009, 1, 5, 12, 45, x),
+ datetime.datetime(2021, 3, 18, 17, 1, 0),
+ ],
+ ),
+ (
+ sqltypes.Numeric,
+ lambda x: [
+ decimal.Decimal("45.10"),
+ decimal.Decimal(x),
+ decimal.Decimal(".03242"),
+ decimal.Decimal("532.3532"),
+ decimal.Decimal("95503.23"),
+ ],
+ ),
+ (
+ postgresql.MONEY,
+ lambda x: [
+ money_str("2"),
+ money_str("%s" % (5 + x)),
+ money_str("50.25"),
+ money_str("18.99"),
+ money_str("15.%s" % x),
+ ],
+ testing.skip_if(
+ "postgresql+psycopg2", "this is a psycopg2 bug"
+ ),
+ ),
+ (
+ postgresql.HSTORE,
+ lambda x: [
+ {"a": "1"},
+ {"b": "%s" % x},
+ {"c": "3"},
+ {"c": "c2"},
+ {"d": "e"},
+ ],
+ testing.requires.hstore,
+ ),
+ (sqltypes.Enum(AnEnum, native_enum=True), enum_values),
+ (sqltypes.Enum(AnEnum, native_enum=False), enum_values),
+ ]
+
+ if not exclude_json:
+ elements.extend(
+ [
+ (sqltypes.JSON, json_values),
+ (postgresql.JSON, json_values),
+ ]
+ )
+
+ return testing.combinations_list(
+ elements, argnames="type_,gen", id_="na"
+ )
+
+ @classmethod
+ def _cls_type_combinations(cls, **kw):
+ return ArrayRoundTripTest.__dict__["_type_combinations"](**kw)
+
+ @testing.fixture
+ def type_specific_fixture(self, metadata, connection, type_):
+ meta = MetaData()
+ table = Table(
+ "foo",
+ meta,
+ Column("id", Integer),
+ Column("bar", self.ARRAY(type_)),
+ )
+
+ meta.create_all(connection)
+
+ def go(gen):
+ connection.execute(
+ table.insert(),
+ [{"id": 1, "bar": gen(1)}, {"id": 2, "bar": gen(2)}],
+ )
+ return table
+
+ return go
+
+ @_type_combinations()
+ def test_type_specific_value_select(
+ self, type_specific_fixture, connection, type_, gen
+ ):
+ table = type_specific_fixture(gen)
+
+ rows = connection.execute(
+ select(table.c.bar).order_by(table.c.id)
+ ).all()
+
+ eq_(rows, [(gen(1),), (gen(2),)])
+
+ @_type_combinations()
+ def test_type_specific_value_update(
+ self, type_specific_fixture, connection, type_, gen
+ ):
+ table = type_specific_fixture(gen)
+
+ new_gen = gen(3)
+ connection.execute(
+ table.update().where(table.c.id == 2).values(bar=new_gen)
+ )
+
+ eq_(
+ new_gen,
+ connection.scalar(select(table.c.bar).where(table.c.id == 2)),
+ )
+
+ @_type_combinations()
+ def test_type_specific_slice_update(
+ self, type_specific_fixture, connection, type_, gen
+ ):
+ table = type_specific_fixture(gen)
+
+ new_gen = gen(3)
+
+ connection.execute(
+ table.update()
+ .where(table.c.id == 2)
+ .values({table.c.bar[1:3]: new_gen[1:4]})
+ )
+
+ rows = connection.execute(
+ select(table.c.bar).order_by(table.c.id)
+ ).all()
+
+ sliced_gen = gen(2)
+ sliced_gen[0:3] = new_gen[1:4]
+
+ eq_(rows, [(gen(1),), (sliced_gen,)])
+
+ @_type_combinations(exclude_json=True)
+ def test_type_specific_value_delete(
+ self, type_specific_fixture, connection, type_, gen
+ ):
+ table = type_specific_fixture(gen)
+
+ new_gen = gen(2)
+
+ connection.execute(table.delete().where(table.c.bar == new_gen))
+
+ eq_(connection.scalar(select(func.count(table.c.id))), 1)
+
class CoreArrayRoundTripTest(
ArrayRoundTripTest, fixtures.TablesTest, AssertsExecutionResults
@@ -1712,6 +1966,23 @@ class PGArrayRoundTripTest(
):
ARRAY = postgresql.ARRAY
+ @ArrayRoundTripTest._cls_type_combinations(exclude_json=True)
+ def test_type_specific_contains(
+ self, type_specific_fixture, connection, type_, gen
+ ):
+ table = type_specific_fixture(gen)
+
+ connection.execute(
+ table.insert(),
+ [{"id": 1, "bar": gen(1)}, {"id": 2, "bar": gen(2)}],
+ )
+
+ id_, value = connection.execute(
+ select(table).where(table.c.bar.contains(gen(1)))
+ ).first()
+ eq_(id_, 1)
+ eq_(value, gen(1))
+
@testing.combinations(
(set,), (list,), (lambda elem: (x for x in elem),), argnames="struct"
)
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 0a8541467..9a5b9d274 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -86,10 +86,10 @@ from sqlalchemy.testing import is_not
from sqlalchemy.testing import mock
from sqlalchemy.testing import pickleable
from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import pep435_enum
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.util import picklers
from sqlalchemy.testing.util import round_decimal
-from sqlalchemy.util import OrderedDict
from sqlalchemy.util import u
@@ -1579,21 +1579,7 @@ class UnicodeTest(fixtures.TestBase):
class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
__backend__ = True
- class SomeEnum(object):
- # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
- __members__ = OrderedDict()
-
- def __init__(self, name, value, alias=None):
- self.name = name
- self.value = value
- self.__members__[name] = self
- setattr(self.__class__, name, self)
- if alias:
- self.__members__[alias] = self
- setattr(self.__class__, alias, self)
-
- class SomeOtherEnum(SomeEnum):
- __members__ = OrderedDict()
+ SomeEnum = pep435_enum("SomeEnum")
one = SomeEnum("one", 1)
two = SomeEnum("two", 2)
@@ -1601,6 +1587,8 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
a_member = SomeEnum("AMember", "a")
b_member = SomeEnum("BMember", "b")
+ SomeOtherEnum = pep435_enum("SomeOtherEnum")
+
other_one = SomeOtherEnum("one", 1)
other_two = SomeOtherEnum("two", 2)
other_three = SomeOtherEnum("three", 3)