diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-11-17 18:53:23 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-11-17 18:53:23 -0500 |
commit | 6369292dcf62561d23c084b3da5ca35c309af552 (patch) | |
tree | c06bf431a7e3912f0fdcd414658a6d90a435dbcf | |
parent | 40071dbda4c2467f10a1ef217ce1d6e64058fba3 (diff) | |
parent | 8b327807aefcb2df56902c94f249c4fe831fdfe1 (diff) | |
download | sqlalchemy-6369292dcf62561d23c084b3da5ca35c309af552.tar.gz |
Merged in audriusk/sqlalchemy_pg_hstore (pull request #26). will adjust some aspects of it, including replace userdefinedtype with typeengine, and move mutationdict to be part of sqlalchemy.ext.mutable
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/__init__.py | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/hstore.py | 306 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg2.py | 26 | ||||
-rw-r--r-- | test/dialect/test_postgresql.py | 192 |
4 files changed, 531 insertions, 6 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 3c273bd56..2a1a07cbd 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -9,12 +9,15 @@ from . import base, psycopg2, pg8000, pypostgresql, zxjdbc base.dialect = psycopg2.dialect from .base import \ - INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, INET, \ - CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME,\ + INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \ + INET, CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME, \ DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array +from .hstore import HSTORE, hstore, HStoreSyntaxError __all__ = ( -'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', 'FLOAT', 'REAL', 'INET', -'CIDR', 'UUID', 'BIT', 'MACADDR', 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', -'DATE', 'BYTEA', 'BOOLEAN', 'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'array' + 'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', + 'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR', + 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN', + 'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'array', 'HSTORE', 'hstore', + 'HStoreSyntaxError' ) diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py new file mode 100644 index 000000000..4797031fa --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -0,0 +1,306 @@ +# postgresql/hstore.py +# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import re + +from .base import ARRAY +from ... import types as sqltypes +from ...sql import functions as sqlfunc +from ...sql.operators import custom_op +from ...exc import SQLAlchemyError +from ...ext.mutable import Mutable + +__all__ = ('HStoreSyntaxError', 'HSTORE', 'hstore') + +# My best guess at the parsing rules of hstore literals, since no formal +# grammar is given. This is mostly reverse engineered from PG's input parser +# behavior. +HSTORE_PAIR_RE = re.compile(r""" +( + "(?P<key> (\\ . | [^"])* )" # Quoted key +) +[ ]* => [ ]* # Pair operator, optional adjoining whitespace +( + (?P<value_null> NULL ) # NULL value + | "(?P<value> (\\ . | [^"])* )" # Quoted value +) +""", re.VERBOSE) + +HSTORE_DELIMITER_RE = re.compile(r""" +[ ]* , [ ]* +""", re.VERBOSE) + + +class HStoreSyntaxError(SQLAlchemyError): + """Indicates an error unmarshalling an hstore value.""" + + def __init__(self, hstore_str, pos): + self.hstore_str = hstore_str + self.pos = pos + + ctx = 20 + hslen = len(hstore_str) + + parsed_tail = hstore_str[max(pos - ctx - 1, 0):min(pos, hslen)] + residual = hstore_str[min(pos, hslen):min(pos + ctx + 1, hslen)] + + if len(parsed_tail) > ctx: + parsed_tail = '[...]' + parsed_tail[1:] + if len(residual) > ctx: + residual = residual[:-1] + '[...]' + + super(HStoreSyntaxError, self).__init__( + "After %r, could not parse residual at position %d: %r" % + (parsed_tail, pos, residual) + ) + + +def _parse_hstore(hstore_str): + """Parse an hstore from it's literal string representation. + + Attempts to approximate PG's hstore input parsing rules as closely as + possible. Although currently this is not strictly necessary, since the + current implementation of hstore's output syntax is stricter than what it + accepts as input, the documentation makes no guarantees that will always + be the case. + + Throws HStoreSyntaxError if parsing fails. + + """ + result = {} + pos = 0 + pair_match = HSTORE_PAIR_RE.match(hstore_str) + + while pair_match is not None: + key = pair_match.group('key') + if pair_match.group('value_null'): + value = None + else: + value = pair_match.group('value').replace(r'\"', '"') + result[key] = value + + pos += pair_match.end() + + delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:]) + if delim_match is not None: + pos += delim_match.end() + + pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:]) + + if pos != len(hstore_str): + raise HStoreSyntaxError(hstore_str, pos) + + return result + + +def _serialize_hstore(val): + """Serialize a dictionary into an hstore literal. Keys and values must + both be strings (except None for values). + + """ + def esc(s, position): + if position == 'value' and s is None: + return 'NULL' + elif isinstance(s, basestring): + return '"%s"' % s.replace('"', r'\"') + else: + raise ValueError("%r in %s position is not a string." % + (s, position)) + + return ', '.join('%s=>%s' % (esc(k, 'key'), esc(v, 'value')) + for k, v in val.iteritems()) + + +class MutationDict(Mutable, dict): + def __setitem__(self, key, value): + """Detect dictionary set events and emit change events.""" + dict.__setitem__(self, key, value) + self.changed() + + def __delitem__(self, key, value): + """Detect dictionary del events and emit change events.""" + dict.__delitem__(self, key, value) + self.changed() + + @classmethod + def coerce(cls, key, value): + """Convert plain dictionary to MutationDict.""" + if not isinstance(value, MutationDict): + if isinstance(value, dict): + return MutationDict(value) + return Mutable.coerce(key, value) + else: + return value + + def __getstate__(self): + return dict(self) + + def __setstate__(self, state): + self.update(state) + + +class HSTORE(sqltypes.Concatenable, sqltypes.UserDefinedType): + """The column type for representing PostgreSQL's contrib/hstore type. This + type is a miniature key-value store in a column. It supports query + operators for all the usual operations on a map-like data structure. + + """ + class comparator_factory(sqltypes.UserDefinedType.Comparator): + def has_key(self, other): + """Boolean expression. Test for presence of a key. Note that the + key may be a SQLA expression. + """ + return self.expr.op('?')(other) + + def has_all(self, other): + """Boolean expression. Test for presence of all keys in the PG + array. + """ + return self.expr.op('?&')(other) + + def has_any(self, other): + """Boolean expression. Test for presence of any key in the PG + array. + """ + return self.expr.op('?|')(other) + + def defined(self, key): + """Boolean expression. Test for presence of a non-NULL value for + the key. Note that the key may be a SQLA expression. + """ + return _HStoreDefinedFunction(self.expr, key) + + def contains(self, other, **kwargs): + """Boolean expression. Test if keys are a superset of the keys of + the argument hstore expression. + """ + return self.expr.op('@>')(other) + + def contained_by(self, other): + """Boolean expression. Test if keys are a proper subset of the + keys of the argument hstore expression. + """ + return self.expr.op('<@')(other) + + def __getitem__(self, other): + """Text expression. Get the value at a given key. Note that the + key may be a SQLA expression. + """ + return self.expr.op('->', precedence=5)(other) + + def __add__(self, other): + """HStore expression. Merge the left and right hstore expressions, + with duplicate keys taking the value from the right expression. + """ + return self.expr.concat(other) + + def delete(self, key): + """HStore expression. Returns the contents of this hstore with the + given key deleted. Note that the key may be a SQLA expression. + """ + if isinstance(key, dict): + key = _serialize_hstore(key) + return _HStoreDeleteFunction(self.expr, key) + + def slice(self, array): + """HStore expression. Returns a subset of an hstore defined by + array of keys. + """ + return _HStoreSliceFunction(self.expr, array) + + def keys(self): + """Text array expression. Returns array of keys.""" + return _HStoreKeysFunction(self.expr) + + def vals(self): + """Text array expression. Returns array of values.""" + return _HStoreValsFunction(self.expr) + + def array(self): + """Text array expression. Returns array of alternating keys and + values. + """ + return _HStoreArrayFunction(self.expr) + + def matrix(self): + """Text array expression. Returns array of [key, value] pairs.""" + return _HStoreMatrixFunction(self.expr) + + def _adapt_expression(self, op, other_comparator): + if isinstance(op, custom_op): + if op.opstring in ['?', '?&', '?|', '@>', '<@']: + return op, sqltypes.Boolean + elif op.opstring == '->': + return op, sqltypes.Text + return op, other_comparator.type + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, dict): + return _serialize_hstore(value) + else: + return value + return process + + def get_col_spec(self): + return 'HSTORE' + + def result_processor(self, dialect, coltype): + def process(value): + if value is not None: + return _parse_hstore(value) + else: + return value + return process + +MutationDict.associate_with(HSTORE) + + +class hstore(sqlfunc.GenericFunction): + """Construct an hstore on the server side using the hstore function. + + The single argument or a pair of arguments are evaluated as SQLAlchemy + expressions, so both may contain columns, function calls, or any other + valid SQL expressions which evaluate to text or array. + + """ + type = HSTORE + name = 'hstore' + + +class _HStoreDefinedFunction(sqlfunc.GenericFunction): + type = sqltypes.Boolean + name = 'defined' + + +class _HStoreDeleteFunction(sqlfunc.GenericFunction): + type = HSTORE + name = 'delete' + + +class _HStoreSliceFunction(sqlfunc.GenericFunction): + type = HSTORE + name = 'slice' + + +class _HStoreKeysFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = 'akeys' + + +class _HStoreValsFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = 'avals' + + +class _HStoreArrayFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = 'hstore_to_array' + + +class _HStoreMatrixFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = 'hstore_to_matrix' diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 700f76793..05286ce20 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -147,6 +147,7 @@ from .base import PGDialect, PGCompiler, \ PGIdentifierPreparer, PGExecutionContext, \ ENUM, ARRAY, _DECIMAL_TYPES, _FLOAT_TYPES,\ _INT_TYPES +from .hstore import HSTORE logger = logging.getLogger('sqlalchemy.dialects.postgresql') @@ -195,6 +196,13 @@ class _PGArray(ARRAY): self.item_type.convert_unicode = "force" # end Py2K +class _PGHStore(HSTORE): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + return None + # When we're handed literal SQL, ensure it's a SELECT-query. Since # 8.3, combining cursors and "FOR UPDATE" has been fine. SERVER_SIDE_CURSOR_RE = re.compile( @@ -282,6 +290,7 @@ class PGDialect_psycopg2(PGDialect): ENUM : _PGEnum, # needs force_unicode sqltypes.Enum : _PGEnum, # needs force_unicode ARRAY : _PGArray, # needs force_unicode + HSTORE : _PGHStore, } ) @@ -300,6 +309,16 @@ class PGDialect_psycopg2(PGDialect): int(x) for x in m.group(1, 2, 3) if x is not None) + self._hstore_oids = None + + def initialize(self, connection): + super(PGDialect_psycopg2, self).initialize(connection) + + if self.psycopg2_version >= (2, 4): + extras = __import__('psycopg2.extras').extras + oids = extras.HstoreAdapter.get_oids(connection.connection) + if oids is not None and oids[0]: + self._hstore_oids = oids[0], oids[1] @classmethod def dbapi(cls): @@ -346,6 +365,13 @@ class PGDialect_psycopg2(PGDialect): extensions.register_type(extensions.UNICODE, conn) fns.append(on_connect) + extras = __import__('psycopg2.extras').extras + def on_connect(conn): + if self._hstore_oids is not None: + oid, array_oid = self._hstore_oids + extras.register_hstore(conn, oid=oid, array_oid=array_oid) + fns.append(on_connect) + if fns: def on_connect(conn): for fn in fns: diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index 3be005f36..33753b48f 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -13,14 +13,16 @@ from sqlalchemy import Table, Column, select, MetaData, text, Integer, \ PrimaryKeyConstraint, DateTime, tuple_, Float, BigInteger, \ func, literal_column, literal, bindparam, cast, extract, \ SmallInteger, Enum, REAL, update, insert, Index, delete, \ - and_, Date, TypeDecorator, Time, Unicode, Interval, or_ + and_, Date, TypeDecorator, Time, Unicode, Interval, or_, Text from sqlalchemy.orm import Session, mapper, aliased from sqlalchemy import exc, schema, types from sqlalchemy.dialects.postgresql import base as postgresql +from sqlalchemy.dialects.postgresql import HSTORE, hstore from sqlalchemy.util.compat import decimal from sqlalchemy.testing.util import round_decimal from sqlalchemy.sql import table, column import logging +import re class SequenceTest(fixtures.TestBase, AssertsCompiledSQL): @@ -2707,3 +2709,191 @@ class TupleTest(fixtures.TestBase): ).scalar(), exp ) + + +class HStoreTest(fixtures.TestBase): + def _assert_sql(self, construct, expected): + dialect = postgresql.dialect() + compiled = str(construct.compile(dialect=dialect)) + compiled = re.sub(r'\s+', ' ', compiled) + expected = re.sub(r'\s+', ' ', expected) + eq_(compiled, expected) + + def setup(self): + metadata = MetaData() + self.test_table = Table('test_table', metadata, + Column('id', Integer, primary_key=True), + Column('hash', HSTORE) + ) + self.hashcol = self.test_table.c.hash + + def _test_where(self, whereclause, expected): + stmt = select([self.test_table]).where(whereclause) + self._assert_sql( + stmt, + "SELECT test_table.id, test_table.hash FROM test_table " + "WHERE %s" % expected + ) + + def _test_cols(self, colclause, expected, from_=True): + stmt = select([colclause]) + self._assert_sql( + stmt, + ( + "SELECT %s" + + (" FROM test_table" if from_ else "") + ) % expected + ) + + def test_where_has_key(self): + self._test_where( + self.hashcol.has_key('foo'), + "test_table.hash ? %(hash_1)s" + ) + + def test_where_has_all(self): + self._test_where( + self.hashcol.has_all(postgresql.array(['1', '2'])), + "test_table.hash ?& ARRAY[%(param_1)s, %(param_2)s]" + ) + + def test_where_has_any(self): + self._test_where( + self.hashcol.has_any(postgresql.array(['1', '2'])), + "test_table.hash ?| ARRAY[%(param_1)s, %(param_2)s]" + ) + + def test_where_defined(self): + self._test_where( + self.hashcol.defined('foo'), + "defined(test_table.hash, %(param_1)s)" + ) + + def test_where_contains(self): + self._test_where( + self.hashcol.contains({'foo': '1'}), + "test_table.hash @> %(hash_1)s" + ) + + def test_where_contained_by(self): + self._test_where( + self.hashcol.contained_by({'foo': '1', 'bar': None}), + "test_table.hash <@ %(hash_1)s" + ) + + def test_where_getitem(self): + self._test_where( + self.hashcol['bar'] == None, + "(test_table.hash -> %(hash_1)s) IS NULL" + ) + + def test_cols_get(self): + self._test_cols( + self.hashcol['foo'], + "test_table.hash -> %(hash_1)s AS anon_1", + True + ) + + def test_cols_delete_single_key(self): + self._test_cols( + self.hashcol.delete('foo'), + "delete(test_table.hash, %(param_1)s) AS delete_1", + True + ) + + def test_cols_delete_array_of_keys(self): + self._test_cols( + self.hashcol.delete(postgresql.array(['foo', 'bar'])), + ("delete(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) " + "AS delete_1"), + True + ) + + def test_cols_delete_matching_pairs(self): + self._test_cols( + self.hashcol.delete(hstore('1', '2')), + ("delete(test_table.hash, hstore(%(param_1)s, %(param_2)s)) " + "AS delete_1"), + True + ) + + def test_cols_slice(self): + self._test_cols( + self.hashcol.slice(postgresql.array(['1', '2'])), + ("slice(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) " + "AS slice_1"), + True + ) + + def test_cols_hstore_pair_text(self): + self._test_cols( + hstore('foo', '3')['foo'], + "hstore(%(param_1)s, %(param_2)s) -> %(hstore_1)s AS anon_1", + False + ) + + def test_cols_hstore_pair_array(self): + self._test_cols( + hstore(postgresql.array(['1', '2']), + postgresql.array(['3', None]))['1'], + ("hstore(ARRAY[%(param_1)s, %(param_2)s], " + "ARRAY[%(param_3)s, NULL]) -> %(hstore_1)s AS anon_1"), + False + ) + + def test_cols_hstore_single_array(self): + self._test_cols( + hstore(postgresql.array(['1', '2', '3', None]))['3'], + ("hstore(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s, NULL]) " + "-> %(hstore_1)s AS anon_1"), + False + ) + + def test_cols_concat(self): + self._test_cols( + self.hashcol.concat(hstore(cast(self.test_table.c.id, Text), '3')), + ("test_table.hash || hstore(CAST(test_table.id AS TEXT), " + "%(param_1)s) AS anon_1"), + True + ) + + def test_cols_concat_op(self): + self._test_cols( + self.hashcol + self.hashcol, + "test_table.hash || test_table.hash AS anon_1", + True + ) + + def test_cols_concat_get(self): + self._test_cols( + (self.hashcol + self.hashcol)['foo'], + "test_table.hash || test_table.hash -> %(param_1)s AS anon_1" + ) + + def test_cols_keys(self): + self._test_cols( + self.hashcol.keys(), + "akeys(test_table.hash) AS akeys_1", + True + ) + + def test_cols_vals(self): + self._test_cols( + self.hashcol.vals(), + "avals(test_table.hash) AS avals_1", + True + ) + + def test_cols_array(self): + self._test_cols( + self.hashcol.array(), + "hstore_to_array(test_table.hash) AS hstore_to_array_1", + True + ) + + def test_cols_matrix(self): + self._test_cols( + self.hashcol.matrix(), + "hstore_to_matrix(test_table.hash) AS hstore_to_matrix_1", + True + ) |