diff options
Diffstat (limited to 'test/dialect/postgresql')
| -rw-r--r-- | test/dialect/postgresql/test_compiler.py | 123 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_dialect.py | 39 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_reflection.py | 76 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_types.py | 400 |
4 files changed, 575 insertions, 63 deletions
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 11661b11f..e64afb186 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.dialects.postgresql import TSRANGE from sqlalchemy.orm import mapper, aliased, Session from sqlalchemy.sql import table, column, operators +from sqlalchemy.util import u class SequenceTest(fixtures.TestBase, AssertsCompiledSQL): @@ -106,6 +107,45 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): 'AS length_1', dialect=dialect) + def test_create_drop_enum(self): + # test escaping and unicode within CREATE TYPE for ENUM + typ = postgresql.ENUM( + "val1", "val2", "val's 3", u('méil'), name="myname") + self.assert_compile(postgresql.CreateEnumType(typ), + u("CREATE TYPE myname AS ENUM ('val1', 'val2', 'val''s 3', 'méil')") + ) + + typ = postgresql.ENUM( + "val1", "val2", "val's 3", name="PleaseQuoteMe") + self.assert_compile(postgresql.CreateEnumType(typ), + "CREATE TYPE \"PleaseQuoteMe\" AS ENUM " + "('val1', 'val2', 'val''s 3')" + ) + + def test_generic_enum(self): + e1 = Enum('x', 'y', 'z', name='somename') + e2 = Enum('x', 'y', 'z', name='somename', schema='someschema') + self.assert_compile(postgresql.CreateEnumType(e1), + "CREATE TYPE somename AS ENUM ('x', 'y', 'z')" + ) + self.assert_compile(postgresql.CreateEnumType(e2), + "CREATE TYPE someschema.somename AS ENUM " + "('x', 'y', 'z')") + self.assert_compile(postgresql.DropEnumType(e1), + 'DROP TYPE somename') + self.assert_compile(postgresql.DropEnumType(e2), + 'DROP TYPE someschema.somename') + t1 = Table('sometable', MetaData(), Column('somecolumn', e1)) + self.assert_compile(schema.CreateTable(t1), + 'CREATE TABLE sometable (somecolumn ' + 'somename)') + t1 = Table('sometable', MetaData(), Column('somecolumn', + Enum('x', 'y', 'z', native_enum=False))) + self.assert_compile(schema.CreateTable(t1), + "CREATE TABLE sometable (somecolumn " + "VARCHAR(1), CHECK (somecolumn IN ('x', " + "'y', 'z')))") + def test_create_partial_index(self): m = MetaData() tbl = Table('testtbl', m, Column('data', Integer)) @@ -173,6 +213,27 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): 'USING hash (data)', dialect=postgresql.dialect()) + + def test_create_index_expr_gets_parens(self): + m = MetaData() + tbl = Table('testtbl', m, Column('x', Integer), Column('y', Integer)) + + idx1 = Index('test_idx1', 5 / (tbl.c.x + tbl.c.y)) + self.assert_compile( + schema.CreateIndex(idx1), + "CREATE INDEX test_idx1 ON testtbl ((5 / (x + y)))" + ) + + def test_create_index_literals(self): + m = MetaData() + tbl = Table('testtbl', m, Column('data', Integer)) + + idx1 = Index('test_idx1', tbl.c.data + 5) + self.assert_compile( + schema.CreateIndex(idx1), + "CREATE INDEX test_idx1 ON testtbl ((data + 5))" + ) + def test_exclude_constraint_min(self): m = MetaData() tbl = Table('testtbl', m, @@ -228,6 +289,68 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): 'SUBSTRING(%(substring_1)s FROM %(substring_2)s)') + def test_for_update(self): + table1 = table('mytable', + column('myid'), column('name'), column('description')) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update(), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE") + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update(nowait=True), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE NOWAIT") + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update(read=True), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE") + + self.assert_compile( + table1.select(table1.c.myid == 7). + with_for_update(read=True, nowait=True), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE NOWAIT") + + self.assert_compile( + table1.select(table1.c.myid == 7). + with_for_update(of=table1.c.myid), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s " + "FOR UPDATE OF mytable") + + self.assert_compile( + table1.select(table1.c.myid == 7). + with_for_update(read=True, nowait=True, of=table1), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s " + "FOR SHARE OF mytable NOWAIT") + + self.assert_compile( + table1.select(table1.c.myid == 7). + with_for_update(read=True, nowait=True, of=table1.c.myid), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s " + "FOR SHARE OF mytable NOWAIT") + + self.assert_compile( + table1.select(table1.c.myid == 7). + with_for_update(read=True, nowait=True, + of=[table1.c.myid, table1.c.name]), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s " + "FOR SHARE OF mytable NOWAIT") + + ta = table1.alias() + self.assert_compile( + ta.select(ta.c.myid == 7). + with_for_update(of=[ta.c.myid, ta.c.name]), + "SELECT mytable_1.myid, mytable_1.name, mytable_1.description " + "FROM mytable AS mytable_1 " + "WHERE mytable_1.myid = %(myid_1)s FOR UPDATE OF mytable_1" + ) def test_reserved_words(self): diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 1fc239cb7..fd6df2c98 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -17,6 +17,7 @@ from sqlalchemy.dialects.postgresql import base as postgresql import logging import logging.handlers from sqlalchemy.testing.mock import Mock +from sqlalchemy.engine.reflection import Inspector class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): @@ -53,7 +54,11 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): 'compiled by GCC gcc (GCC) 4.4.2, 64-bit', (8, 5)), ('EnterpriseDB 9.1.2.2 on x86_64-unknown-linux-gnu, ' 'compiled by gcc (GCC) 4.1.2 20080704 (Red Hat 4.1.2-50), ' - '64-bit', (9, 1, 2))]: + '64-bit', (9, 1, 2)), + ('[PostgreSQL 9.2.4 ] VMware vFabric Postgres 9.2.4.0 ' + 'release build 1080137', (9, 2, 4)) + + ]: eq_(testing.db.dialect._get_server_version_info(mock_conn(string)), version) @@ -63,8 +68,10 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): assert testing.db.dialect.dbapi.__version__.\ startswith(".".join(str(x) for x in v)) + # currently not passing with pg 9.3 that does not seem to generate + # any notices here, woudl rather find a way to mock this @testing.only_on('postgresql+psycopg2', 'psycopg2-specific feature') - def test_notice_logging(self): + def _test_notice_logging(self): log = logging.getLogger('sqlalchemy.dialects.postgresql') buf = logging.handlers.BufferingHandler(100) lev = log.level @@ -199,18 +206,32 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): assert_raises(exc.InvalidRequestError, testing.db.execute, stmt) def test_serial_integer(self): - for type_, expected in [ - (Integer, 'SERIAL'), - (BigInteger, 'BIGSERIAL'), - (SmallInteger, 'SMALLINT'), - (postgresql.INTEGER, 'SERIAL'), - (postgresql.BIGINT, 'BIGSERIAL'), + + for version, type_, expected in [ + (None, Integer, 'SERIAL'), + (None, BigInteger, 'BIGSERIAL'), + ((9, 1), SmallInteger, 'SMALLINT'), + ((9, 2), SmallInteger, 'SMALLSERIAL'), + (None, postgresql.INTEGER, 'SERIAL'), + (None, postgresql.BIGINT, 'BIGSERIAL'), ]: m = MetaData() t = Table('t', m, Column('c', type_, primary_key=True)) - ddl_compiler = testing.db.dialect.ddl_compiler(testing.db.dialect, schema.CreateTable(t)) + + if version: + dialect = postgresql.dialect() + dialect._get_server_version_info = Mock(return_value=version) + dialect.initialize(testing.db.connect()) + else: + dialect = testing.db.dialect + + ddl_compiler = dialect.ddl_compiler( + dialect, + schema.CreateTable(t) + ) eq_( ddl_compiler.get_column_specification(t.c.c), "c %s NOT NULL" % expected ) + diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index fb399b546..58f34d5d0 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -5,6 +5,7 @@ from sqlalchemy.testing.assertions import eq_, assert_raises, \ AssertsCompiledSQL, ComparesTables from sqlalchemy.testing import engines, fixtures from sqlalchemy import testing +from sqlalchemy import inspect from sqlalchemy import Table, Column, select, MetaData, text, Integer, \ String, Sequence, ForeignKey, join, Numeric, \ PrimaryKeyConstraint, DateTime, tuple_, Float, BigInteger, \ @@ -159,6 +160,17 @@ class ReflectionTest(fixtures.TestBase): subject.join(referer).onclause)) @testing.provide_metadata + def test_reflect_default_over_128_chars(self): + Table('t', self.metadata, + Column('x', String(200), server_default="abcd" * 40) + ).create(testing.db) + + m = MetaData() + t = Table('t', m, autoload=True, autoload_with=testing.db) + eq_( + t.c.x.server_default.arg.text, "'%s'::character varying" % ("abcd" * 40) + ) + @testing.provide_metadata def test_renamed_sequence_reflection(self): metadata = self.metadata t = Table('t', metadata, Column('id', Integer, primary_key=True)) @@ -416,6 +428,70 @@ class ReflectionTest(fixtures.TestBase): eq_(ind, [{'unique': False, 'column_names': ['y'], 'name': 'idx1'}]) conn.close() + @testing.provide_metadata + def test_foreign_key_option_inspection(self): + metadata = self.metadata + Table('person', metadata, + Column('id', String(length=32), nullable=False, primary_key=True), + Column('company_id', ForeignKey('company.id', + name='person_company_id_fkey', + match='FULL', onupdate='RESTRICT', ondelete='RESTRICT', + deferrable=True, initially='DEFERRED' + ) + ) + ) + Table('company', metadata, + Column('id', String(length=32), nullable=False, primary_key=True), + Column('name', String(length=255)), + Column('industry_id', ForeignKey('industry.id', + name='company_industry_id_fkey', + onupdate='CASCADE', ondelete='CASCADE', + deferrable=False, # PG default + initially='IMMEDIATE' # PG default + ) + ) + ) + Table('industry', metadata, + Column('id', Integer(), nullable=False, primary_key=True), + Column('name', String(length=255)) + ) + fk_ref = { + 'person_company_id_fkey': { + 'name': 'person_company_id_fkey', + 'constrained_columns': ['company_id'], + 'referred_columns': ['id'], + 'referred_table': 'company', + 'referred_schema': None, + 'options': { + 'onupdate': 'RESTRICT', + 'deferrable': True, + 'ondelete': 'RESTRICT', + 'initially': 'DEFERRED', + 'match': 'FULL' + } + }, + 'company_industry_id_fkey': { + 'name': 'company_industry_id_fkey', + 'constrained_columns': ['industry_id'], + 'referred_columns': ['id'], + 'referred_table': 'industry', + 'referred_schema': None, + 'options': { + 'onupdate': 'CASCADE', + 'deferrable': None, + 'ondelete': 'CASCADE', + 'initially': None, + 'match': None + } + } + } + metadata.create_all() + inspector = inspect(testing.db) + fks = inspector.get_foreign_keys('person') + \ + inspector.get_foreign_keys('company') + for fk in fks: + eq_(fk, fk_ref[fk['name']]) + class CustomTypeReflectionTest(fixtures.TestBase): class CustomType(object): diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 784f8bcbf..ba4b63e1a 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -10,18 +10,22 @@ from sqlalchemy import Table, Column, select, MetaData, text, Integer, \ PrimaryKeyConstraint, DateTime, tuple_, Float, BigInteger, \ func, literal_column, literal, bindparam, cast, extract, \ SmallInteger, Enum, REAL, update, insert, Index, delete, \ - and_, Date, TypeDecorator, Time, Unicode, Interval, or_, Text + and_, Date, TypeDecorator, Time, Unicode, Interval, or_, Text, \ + type_coerce from sqlalchemy.orm import Session, mapper, aliased from sqlalchemy import exc, schema, types from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \ - INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE + INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \ + JSON import decimal from sqlalchemy import util from sqlalchemy.testing.util import round_decimal from sqlalchemy.sql import table, column, operators import logging import re +from sqlalchemy import inspect +from sqlalchemy import event class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): __only_on__ = 'postgresql' @@ -96,34 +100,10 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): ([5], [5], [6], [decimal.Decimal("6.4")]) ) -class EnumTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): +class EnumTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = 'postgresql' - __dialect__ = postgresql.dialect() - def test_compile(self): - e1 = Enum('x', 'y', 'z', name='somename') - e2 = Enum('x', 'y', 'z', name='somename', schema='someschema') - self.assert_compile(postgresql.CreateEnumType(e1), - "CREATE TYPE somename AS ENUM ('x','y','z')" - ) - self.assert_compile(postgresql.CreateEnumType(e2), - "CREATE TYPE someschema.somename AS ENUM " - "('x','y','z')") - self.assert_compile(postgresql.DropEnumType(e1), - 'DROP TYPE somename') - self.assert_compile(postgresql.DropEnumType(e2), - 'DROP TYPE someschema.somename') - t1 = Table('sometable', MetaData(), Column('somecolumn', e1)) - self.assert_compile(schema.CreateTable(t1), - 'CREATE TABLE sometable (somecolumn ' - 'somename)') - t1 = Table('sometable', MetaData(), Column('somecolumn', - Enum('x', 'y', 'z', native_enum=False))) - self.assert_compile(schema.CreateTable(t1), - "CREATE TABLE sometable (somecolumn " - "VARCHAR(1), CHECK (somecolumn IN ('x', " - "'y', 'z')))") @testing.fails_on('postgresql+zxjdbc', 'zxjdbc fails on ENUM: column "XXX" is of type ' @@ -860,7 +840,8 @@ class SpecialTypesTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): Column('plain_interval', postgresql.INTERVAL), Column('year_interval', y2m()), Column('month_interval', d2s()), - Column('precision_interval', postgresql.INTERVAL(precision=3)) + Column('precision_interval', postgresql.INTERVAL(precision=3)), + Column('tsvector_document', postgresql.TSVECTOR) ) metadata.create_all() @@ -893,6 +874,17 @@ class SpecialTypesTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): self.assert_compile(type_, expected) @testing.provide_metadata + def test_tsvector_round_trip(self): + t = Table('t1', self.metadata, Column('data', postgresql.TSVECTOR)) + t.create() + testing.db.execute(t.insert(), data="a fat cat sat") + eq_(testing.db.scalar(select([t.c.data])), "'a' 'cat' 'fat' 'sat'") + + testing.db.execute(t.update(), data="'a' 'cat' 'fat' 'mat' 'sat'") + + eq_(testing.db.scalar(select([t.c.data])), "'a' 'cat' 'fat' 'mat' 'sat'") + + @testing.provide_metadata def test_bit_reflection(self): metadata = self.metadata t1 = Table('t1', metadata, @@ -918,7 +910,6 @@ class UUIDTest(fixtures.TestBase): __only_on__ = 'postgresql' - @testing.requires.python25 @testing.fails_on('postgresql+zxjdbc', 'column "data" is of type uuid but expression is of type character varying') @testing.fails_on('postgresql+pg8000', 'No support for UUID type') @@ -932,7 +923,6 @@ class UUIDTest(fixtures.TestBase): str(uuid.uuid4()) ) - @testing.requires.python25 @testing.fails_on('postgresql+zxjdbc', 'column "data" is of type uuid but expression is of type character varying') @testing.fails_on('postgresql+pg8000', 'No support for UUID type') @@ -978,13 +968,8 @@ class UUIDTest(fixtures.TestBase): -class HStoreTest(fixtures.TestBase): - def _assert_sql(self, construct, expected): - dialect = postgresql.dialect() - compiled = str(construct.compile(dialect=dialect)) - compiled = re.sub(r'\s+', ' ', compiled) - expected = re.sub(r'\s+', ' ', expected) - eq_(compiled, expected) +class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): + __dialect__ = 'postgresql' def setup(self): metadata = MetaData() @@ -996,7 +981,7 @@ class HStoreTest(fixtures.TestBase): def _test_where(self, whereclause, expected): stmt = select([self.test_table]).where(whereclause) - self._assert_sql( + self.assert_compile( stmt, "SELECT test_table.id, test_table.hash FROM test_table " "WHERE %s" % expected @@ -1004,7 +989,7 @@ class HStoreTest(fixtures.TestBase): def _test_cols(self, colclause, expected, from_=True): stmt = select([colclause]) - self._assert_sql( + self.assert_compile( stmt, ( "SELECT %s" + @@ -1013,9 +998,8 @@ class HStoreTest(fixtures.TestBase): ) def test_bind_serialize_default(self): - from sqlalchemy.engine import default - dialect = default.DefaultDialect() + dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_bind_processor(dialect) eq_( proc(util.OrderedDict([("key1", "value1"), ("key2", "value2")])), @@ -1023,9 +1007,7 @@ class HStoreTest(fixtures.TestBase): ) def test_bind_serialize_with_slashes_and_quotes(self): - from sqlalchemy.engine import default - - dialect = default.DefaultDialect() + dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_bind_processor(dialect) eq_( proc({'\\"a': '\\"1'}), @@ -1033,9 +1015,7 @@ class HStoreTest(fixtures.TestBase): ) def test_parse_error(self): - from sqlalchemy.engine import default - - dialect = default.DefaultDialect() + dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_result_processor( dialect, None) assert_raises_message( @@ -1048,9 +1028,7 @@ class HStoreTest(fixtures.TestBase): ) def test_result_deserialize_default(self): - from sqlalchemy.engine import default - - dialect = default.DefaultDialect() + dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_result_processor( dialect, None) eq_( @@ -1059,9 +1037,7 @@ class HStoreTest(fixtures.TestBase): ) def test_result_deserialize_with_slashes_and_quotes(self): - from sqlalchemy.engine import default - - dialect = default.DefaultDialect() + dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_result_processor( dialect, None) eq_( @@ -1305,7 +1281,6 @@ class HStoreRoundTripTest(fixtures.TablesTest): return engine def test_reflect(self): - from sqlalchemy import inspect insp = inspect(testing.db) cols = insp.get_columns('data_table') assert isinstance(cols[2]['type'], HSTORE) @@ -1677,3 +1652,320 @@ class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest): def _data_obj(self): return self.extras.DateTimeTZRange(*self.tstzs()) + + +class JSONTest(AssertsCompiledSQL, fixtures.TestBase): + __dialect__ = 'postgresql' + + def setup(self): + metadata = MetaData() + self.test_table = Table('test_table', metadata, + Column('id', Integer, primary_key=True), + Column('test_column', JSON) + ) + self.jsoncol = self.test_table.c.test_column + + def _test_where(self, whereclause, expected): + stmt = select([self.test_table]).where(whereclause) + self.assert_compile( + stmt, + "SELECT test_table.id, test_table.test_column FROM test_table " + "WHERE %s" % expected + ) + + def _test_cols(self, colclause, expected, from_=True): + stmt = select([colclause]) + self.assert_compile( + stmt, + ( + "SELECT %s" + + (" FROM test_table" if from_ else "") + ) % expected + ) + + def test_bind_serialize_default(self): + dialect = postgresql.dialect() + proc = self.test_table.c.test_column.type._cached_bind_processor(dialect) + eq_( + proc({"A": [1, 2, 3, True, False]}), + '{"A": [1, 2, 3, true, false]}' + ) + + def test_result_deserialize_default(self): + dialect = postgresql.dialect() + proc = self.test_table.c.test_column.type._cached_result_processor( + dialect, None) + eq_( + proc('{"A": [1, 2, 3, true, false]}'), + {"A": [1, 2, 3, True, False]} + ) + + # This test is a bit misleading -- in real life you will need to cast to do anything + def test_where_getitem(self): + self._test_where( + self.jsoncol['bar'] == None, + "(test_table.test_column -> %(test_column_1)s) IS NULL" + ) + + def test_where_path(self): + self._test_where( + self.jsoncol[("foo", 1)] == None, + "(test_table.test_column #> %(test_column_1)s) IS NULL" + ) + + def test_where_getitem_as_text(self): + self._test_where( + self.jsoncol['bar'].astext == None, + "(test_table.test_column ->> %(test_column_1)s) IS NULL" + ) + + def test_where_getitem_as_cast(self): + self._test_where( + self.jsoncol['bar'].cast(Integer) == 5, + "CAST(test_table.test_column ->> %(test_column_1)s AS INTEGER) " + "= %(param_1)s" + ) + + def test_where_path_as_text(self): + self._test_where( + self.jsoncol[("foo", 1)].astext == None, + "(test_table.test_column #>> %(test_column_1)s) IS NULL" + ) + + def test_cols_get(self): + self._test_cols( + self.jsoncol['foo'], + "test_table.test_column -> %(test_column_1)s AS anon_1", + True + ) + + +class JSONRoundTripTest(fixtures.TablesTest): + __only_on__ = ('postgresql >= 9.3',) + + @classmethod + def define_tables(cls, metadata): + Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30), nullable=False), + Column('data', JSON) + ) + + def _fixture_data(self, engine): + data_table = self.tables.data_table + engine.execute( + data_table.insert(), + {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}, + {'name': 'r2', 'data': {"k1": "r2v1", "k2": "r2v2"}}, + {'name': 'r3', 'data': {"k1": "r3v1", "k2": "r3v2"}}, + {'name': 'r4', 'data': {"k1": "r4v1", "k2": "r4v2"}}, + {'name': 'r5', 'data': {"k1": "r5v1", "k2": "r5v2", "k3": 5}}, + ) + + def _assert_data(self, compare): + data = testing.db.execute( + select([self.tables.data_table.c.data]). + order_by(self.tables.data_table.c.name) + ).fetchall() + eq_([d for d, in data], compare) + + def _test_insert(self, engine): + engine.execute( + self.tables.data_table.insert(), + {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}} + ) + self._assert_data([{"k1": "r1v1", "k2": "r1v2"}]) + + def _non_native_engine(self, json_serializer=None, json_deserializer=None): + if json_serializer is not None or json_deserializer is not None: + options = { + "json_serializer": json_serializer, + "json_deserializer": json_deserializer + } + else: + options = {} + + if testing.against("postgresql+psycopg2"): + from psycopg2.extras import register_default_json + engine = engines.testing_engine(options=options) + @event.listens_for(engine, "connect") + def connect(dbapi_connection, connection_record): + engine.dialect._has_native_json = False + def pass_(value): + return value + register_default_json(dbapi_connection, loads=pass_) + elif options: + engine = engines.testing_engine(options=options) + else: + engine = testing.db + engine.connect() + return engine + + def test_reflect(self): + insp = inspect(testing.db) + cols = insp.get_columns('data_table') + assert isinstance(cols[2]['type'], JSON) + + @testing.only_on("postgresql+psycopg2") + def test_insert_native(self): + engine = testing.db + self._test_insert(engine) + + def test_insert_python(self): + engine = self._non_native_engine() + self._test_insert(engine) + + + def _test_custom_serialize_deserialize(self, native): + import json + def loads(value): + value = json.loads(value) + value['x'] = value['x'] + '_loads' + return value + + def dumps(value): + value = dict(value) + value['x'] = 'dumps_y' + return json.dumps(value) + + if native: + engine = engines.testing_engine(options=dict( + json_serializer=dumps, + json_deserializer=loads + )) + else: + engine = self._non_native_engine( + json_serializer=dumps, + json_deserializer=loads + ) + + s = select([ + cast( + { + "key": "value", + "x": "q" + }, + JSON + ) + ]) + eq_( + engine.scalar(s), + { + "key": "value", + "x": "dumps_y_loads" + }, + ) + + @testing.only_on("postgresql+psycopg2") + def test_custom_native(self): + self._test_custom_serialize_deserialize(True) + + @testing.only_on("postgresql+psycopg2") + def test_custom_python(self): + self._test_custom_serialize_deserialize(False) + + + @testing.only_on("postgresql+psycopg2") + def test_criterion_native(self): + engine = testing.db + self._fixture_data(engine) + self._test_criterion(engine) + + def test_criterion_python(self): + engine = self._non_native_engine() + self._fixture_data(engine) + self._test_criterion(engine) + + def test_path_query(self): + engine = testing.db + self._fixture_data(engine) + data_table = self.tables.data_table + result = engine.execute( + select([data_table.c.data]).where( + data_table.c.data[('k1',)].astext == 'r3v1' + ) + ).first() + eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},)) + + def test_query_returned_as_text(self): + engine = testing.db + self._fixture_data(engine) + data_table = self.tables.data_table + result = engine.execute( + select([data_table.c.data['k1'].astext]) + ).first() + assert isinstance(result[0], util.text_type) + + def test_query_returned_as_int(self): + engine = testing.db + self._fixture_data(engine) + data_table = self.tables.data_table + result = engine.execute( + select([data_table.c.data['k3'].cast(Integer)]).where( + data_table.c.name == 'r5') + ).first() + assert isinstance(result[0], int) + + def _test_criterion(self, engine): + data_table = self.tables.data_table + result = engine.execute( + select([data_table.c.data]).where( + data_table.c.data['k1'].astext == 'r3v1' + ) + ).first() + eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},)) + + def _test_fixed_round_trip(self, engine): + s = select([ + cast( + { + "key": "value", + "key2": {"k1": "v1", "k2": "v2"} + }, + JSON + ) + ]) + eq_( + engine.scalar(s), + { + "key": "value", + "key2": {"k1": "v1", "k2": "v2"} + }, + ) + + def test_fixed_round_trip_python(self): + engine = self._non_native_engine() + self._test_fixed_round_trip(engine) + + @testing.only_on("postgresql+psycopg2") + def test_fixed_round_trip_native(self): + engine = testing.db + self._test_fixed_round_trip(engine) + + def _test_unicode_round_trip(self, engine): + s = select([ + cast( + { + util.u('réveillé'): util.u('réveillé'), + "data": {"k1": util.u('drôle')} + }, + JSON + ) + ]) + eq_( + engine.scalar(s), + { + util.u('réveillé'): util.u('réveillé'), + "data": {"k1": util.u('drôle')} + }, + ) + + + def test_unicode_round_trip_python(self): + engine = self._non_native_engine() + self._test_unicode_round_trip(engine) + + @testing.only_on("postgresql+psycopg2") + def test_unicode_round_trip_native(self): + engine = testing.db + self._test_unicode_round_trip(engine) |
