summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2017-04-11 10:26:38 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2017-04-11 10:49:30 -0400
commit1b463058e3282c73d0fb361f78e96ecaa23ce9f4 (patch)
tree3dc225d6233db6c15c57a5f6941229f92bb101d6
parent5b81dbcfa3888de65fc33b247353b38488199b00 (diff)
downloadsqlalchemy-1b463058e3282c73d0fb361f78e96ecaa23ce9f4.tar.gz
Set up base ARRAY to be compatible with postgresql.ARRAY.
For some reason, when ARRAY was added to the base it was never linked to postgresql.ARRAY. Link the two types and also make base ARRAY the schema event target so that it supports the same features as postgresql.ARRAY. Change-Id: I82fa6c9d2b8c5028dba3a009715f7bc296b2bc0b Fixes: #3964
-rw-r--r--doc/build/changelog/changelog_12.rst7
-rw-r--r--lib/sqlalchemy/dialects/postgresql/array.py17
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py19
-rw-r--r--test/dialect/postgresql/test_types.py231
4 files changed, 160 insertions, 114 deletions
diff --git a/doc/build/changelog/changelog_12.rst b/doc/build/changelog/changelog_12.rst
index 815f587f8..7c0421019 100644
--- a/doc/build/changelog/changelog_12.rst
+++ b/doc/build/changelog/changelog_12.rst
@@ -13,6 +13,13 @@
.. changelog::
:version: 1.2.0b1
+ .. change:: 3964
+ :tags: bug, postgresql
+ :tickets: 3964
+
+ Fixed bug where the base :class:`.sqltypes.ARRAY` datatype would not
+ invoke the bind/result processors of :class:`.postgresql.ARRAY`.
+
.. change:: 3963
:tags: bug, orm
:tickets: 3963
diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py
index 98cab9562..009c83c0d 100644
--- a/lib/sqlalchemy/dialects/postgresql/array.py
+++ b/lib/sqlalchemy/dialects/postgresql/array.py
@@ -5,7 +5,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from .base import ischema_names
+from .base import ischema_names, colspecs
from ...sql import expression, operators
from ...sql.base import SchemaEventTarget
from ... import types as sqltypes
@@ -114,7 +114,7 @@ CONTAINED_BY = operators.custom_op("<@", precedence=5)
OVERLAP = operators.custom_op("&&", precedence=5)
-class ARRAY(SchemaEventTarget, sqltypes.ARRAY):
+class ARRAY(sqltypes.ARRAY):
"""PostgreSQL ARRAY type.
@@ -248,18 +248,6 @@ class ARRAY(SchemaEventTarget, sqltypes.ARRAY):
def compare_values(self, x, y):
return x == y
- def _set_parent(self, column):
- """Support SchemaEventTarget"""
-
- if isinstance(self.item_type, SchemaEventTarget):
- self.item_type._set_parent(column)
-
- def _set_parent_with_dispatch(self, parent):
- """Support SchemaEventTarget"""
-
- if isinstance(self.item_type, SchemaEventTarget):
- self.item_type._set_parent_with_dispatch(parent)
-
def _proc_array(self, arr, itemproc, dim, collection):
if dim is None:
arr = list(arr)
@@ -311,4 +299,5 @@ class ARRAY(SchemaEventTarget, sqltypes.ARRAY):
tuple if self.as_tuple else list)
return process
+colspecs[sqltypes.ARRAY] = ARRAY
ischema_names['_array'] = ARRAY
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 8a114ece6..b8117e3ca 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -2061,7 +2061,7 @@ class JSON(Indexable, TypeEngine):
return process
-class ARRAY(Indexable, Concatenable, TypeEngine):
+class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
"""Represent a SQL Array type.
.. note:: This type serves as the basis for all ARRAY operations.
@@ -2199,6 +2199,11 @@ class ARRAY(Indexable, Concatenable, TypeEngine):
return operators.getitem, index, return_type
+ def contains(self, *arg, **kw):
+ raise NotImplementedError(
+ "ARRAY.contains() not implemented for the base "
+ "ARRAY type; please use the dialect-specific ARRAY type")
+
@util.dependencies("sqlalchemy.sql.elements")
def any(self, elements, other, operator=None):
"""Return ``other operator ANY (array)`` clause.
@@ -2325,6 +2330,18 @@ class ARRAY(Indexable, Concatenable, TypeEngine):
def compare_values(self, x, y):
return x == y
+ def _set_parent(self, column):
+ """Support SchemaEventTarget"""
+
+ if isinstance(self.item_type, SchemaEventTarget):
+ self.item_type._set_parent(column)
+
+ def _set_parent_with_dispatch(self, parent):
+ """Support SchemaEventTarget"""
+
+ if isinstance(self.item_type, SchemaEventTarget):
+ self.item_type._set_parent_with_dispatch(parent)
+
class REAL(Float):
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 807eeb60c..d2e19a04a 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -4,6 +4,7 @@ from sqlalchemy.testing.assertions import eq_, assert_raises, \
AssertsCompiledSQL, ComparesTables
from sqlalchemy.testing import engines, fixtures
from sqlalchemy import testing
+from sqlalchemy.sql import sqltypes
import datetime
from sqlalchemy import Table, MetaData, Column, Integer, Enum, Float, select, \
func, DateTime, Numeric, exc, String, cast, REAL, TypeDecorator, Unicode, \
@@ -85,7 +86,7 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
@testing.fails_on('postgresql+zxjdbc',
'zxjdbc has no support for PG arrays')
@testing.provide_metadata
- def test_arrays(self):
+ def test_arrays_pg(self):
metadata = self.metadata
t1 = Table('t', metadata,
Column('x', postgresql.ARRAY(Float)),
@@ -101,6 +102,25 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
([5], [5], [6], [decimal.Decimal("6.4")])
)
+ @testing.fails_on('postgresql+zxjdbc',
+ 'zxjdbc has no support for PG arrays')
+ @testing.provide_metadata
+ def test_arrays_base(self):
+ metadata = self.metadata
+ t1 = Table('t', metadata,
+ Column('x', sqltypes.ARRAY(Float)),
+ Column('y', sqltypes.ARRAY(REAL)),
+ Column('z', sqltypes.ARRAY(postgresql.DOUBLE_PRECISION)),
+ Column('q', sqltypes.ARRAY(Numeric))
+ )
+ metadata.create_all()
+ t1.insert().execute(x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")])
+ row = t1.select().execute().first()
+ eq_(
+ row,
+ ([5], [5], [6], [decimal.Decimal("6.4")])
+ )
+
class EnumTest(fixtures.TestBase, AssertsExecutionResults):
__backend__ = True
@@ -987,17 +1007,19 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
is_(expr.type.item_type.__class__, Integer)
-class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
+class ArrayRoundTripTest(object):
__only_on__ = 'postgresql'
__backend__ = True
__unsupported_on__ = 'postgresql+pg8000', 'postgresql+zxjdbc'
+ ARRAY = postgresql.ARRAY
+
@classmethod
def define_tables(cls, metadata):
class ProcValue(TypeDecorator):
- impl = postgresql.ARRAY(Integer, dimensions=2)
+ impl = cls.ARRAY(Integer, dimensions=2)
def process_bind_param(self, value, dialect):
if value is None:
@@ -1017,15 +1039,15 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
Table('arrtable', metadata,
Column('id', Integer, primary_key=True),
- Column('intarr', postgresql.ARRAY(Integer)),
- Column('strarr', postgresql.ARRAY(Unicode())),
+ Column('intarr', cls.ARRAY(Integer)),
+ Column('strarr', cls.ARRAY(Unicode())),
Column('dimarr', ProcValue)
)
Table('dim_arrtable', metadata,
Column('id', Integer, primary_key=True),
- Column('intarr', postgresql.ARRAY(Integer, dimensions=1)),
- Column('strarr', postgresql.ARRAY(Unicode(), dimensions=1)),
+ Column('intarr', cls.ARRAY(Integer, dimensions=1)),
+ Column('strarr', cls.ARRAY(Unicode(), dimensions=1)),
Column('dimarr', ProcValue)
)
@@ -1038,8 +1060,8 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
def test_reflect_array_column(self):
metadata2 = MetaData(testing.db)
tbl = Table('arrtable', metadata2, autoload=True)
- assert isinstance(tbl.c.intarr.type, postgresql.ARRAY)
- assert isinstance(tbl.c.strarr.type, postgresql.ARRAY)
+ assert isinstance(tbl.c.intarr.type, self.ARRAY)
+ assert isinstance(tbl.c.strarr.type, self.ARRAY)
assert isinstance(tbl.c.intarr.type.item_type, Integer)
assert isinstance(tbl.c.strarr.type.item_type, String)
@@ -1107,19 +1129,19 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
func.array_cat(
array([1, 2, 3]),
array([4, 5, 6]),
- type_=postgresql.ARRAY(Integer)
+ type_=self.ARRAY(Integer)
)[2:5]
])
eq_(
testing.db.execute(stmt).scalar(), [2, 3, 4, 5]
)
- def test_any_all_exprs(self):
+ def test_any_all_exprs_array(self):
stmt = select([
3 == any_(func.array_cat(
array([1, 2, 3]),
array([4, 5, 6]),
- type_=postgresql.ARRAY(Integer)
+ type_=self.ARRAY(Integer)
))
])
eq_(
@@ -1225,17 +1247,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
7
)
- def test_undim_array_empty(self):
- arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
- eq_(
- testing.db.scalar(
- select([arrtable.c.intarr]).
- where(arrtable.c.intarr.contains([]))
- ),
- [4, 5, 6]
- )
-
def test_array_getitem_slice_exec(self):
arrtable = self.tables.arrtable
testing.db.execute(
@@ -1255,49 +1266,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
[7, 8]
)
- def _test_undim_array_contains_typed_exec(self, struct):
- arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
- eq_(
- testing.db.scalar(
- select([arrtable.c.intarr]).
- where(arrtable.c.intarr.contains(struct([4, 5])))
- ),
- [4, 5, 6]
- )
-
- def test_undim_array_contains_set_exec(self):
- self._test_undim_array_contains_typed_exec(set)
-
- def test_undim_array_contains_list_exec(self):
- self._test_undim_array_contains_typed_exec(list)
-
- def test_undim_array_contains_generator_exec(self):
- self._test_undim_array_contains_typed_exec(
- lambda elem: (x for x in elem))
-
- def _test_dim_array_contains_typed_exec(self, struct):
- dim_arrtable = self.tables.dim_arrtable
- self._fixture_456(dim_arrtable)
- eq_(
- testing.db.scalar(
- select([dim_arrtable.c.intarr]).
- where(dim_arrtable.c.intarr.contains(struct([4, 5])))
- ),
- [4, 5, 6]
- )
-
- def test_dim_array_contains_set_exec(self):
- self._test_dim_array_contains_typed_exec(set)
-
- def test_dim_array_contains_list_exec(self):
- self._test_dim_array_contains_typed_exec(list)
-
- def test_dim_array_contains_generator_exec(self):
- self._test_dim_array_contains_typed_exec(
- lambda elem: (
- x for x in elem))
-
def test_multi_dim_roundtrip(self):
arrtable = self.tables.arrtable
testing.db.execute(arrtable.insert(), dimarr=[[1, 2, 3], [4, 5, 6]])
@@ -1306,35 +1274,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
[[-1, 0, 1], [2, 3, 4]]
)
- def test_array_contained_by_exec(self):
- arrtable = self.tables.arrtable
- with testing.db.connect() as conn:
- conn.execute(
- arrtable.insert(),
- intarr=[6, 5, 4]
- )
- eq_(
- conn.scalar(
- select([arrtable.c.intarr.contained_by([4, 5, 6, 7])])
- ),
- True
- )
-
- def test_array_overlap_exec(self):
- arrtable = self.tables.arrtable
- with testing.db.connect() as conn:
- conn.execute(
- arrtable.insert(),
- intarr=[4, 5, 6]
- )
- eq_(
- conn.scalar(
- select([arrtable.c.intarr]).
- where(arrtable.c.intarr.overlap([7, 6]))
- ),
- [4, 5, 6]
- )
-
def test_array_any_exec(self):
arrtable = self.tables.arrtable
with testing.db.connect() as conn:
@@ -1372,10 +1311,10 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
t1 = Table(
't1', metadata,
Column('id', Integer, primary_key=True),
- Column('data', postgresql.ARRAY(String(5), as_tuple=True)),
+ Column('data', self.ARRAY(String(5), as_tuple=True)),
Column(
'data2',
- postgresql.ARRAY(
+ self.ARRAY(
Numeric(asdecimal=False), as_tuple=True)
)
)
@@ -1416,13 +1355,13 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
't', m,
Column(
'data_1',
- postgresql.ARRAY(
+ self.ARRAY(
postgresql.ENUM('a', 'b', 'c', name='my_enum_1')
)
),
Column(
'data_2',
- postgresql.ARRAY(
+ self.ARRAY(
types.Enum('a', 'b', 'c', name='my_enum_2')
)
)
@@ -1437,6 +1376,100 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
eq_(inspect(testing.db).get_enums(), [])
+class CoreArrayRoundTripTest(ArrayRoundTripTest,
+ fixtures.TablesTest, AssertsExecutionResults):
+
+ ARRAY = sqltypes.ARRAY
+
+
+class PGArrayRoundTripTest(ArrayRoundTripTest,
+ fixtures.TablesTest, AssertsExecutionResults):
+ ARRAY = postgresql.ARRAY
+
+ def _test_undim_array_contains_typed_exec(self, struct):
+ arrtable = self.tables.arrtable
+ self._fixture_456(arrtable)
+ eq_(
+ testing.db.scalar(
+ select([arrtable.c.intarr]).
+ where(arrtable.c.intarr.contains(struct([4, 5])))
+ ),
+ [4, 5, 6]
+ )
+
+ def test_undim_array_contains_set_exec(self):
+ self._test_undim_array_contains_typed_exec(set)
+
+ def test_undim_array_contains_list_exec(self):
+ self._test_undim_array_contains_typed_exec(list)
+
+ def test_undim_array_contains_generator_exec(self):
+ self._test_undim_array_contains_typed_exec(
+ lambda elem: (x for x in elem))
+
+ def _test_dim_array_contains_typed_exec(self, struct):
+ dim_arrtable = self.tables.dim_arrtable
+ self._fixture_456(dim_arrtable)
+ eq_(
+ testing.db.scalar(
+ select([dim_arrtable.c.intarr]).
+ where(dim_arrtable.c.intarr.contains(struct([4, 5])))
+ ),
+ [4, 5, 6]
+ )
+
+ def test_dim_array_contains_set_exec(self):
+ self._test_dim_array_contains_typed_exec(set)
+
+ def test_dim_array_contains_list_exec(self):
+ self._test_dim_array_contains_typed_exec(list)
+
+ def test_dim_array_contains_generator_exec(self):
+ self._test_dim_array_contains_typed_exec(
+ lambda elem: (
+ x for x in elem))
+
+ def test_array_contained_by_exec(self):
+ arrtable = self.tables.arrtable
+ with testing.db.connect() as conn:
+ conn.execute(
+ arrtable.insert(),
+ intarr=[6, 5, 4]
+ )
+ eq_(
+ conn.scalar(
+ select([arrtable.c.intarr.contained_by([4, 5, 6, 7])])
+ ),
+ True
+ )
+
+ def test_undim_array_empty(self):
+ arrtable = self.tables.arrtable
+ self._fixture_456(arrtable)
+ eq_(
+ testing.db.scalar(
+ select([arrtable.c.intarr]).
+ where(arrtable.c.intarr.contains([]))
+ ),
+ [4, 5, 6]
+ )
+
+ def test_array_overlap_exec(self):
+ arrtable = self.tables.arrtable
+ with testing.db.connect() as conn:
+ conn.execute(
+ arrtable.insert(),
+ intarr=[4, 5, 6]
+ )
+ eq_(
+ conn.scalar(
+ select([arrtable.c.intarr]).
+ where(arrtable.c.intarr.overlap([7, 6]))
+ ),
+ [4, 5, 6]
+ )
+
+
class HashableFlagORMTest(fixtures.TestBase):
"""test the various 'collection' types that they flip the 'hashable' flag
appropriately. [ticket:3499]"""