diff options
author | Christoph Zwerschke <cito@online.de> | 2016-04-11 23:15:25 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-04-15 12:00:27 -0400 |
commit | b59cbb5fd7d2ab56334774c6354fffa0b50448bc (patch) | |
tree | f08bb82751830440739a2854a783fa1763bb9a63 /lib/sqlalchemy/dialects/postgresql/pygresql.py | |
parent | bde46e33593805584c7c0dedb3a666909fb67888 (diff) | |
download | sqlalchemy-b59cbb5fd7d2ab56334774c6354fffa0b50448bc.tar.gz |
- Add support for PostgreSQL with PyGreSQL
Change-Id: I040b75ff3b4110e7e8b26442a4eb226ba8c26715
Pull-request: https://github.com/zzzeek/sqlalchemy/pull/234
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/pygresql.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/pygresql.py | 243 |
1 files changed, 243 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/pygresql.py b/lib/sqlalchemy/dialects/postgresql/pygresql.py new file mode 100644 index 000000000..d30206613 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pygresql.py @@ -0,0 +1,243 @@ +# postgresql/pygresql.py +# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: postgresql+pygresql + :name: pygresql + :dbapi: pgdb + :connectstring: postgresql+pygresql://user:password@host:port/dbname\ +[?key=value&key=value...] + :url: http://www.pygresql.org/ +""" + +import decimal +import re + +from ... import exc, processors, util +from ...types import Numeric, JSON as Json +from ...sql.elements import Null +from .base import PGDialect, PGCompiler, PGIdentifierPreparer, \ + _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID +from .hstore import HSTORE +from .json import JSON, JSONB + + +class _PGNumeric(Numeric): + + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if not isinstance(coltype, int): + coltype = coltype.oid + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, + self._effective_decimal_return_scale) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # PyGreSQL returns Decimal natively for 1700 (numeric) + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype) + else: + if coltype in _FLOAT_TYPES: + # PyGreSQL returns float natively for 701 (float8) + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype) + + +class _PGHStore(HSTORE): + + def bind_processor(self, dialect): + if not dialect.has_native_hstore: + return super(_PGHStore, self).bind_processor(dialect) + hstore = dialect.dbapi.Hstore + def process(value): + if isinstance(value, dict): + return hstore(value) + return value + return process + + def result_processor(self, dialect, coltype): + if not dialect.has_native_hstore: + return super(_PGHStore, self).result_processor(dialect, coltype) + + +class _PGJSON(JSON): + + def bind_processor(self, dialect): + if not dialect.has_native_json: + return super(_PGJSON, self).bind_processor(dialect) + json = dialect.dbapi.Json + + def process(value): + if value is self.NULL: + value = None + elif isinstance(value, Null) or ( + value is None and self.none_as_null): + return None + if value is None or isinstance(value, (dict, list)): + return json(value) + return value + + return process + + def result_processor(self, dialect, coltype): + if not dialect.has_native_json: + return super(_PGJSON, self).result_processor(dialect, coltype) + + +class _PGJSONB(JSONB): + + def bind_processor(self, dialect): + if not dialect.has_native_json: + return super(_PGJSONB, self).bind_processor(dialect) + json = dialect.dbapi.Json + + def process(value): + if value is self.NULL: + value = None + elif isinstance(value, Null) or ( + value is None and self.none_as_null): + return None + if value is None or isinstance(value, (dict, list)): + return json(value) + return value + + return process + + def result_processor(self, dialect, coltype): + if not dialect.has_native_json: + return super(_PGJSONB, self).result_processor(dialect, coltype) + + +class _PGUUID(UUID): + + def bind_processor(self, dialect): + if not dialect.has_native_uuid: + return super(_PGUUID, self).bind_processor(dialect) + uuid = dialect.dbapi.Uuid + + def process(value): + if value is None: + return None + if isinstance(value, (str, bytes)): + if len(value) == 16: + return uuid(bytes=value) + return uuid(value) + if isinstance(value, int): + return uuid(int=value) + return value + + return process + + def result_processor(self, dialect, coltype): + if not dialect.has_native_uuid: + return super(_PGUUID, self).result_processor(dialect, coltype) + if not self.as_uuid: + def process(value): + if value is not None: + return str(value) + return process + + +class _PGCompiler(PGCompiler): + + def visit_mod_binary(self, binary, operator, **kw): + return self.process(binary.left, **kw) + " %% " + \ + self.process(binary.right, **kw) + + def post_process_text(self, text): + return text.replace('%', '%%') + + +class _PGIdentifierPreparer(PGIdentifierPreparer): + + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + return value.replace('%', '%%') + + +class PGDialect_pygresql(PGDialect): + + driver = 'pygresql' + + statement_compiler = _PGCompiler + preparer = _PGIdentifierPreparer + + @classmethod + def dbapi(cls): + import pgdb + return pgdb + + colspecs = util.update_copy( + PGDialect.colspecs, + { + Numeric: _PGNumeric, + HSTORE: _PGHStore, + Json: _PGJSON, + JSON: _PGJSON, + JSONB: _PGJSONB, + UUID: _PGUUID, + } + ) + + def __init__(self, **kwargs): + super(PGDialect_pygresql, self).__init__(**kwargs) + try: + version = self.dbapi.version + m = re.match(r'(\d+)\.(\d+)', version) + version = (int(m.group(1)), int(m.group(2))) + except (AttributeError, ValueError, TypeError): + version = (0, 0) + self.dbapi_version = version + if version < (5, 0): + has_native_hstore = has_native_json = has_native_uuid = False + if version != (0, 0): + util.warn("PyGreSQL is only fully supported by SQLAlchemy" + " since version 5.0.") + else: + self.supports_unicode_statements = True + self.supports_unicode_binds = True + has_native_hstore = has_native_json = has_native_uuid = True + self.has_native_hstore = has_native_hstore + self.has_native_json = has_native_json + self.has_native_uuid = has_native_uuid + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['host'] = '%s:%s' % ( + opts.get('host', '').rsplit(':', 1)[0], opts.pop('port')) + opts.update(url.query) + return [], opts + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error): + if not connection: + return False + try: + connection = connection.connection + except AttributeError: + pass + else: + if not connection: + return False + try: + return connection.closed + except AttributeError: # PyGreSQL < 5.0 + return connection._cnx is None + return False + + +dialect = PGDialect_pygresql |