summaryrefslogtreecommitdiff
path: root/test/dialect/postgresql
diff options
context:
space:
mode:
Diffstat (limited to 'test/dialect/postgresql')
-rw-r--r--test/dialect/postgresql/test_compiler.py123
-rw-r--r--test/dialect/postgresql/test_dialect.py39
-rw-r--r--test/dialect/postgresql/test_reflection.py76
-rw-r--r--test/dialect/postgresql/test_types.py400
4 files changed, 575 insertions, 63 deletions
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index 11661b11f..e64afb186 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import base as postgresql
from sqlalchemy.dialects.postgresql import TSRANGE
from sqlalchemy.orm import mapper, aliased, Session
from sqlalchemy.sql import table, column, operators
+from sqlalchemy.util import u
class SequenceTest(fixtures.TestBase, AssertsCompiledSQL):
@@ -106,6 +107,45 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
'AS length_1', dialect=dialect)
+ def test_create_drop_enum(self):
+ # test escaping and unicode within CREATE TYPE for ENUM
+ typ = postgresql.ENUM(
+ "val1", "val2", "val's 3", u('méil'), name="myname")
+ self.assert_compile(postgresql.CreateEnumType(typ),
+ u("CREATE TYPE myname AS ENUM ('val1', 'val2', 'val''s 3', 'méil')")
+ )
+
+ typ = postgresql.ENUM(
+ "val1", "val2", "val's 3", name="PleaseQuoteMe")
+ self.assert_compile(postgresql.CreateEnumType(typ),
+ "CREATE TYPE \"PleaseQuoteMe\" AS ENUM "
+ "('val1', 'val2', 'val''s 3')"
+ )
+
+ def test_generic_enum(self):
+ e1 = Enum('x', 'y', 'z', name='somename')
+ e2 = Enum('x', 'y', 'z', name='somename', schema='someschema')
+ self.assert_compile(postgresql.CreateEnumType(e1),
+ "CREATE TYPE somename AS ENUM ('x', 'y', 'z')"
+ )
+ self.assert_compile(postgresql.CreateEnumType(e2),
+ "CREATE TYPE someschema.somename AS ENUM "
+ "('x', 'y', 'z')")
+ self.assert_compile(postgresql.DropEnumType(e1),
+ 'DROP TYPE somename')
+ self.assert_compile(postgresql.DropEnumType(e2),
+ 'DROP TYPE someschema.somename')
+ t1 = Table('sometable', MetaData(), Column('somecolumn', e1))
+ self.assert_compile(schema.CreateTable(t1),
+ 'CREATE TABLE sometable (somecolumn '
+ 'somename)')
+ t1 = Table('sometable', MetaData(), Column('somecolumn',
+ Enum('x', 'y', 'z', native_enum=False)))
+ self.assert_compile(schema.CreateTable(t1),
+ "CREATE TABLE sometable (somecolumn "
+ "VARCHAR(1), CHECK (somecolumn IN ('x', "
+ "'y', 'z')))")
+
def test_create_partial_index(self):
m = MetaData()
tbl = Table('testtbl', m, Column('data', Integer))
@@ -173,6 +213,27 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
'USING hash (data)',
dialect=postgresql.dialect())
+
+ def test_create_index_expr_gets_parens(self):
+ m = MetaData()
+ tbl = Table('testtbl', m, Column('x', Integer), Column('y', Integer))
+
+ idx1 = Index('test_idx1', 5 / (tbl.c.x + tbl.c.y))
+ self.assert_compile(
+ schema.CreateIndex(idx1),
+ "CREATE INDEX test_idx1 ON testtbl ((5 / (x + y)))"
+ )
+
+ def test_create_index_literals(self):
+ m = MetaData()
+ tbl = Table('testtbl', m, Column('data', Integer))
+
+ idx1 = Index('test_idx1', tbl.c.data + 5)
+ self.assert_compile(
+ schema.CreateIndex(idx1),
+ "CREATE INDEX test_idx1 ON testtbl ((data + 5))"
+ )
+
def test_exclude_constraint_min(self):
m = MetaData()
tbl = Table('testtbl', m,
@@ -228,6 +289,68 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
'SUBSTRING(%(substring_1)s FROM %(substring_2)s)')
+ def test_for_update(self):
+ table1 = table('mytable',
+ column('myid'), column('name'), column('description'))
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(nowait=True),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE NOWAIT")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(read=True),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).
+ with_for_update(read=True, nowait=True),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE NOWAIT")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).
+ with_for_update(of=table1.c.myid),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR UPDATE OF mytable")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).
+ with_for_update(read=True, nowait=True, of=table1),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR SHARE OF mytable NOWAIT")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).
+ with_for_update(read=True, nowait=True, of=table1.c.myid),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR SHARE OF mytable NOWAIT")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).
+ with_for_update(read=True, nowait=True,
+ of=[table1.c.myid, table1.c.name]),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR SHARE OF mytable NOWAIT")
+
+ ta = table1.alias()
+ self.assert_compile(
+ ta.select(ta.c.myid == 7).
+ with_for_update(of=[ta.c.myid, ta.c.name]),
+ "SELECT mytable_1.myid, mytable_1.name, mytable_1.description "
+ "FROM mytable AS mytable_1 "
+ "WHERE mytable_1.myid = %(myid_1)s FOR UPDATE OF mytable_1"
+ )
def test_reserved_words(self):
diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py
index 1fc239cb7..fd6df2c98 100644
--- a/test/dialect/postgresql/test_dialect.py
+++ b/test/dialect/postgresql/test_dialect.py
@@ -17,6 +17,7 @@ from sqlalchemy.dialects.postgresql import base as postgresql
import logging
import logging.handlers
from sqlalchemy.testing.mock import Mock
+from sqlalchemy.engine.reflection import Inspector
class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
@@ -53,7 +54,11 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
'compiled by GCC gcc (GCC) 4.4.2, 64-bit', (8, 5)),
('EnterpriseDB 9.1.2.2 on x86_64-unknown-linux-gnu, '
'compiled by gcc (GCC) 4.1.2 20080704 (Red Hat 4.1.2-50), '
- '64-bit', (9, 1, 2))]:
+ '64-bit', (9, 1, 2)),
+ ('[PostgreSQL 9.2.4 ] VMware vFabric Postgres 9.2.4.0 '
+ 'release build 1080137', (9, 2, 4))
+
+ ]:
eq_(testing.db.dialect._get_server_version_info(mock_conn(string)),
version)
@@ -63,8 +68,10 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
assert testing.db.dialect.dbapi.__version__.\
startswith(".".join(str(x) for x in v))
+ # currently not passing with pg 9.3 that does not seem to generate
+ # any notices here, woudl rather find a way to mock this
@testing.only_on('postgresql+psycopg2', 'psycopg2-specific feature')
- def test_notice_logging(self):
+ def _test_notice_logging(self):
log = logging.getLogger('sqlalchemy.dialects.postgresql')
buf = logging.handlers.BufferingHandler(100)
lev = log.level
@@ -199,18 +206,32 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
assert_raises(exc.InvalidRequestError, testing.db.execute, stmt)
def test_serial_integer(self):
- for type_, expected in [
- (Integer, 'SERIAL'),
- (BigInteger, 'BIGSERIAL'),
- (SmallInteger, 'SMALLINT'),
- (postgresql.INTEGER, 'SERIAL'),
- (postgresql.BIGINT, 'BIGSERIAL'),
+
+ for version, type_, expected in [
+ (None, Integer, 'SERIAL'),
+ (None, BigInteger, 'BIGSERIAL'),
+ ((9, 1), SmallInteger, 'SMALLINT'),
+ ((9, 2), SmallInteger, 'SMALLSERIAL'),
+ (None, postgresql.INTEGER, 'SERIAL'),
+ (None, postgresql.BIGINT, 'BIGSERIAL'),
]:
m = MetaData()
t = Table('t', m, Column('c', type_, primary_key=True))
- ddl_compiler = testing.db.dialect.ddl_compiler(testing.db.dialect, schema.CreateTable(t))
+
+ if version:
+ dialect = postgresql.dialect()
+ dialect._get_server_version_info = Mock(return_value=version)
+ dialect.initialize(testing.db.connect())
+ else:
+ dialect = testing.db.dialect
+
+ ddl_compiler = dialect.ddl_compiler(
+ dialect,
+ schema.CreateTable(t)
+ )
eq_(
ddl_compiler.get_column_specification(t.c.c),
"c %s NOT NULL" % expected
)
+
diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py
index fb399b546..58f34d5d0 100644
--- a/test/dialect/postgresql/test_reflection.py
+++ b/test/dialect/postgresql/test_reflection.py
@@ -5,6 +5,7 @@ from sqlalchemy.testing.assertions import eq_, assert_raises, \
AssertsCompiledSQL, ComparesTables
from sqlalchemy.testing import engines, fixtures
from sqlalchemy import testing
+from sqlalchemy import inspect
from sqlalchemy import Table, Column, select, MetaData, text, Integer, \
String, Sequence, ForeignKey, join, Numeric, \
PrimaryKeyConstraint, DateTime, tuple_, Float, BigInteger, \
@@ -159,6 +160,17 @@ class ReflectionTest(fixtures.TestBase):
subject.join(referer).onclause))
@testing.provide_metadata
+ def test_reflect_default_over_128_chars(self):
+ Table('t', self.metadata,
+ Column('x', String(200), server_default="abcd" * 40)
+ ).create(testing.db)
+
+ m = MetaData()
+ t = Table('t', m, autoload=True, autoload_with=testing.db)
+ eq_(
+ t.c.x.server_default.arg.text, "'%s'::character varying" % ("abcd" * 40)
+ )
+ @testing.provide_metadata
def test_renamed_sequence_reflection(self):
metadata = self.metadata
t = Table('t', metadata, Column('id', Integer, primary_key=True))
@@ -416,6 +428,70 @@ class ReflectionTest(fixtures.TestBase):
eq_(ind, [{'unique': False, 'column_names': ['y'], 'name': 'idx1'}])
conn.close()
+ @testing.provide_metadata
+ def test_foreign_key_option_inspection(self):
+ metadata = self.metadata
+ Table('person', metadata,
+ Column('id', String(length=32), nullable=False, primary_key=True),
+ Column('company_id', ForeignKey('company.id',
+ name='person_company_id_fkey',
+ match='FULL', onupdate='RESTRICT', ondelete='RESTRICT',
+ deferrable=True, initially='DEFERRED'
+ )
+ )
+ )
+ Table('company', metadata,
+ Column('id', String(length=32), nullable=False, primary_key=True),
+ Column('name', String(length=255)),
+ Column('industry_id', ForeignKey('industry.id',
+ name='company_industry_id_fkey',
+ onupdate='CASCADE', ondelete='CASCADE',
+ deferrable=False, # PG default
+ initially='IMMEDIATE' # PG default
+ )
+ )
+ )
+ Table('industry', metadata,
+ Column('id', Integer(), nullable=False, primary_key=True),
+ Column('name', String(length=255))
+ )
+ fk_ref = {
+ 'person_company_id_fkey': {
+ 'name': 'person_company_id_fkey',
+ 'constrained_columns': ['company_id'],
+ 'referred_columns': ['id'],
+ 'referred_table': 'company',
+ 'referred_schema': None,
+ 'options': {
+ 'onupdate': 'RESTRICT',
+ 'deferrable': True,
+ 'ondelete': 'RESTRICT',
+ 'initially': 'DEFERRED',
+ 'match': 'FULL'
+ }
+ },
+ 'company_industry_id_fkey': {
+ 'name': 'company_industry_id_fkey',
+ 'constrained_columns': ['industry_id'],
+ 'referred_columns': ['id'],
+ 'referred_table': 'industry',
+ 'referred_schema': None,
+ 'options': {
+ 'onupdate': 'CASCADE',
+ 'deferrable': None,
+ 'ondelete': 'CASCADE',
+ 'initially': None,
+ 'match': None
+ }
+ }
+ }
+ metadata.create_all()
+ inspector = inspect(testing.db)
+ fks = inspector.get_foreign_keys('person') + \
+ inspector.get_foreign_keys('company')
+ for fk in fks:
+ eq_(fk, fk_ref[fk['name']])
+
class CustomTypeReflectionTest(fixtures.TestBase):
class CustomType(object):
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 784f8bcbf..ba4b63e1a 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -10,18 +10,22 @@ from sqlalchemy import Table, Column, select, MetaData, text, Integer, \
PrimaryKeyConstraint, DateTime, tuple_, Float, BigInteger, \
func, literal_column, literal, bindparam, cast, extract, \
SmallInteger, Enum, REAL, update, insert, Index, delete, \
- and_, Date, TypeDecorator, Time, Unicode, Interval, or_, Text
+ and_, Date, TypeDecorator, Time, Unicode, Interval, or_, Text, \
+ type_coerce
from sqlalchemy.orm import Session, mapper, aliased
from sqlalchemy import exc, schema, types
from sqlalchemy.dialects.postgresql import base as postgresql
from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \
- INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE
+ INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \
+ JSON
import decimal
from sqlalchemy import util
from sqlalchemy.testing.util import round_decimal
from sqlalchemy.sql import table, column, operators
import logging
import re
+from sqlalchemy import inspect
+from sqlalchemy import event
class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
__only_on__ = 'postgresql'
@@ -96,34 +100,10 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
([5], [5], [6], [decimal.Decimal("6.4")])
)
-class EnumTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
+class EnumTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = 'postgresql'
- __dialect__ = postgresql.dialect()
- def test_compile(self):
- e1 = Enum('x', 'y', 'z', name='somename')
- e2 = Enum('x', 'y', 'z', name='somename', schema='someschema')
- self.assert_compile(postgresql.CreateEnumType(e1),
- "CREATE TYPE somename AS ENUM ('x','y','z')"
- )
- self.assert_compile(postgresql.CreateEnumType(e2),
- "CREATE TYPE someschema.somename AS ENUM "
- "('x','y','z')")
- self.assert_compile(postgresql.DropEnumType(e1),
- 'DROP TYPE somename')
- self.assert_compile(postgresql.DropEnumType(e2),
- 'DROP TYPE someschema.somename')
- t1 = Table('sometable', MetaData(), Column('somecolumn', e1))
- self.assert_compile(schema.CreateTable(t1),
- 'CREATE TABLE sometable (somecolumn '
- 'somename)')
- t1 = Table('sometable', MetaData(), Column('somecolumn',
- Enum('x', 'y', 'z', native_enum=False)))
- self.assert_compile(schema.CreateTable(t1),
- "CREATE TABLE sometable (somecolumn "
- "VARCHAR(1), CHECK (somecolumn IN ('x', "
- "'y', 'z')))")
@testing.fails_on('postgresql+zxjdbc',
'zxjdbc fails on ENUM: column "XXX" is of type '
@@ -860,7 +840,8 @@ class SpecialTypesTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
Column('plain_interval', postgresql.INTERVAL),
Column('year_interval', y2m()),
Column('month_interval', d2s()),
- Column('precision_interval', postgresql.INTERVAL(precision=3))
+ Column('precision_interval', postgresql.INTERVAL(precision=3)),
+ Column('tsvector_document', postgresql.TSVECTOR)
)
metadata.create_all()
@@ -893,6 +874,17 @@ class SpecialTypesTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
self.assert_compile(type_, expected)
@testing.provide_metadata
+ def test_tsvector_round_trip(self):
+ t = Table('t1', self.metadata, Column('data', postgresql.TSVECTOR))
+ t.create()
+ testing.db.execute(t.insert(), data="a fat cat sat")
+ eq_(testing.db.scalar(select([t.c.data])), "'a' 'cat' 'fat' 'sat'")
+
+ testing.db.execute(t.update(), data="'a' 'cat' 'fat' 'mat' 'sat'")
+
+ eq_(testing.db.scalar(select([t.c.data])), "'a' 'cat' 'fat' 'mat' 'sat'")
+
+ @testing.provide_metadata
def test_bit_reflection(self):
metadata = self.metadata
t1 = Table('t1', metadata,
@@ -918,7 +910,6 @@ class UUIDTest(fixtures.TestBase):
__only_on__ = 'postgresql'
- @testing.requires.python25
@testing.fails_on('postgresql+zxjdbc',
'column "data" is of type uuid but expression is of type character varying')
@testing.fails_on('postgresql+pg8000', 'No support for UUID type')
@@ -932,7 +923,6 @@ class UUIDTest(fixtures.TestBase):
str(uuid.uuid4())
)
- @testing.requires.python25
@testing.fails_on('postgresql+zxjdbc',
'column "data" is of type uuid but expression is of type character varying')
@testing.fails_on('postgresql+pg8000', 'No support for UUID type')
@@ -978,13 +968,8 @@ class UUIDTest(fixtures.TestBase):
-class HStoreTest(fixtures.TestBase):
- def _assert_sql(self, construct, expected):
- dialect = postgresql.dialect()
- compiled = str(construct.compile(dialect=dialect))
- compiled = re.sub(r'\s+', ' ', compiled)
- expected = re.sub(r'\s+', ' ', expected)
- eq_(compiled, expected)
+class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
+ __dialect__ = 'postgresql'
def setup(self):
metadata = MetaData()
@@ -996,7 +981,7 @@ class HStoreTest(fixtures.TestBase):
def _test_where(self, whereclause, expected):
stmt = select([self.test_table]).where(whereclause)
- self._assert_sql(
+ self.assert_compile(
stmt,
"SELECT test_table.id, test_table.hash FROM test_table "
"WHERE %s" % expected
@@ -1004,7 +989,7 @@ class HStoreTest(fixtures.TestBase):
def _test_cols(self, colclause, expected, from_=True):
stmt = select([colclause])
- self._assert_sql(
+ self.assert_compile(
stmt,
(
"SELECT %s" +
@@ -1013,9 +998,8 @@ class HStoreTest(fixtures.TestBase):
)
def test_bind_serialize_default(self):
- from sqlalchemy.engine import default
- dialect = default.DefaultDialect()
+ dialect = postgresql.dialect()
proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
eq_(
proc(util.OrderedDict([("key1", "value1"), ("key2", "value2")])),
@@ -1023,9 +1007,7 @@ class HStoreTest(fixtures.TestBase):
)
def test_bind_serialize_with_slashes_and_quotes(self):
- from sqlalchemy.engine import default
-
- dialect = default.DefaultDialect()
+ dialect = postgresql.dialect()
proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
eq_(
proc({'\\"a': '\\"1'}),
@@ -1033,9 +1015,7 @@ class HStoreTest(fixtures.TestBase):
)
def test_parse_error(self):
- from sqlalchemy.engine import default
-
- dialect = default.DefaultDialect()
+ dialect = postgresql.dialect()
proc = self.test_table.c.hash.type._cached_result_processor(
dialect, None)
assert_raises_message(
@@ -1048,9 +1028,7 @@ class HStoreTest(fixtures.TestBase):
)
def test_result_deserialize_default(self):
- from sqlalchemy.engine import default
-
- dialect = default.DefaultDialect()
+ dialect = postgresql.dialect()
proc = self.test_table.c.hash.type._cached_result_processor(
dialect, None)
eq_(
@@ -1059,9 +1037,7 @@ class HStoreTest(fixtures.TestBase):
)
def test_result_deserialize_with_slashes_and_quotes(self):
- from sqlalchemy.engine import default
-
- dialect = default.DefaultDialect()
+ dialect = postgresql.dialect()
proc = self.test_table.c.hash.type._cached_result_processor(
dialect, None)
eq_(
@@ -1305,7 +1281,6 @@ class HStoreRoundTripTest(fixtures.TablesTest):
return engine
def test_reflect(self):
- from sqlalchemy import inspect
insp = inspect(testing.db)
cols = insp.get_columns('data_table')
assert isinstance(cols[2]['type'], HSTORE)
@@ -1677,3 +1652,320 @@ class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest):
def _data_obj(self):
return self.extras.DateTimeTZRange(*self.tstzs())
+
+
+class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
+ __dialect__ = 'postgresql'
+
+ def setup(self):
+ metadata = MetaData()
+ self.test_table = Table('test_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('test_column', JSON)
+ )
+ self.jsoncol = self.test_table.c.test_column
+
+ def _test_where(self, whereclause, expected):
+ stmt = select([self.test_table]).where(whereclause)
+ self.assert_compile(
+ stmt,
+ "SELECT test_table.id, test_table.test_column FROM test_table "
+ "WHERE %s" % expected
+ )
+
+ def _test_cols(self, colclause, expected, from_=True):
+ stmt = select([colclause])
+ self.assert_compile(
+ stmt,
+ (
+ "SELECT %s" +
+ (" FROM test_table" if from_ else "")
+ ) % expected
+ )
+
+ def test_bind_serialize_default(self):
+ dialect = postgresql.dialect()
+ proc = self.test_table.c.test_column.type._cached_bind_processor(dialect)
+ eq_(
+ proc({"A": [1, 2, 3, True, False]}),
+ '{"A": [1, 2, 3, true, false]}'
+ )
+
+ def test_result_deserialize_default(self):
+ dialect = postgresql.dialect()
+ proc = self.test_table.c.test_column.type._cached_result_processor(
+ dialect, None)
+ eq_(
+ proc('{"A": [1, 2, 3, true, false]}'),
+ {"A": [1, 2, 3, True, False]}
+ )
+
+ # This test is a bit misleading -- in real life you will need to cast to do anything
+ def test_where_getitem(self):
+ self._test_where(
+ self.jsoncol['bar'] == None,
+ "(test_table.test_column -> %(test_column_1)s) IS NULL"
+ )
+
+ def test_where_path(self):
+ self._test_where(
+ self.jsoncol[("foo", 1)] == None,
+ "(test_table.test_column #> %(test_column_1)s) IS NULL"
+ )
+
+ def test_where_getitem_as_text(self):
+ self._test_where(
+ self.jsoncol['bar'].astext == None,
+ "(test_table.test_column ->> %(test_column_1)s) IS NULL"
+ )
+
+ def test_where_getitem_as_cast(self):
+ self._test_where(
+ self.jsoncol['bar'].cast(Integer) == 5,
+ "CAST(test_table.test_column ->> %(test_column_1)s AS INTEGER) "
+ "= %(param_1)s"
+ )
+
+ def test_where_path_as_text(self):
+ self._test_where(
+ self.jsoncol[("foo", 1)].astext == None,
+ "(test_table.test_column #>> %(test_column_1)s) IS NULL"
+ )
+
+ def test_cols_get(self):
+ self._test_cols(
+ self.jsoncol['foo'],
+ "test_table.test_column -> %(test_column_1)s AS anon_1",
+ True
+ )
+
+
+class JSONRoundTripTest(fixtures.TablesTest):
+ __only_on__ = ('postgresql >= 9.3',)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table('data_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(30), nullable=False),
+ Column('data', JSON)
+ )
+
+ def _fixture_data(self, engine):
+ data_table = self.tables.data_table
+ engine.execute(
+ data_table.insert(),
+ {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}},
+ {'name': 'r2', 'data': {"k1": "r2v1", "k2": "r2v2"}},
+ {'name': 'r3', 'data': {"k1": "r3v1", "k2": "r3v2"}},
+ {'name': 'r4', 'data': {"k1": "r4v1", "k2": "r4v2"}},
+ {'name': 'r5', 'data': {"k1": "r5v1", "k2": "r5v2", "k3": 5}},
+ )
+
+ def _assert_data(self, compare):
+ data = testing.db.execute(
+ select([self.tables.data_table.c.data]).
+ order_by(self.tables.data_table.c.name)
+ ).fetchall()
+ eq_([d for d, in data], compare)
+
+ def _test_insert(self, engine):
+ engine.execute(
+ self.tables.data_table.insert(),
+ {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}
+ )
+ self._assert_data([{"k1": "r1v1", "k2": "r1v2"}])
+
+ def _non_native_engine(self, json_serializer=None, json_deserializer=None):
+ if json_serializer is not None or json_deserializer is not None:
+ options = {
+ "json_serializer": json_serializer,
+ "json_deserializer": json_deserializer
+ }
+ else:
+ options = {}
+
+ if testing.against("postgresql+psycopg2"):
+ from psycopg2.extras import register_default_json
+ engine = engines.testing_engine(options=options)
+ @event.listens_for(engine, "connect")
+ def connect(dbapi_connection, connection_record):
+ engine.dialect._has_native_json = False
+ def pass_(value):
+ return value
+ register_default_json(dbapi_connection, loads=pass_)
+ elif options:
+ engine = engines.testing_engine(options=options)
+ else:
+ engine = testing.db
+ engine.connect()
+ return engine
+
+ def test_reflect(self):
+ insp = inspect(testing.db)
+ cols = insp.get_columns('data_table')
+ assert isinstance(cols[2]['type'], JSON)
+
+ @testing.only_on("postgresql+psycopg2")
+ def test_insert_native(self):
+ engine = testing.db
+ self._test_insert(engine)
+
+ def test_insert_python(self):
+ engine = self._non_native_engine()
+ self._test_insert(engine)
+
+
+ def _test_custom_serialize_deserialize(self, native):
+ import json
+ def loads(value):
+ value = json.loads(value)
+ value['x'] = value['x'] + '_loads'
+ return value
+
+ def dumps(value):
+ value = dict(value)
+ value['x'] = 'dumps_y'
+ return json.dumps(value)
+
+ if native:
+ engine = engines.testing_engine(options=dict(
+ json_serializer=dumps,
+ json_deserializer=loads
+ ))
+ else:
+ engine = self._non_native_engine(
+ json_serializer=dumps,
+ json_deserializer=loads
+ )
+
+ s = select([
+ cast(
+ {
+ "key": "value",
+ "x": "q"
+ },
+ JSON
+ )
+ ])
+ eq_(
+ engine.scalar(s),
+ {
+ "key": "value",
+ "x": "dumps_y_loads"
+ },
+ )
+
+ @testing.only_on("postgresql+psycopg2")
+ def test_custom_native(self):
+ self._test_custom_serialize_deserialize(True)
+
+ @testing.only_on("postgresql+psycopg2")
+ def test_custom_python(self):
+ self._test_custom_serialize_deserialize(False)
+
+
+ @testing.only_on("postgresql+psycopg2")
+ def test_criterion_native(self):
+ engine = testing.db
+ self._fixture_data(engine)
+ self._test_criterion(engine)
+
+ def test_criterion_python(self):
+ engine = self._non_native_engine()
+ self._fixture_data(engine)
+ self._test_criterion(engine)
+
+ def test_path_query(self):
+ engine = testing.db
+ self._fixture_data(engine)
+ data_table = self.tables.data_table
+ result = engine.execute(
+ select([data_table.c.data]).where(
+ data_table.c.data[('k1',)].astext == 'r3v1'
+ )
+ ).first()
+ eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
+
+ def test_query_returned_as_text(self):
+ engine = testing.db
+ self._fixture_data(engine)
+ data_table = self.tables.data_table
+ result = engine.execute(
+ select([data_table.c.data['k1'].astext])
+ ).first()
+ assert isinstance(result[0], util.text_type)
+
+ def test_query_returned_as_int(self):
+ engine = testing.db
+ self._fixture_data(engine)
+ data_table = self.tables.data_table
+ result = engine.execute(
+ select([data_table.c.data['k3'].cast(Integer)]).where(
+ data_table.c.name == 'r5')
+ ).first()
+ assert isinstance(result[0], int)
+
+ def _test_criterion(self, engine):
+ data_table = self.tables.data_table
+ result = engine.execute(
+ select([data_table.c.data]).where(
+ data_table.c.data['k1'].astext == 'r3v1'
+ )
+ ).first()
+ eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
+
+ def _test_fixed_round_trip(self, engine):
+ s = select([
+ cast(
+ {
+ "key": "value",
+ "key2": {"k1": "v1", "k2": "v2"}
+ },
+ JSON
+ )
+ ])
+ eq_(
+ engine.scalar(s),
+ {
+ "key": "value",
+ "key2": {"k1": "v1", "k2": "v2"}
+ },
+ )
+
+ def test_fixed_round_trip_python(self):
+ engine = self._non_native_engine()
+ self._test_fixed_round_trip(engine)
+
+ @testing.only_on("postgresql+psycopg2")
+ def test_fixed_round_trip_native(self):
+ engine = testing.db
+ self._test_fixed_round_trip(engine)
+
+ def _test_unicode_round_trip(self, engine):
+ s = select([
+ cast(
+ {
+ util.u('réveillé'): util.u('réveillé'),
+ "data": {"k1": util.u('drôle')}
+ },
+ JSON
+ )
+ ])
+ eq_(
+ engine.scalar(s),
+ {
+ util.u('réveillé'): util.u('réveillé'),
+ "data": {"k1": util.u('drôle')}
+ },
+ )
+
+
+ def test_unicode_round_trip_python(self):
+ engine = self._non_native_engine()
+ self._test_unicode_round_trip(engine)
+
+ @testing.only_on("postgresql+psycopg2")
+ def test_unicode_round_trip_native(self):
+ engine = testing.db
+ self._test_unicode_round_trip(engine)