diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/pygresql.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/pygresql.py | 72 |
1 files changed, 44 insertions, 28 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/pygresql.py b/lib/sqlalchemy/dialects/postgresql/pygresql.py index 304afca44..c7edb8fc3 100644 --- a/lib/sqlalchemy/dialects/postgresql/pygresql.py +++ b/lib/sqlalchemy/dialects/postgresql/pygresql.py @@ -20,14 +20,20 @@ import re from ... import exc, processors, util from ...types import Numeric, JSON as Json from ...sql.elements import Null -from .base import PGDialect, PGCompiler, PGIdentifierPreparer, \ - _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID +from .base import ( + PGDialect, + PGCompiler, + PGIdentifierPreparer, + _DECIMAL_TYPES, + _FLOAT_TYPES, + _INT_TYPES, + UUID, +) from .hstore import HSTORE from .json import JSON, JSONB class _PGNumeric(Numeric): - def bind_processor(self, dialect): return None @@ -37,14 +43,15 @@ class _PGNumeric(Numeric): if self.asdecimal: if coltype in _FLOAT_TYPES: return processors.to_decimal_processor_factory( - decimal.Decimal, - self._effective_decimal_return_scale) + decimal.Decimal, self._effective_decimal_return_scale + ) elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: # PyGreSQL returns Decimal natively for 1700 (numeric) return None else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) else: if coltype in _FLOAT_TYPES: # PyGreSQL returns float natively for 701 (float8) @@ -53,19 +60,21 @@ class _PGNumeric(Numeric): return processors.to_float else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) class _PGHStore(HSTORE): - def bind_processor(self, dialect): if not dialect.has_native_hstore: return super(_PGHStore, self).bind_processor(dialect) hstore = dialect.dbapi.Hstore + def process(value): if isinstance(value, dict): return hstore(value) return value + return process def result_processor(self, dialect, coltype): @@ -74,7 +83,6 @@ class _PGHStore(HSTORE): class _PGJSON(JSON): - def bind_processor(self, dialect): if not dialect.has_native_json: return super(_PGJSON, self).bind_processor(dialect) @@ -84,7 +92,8 @@ class _PGJSON(JSON): if value is self.NULL: value = None elif isinstance(value, Null) or ( - value is None and self.none_as_null): + value is None and self.none_as_null + ): return None if value is None or isinstance(value, (dict, list)): return json(value) @@ -98,7 +107,6 @@ class _PGJSON(JSON): class _PGJSONB(JSONB): - def bind_processor(self, dialect): if not dialect.has_native_json: return super(_PGJSONB, self).bind_processor(dialect) @@ -108,7 +116,8 @@ class _PGJSONB(JSONB): if value is self.NULL: value = None elif isinstance(value, Null) or ( - value is None and self.none_as_null): + value is None and self.none_as_null + ): return None if value is None or isinstance(value, (dict, list)): return json(value) @@ -122,7 +131,6 @@ class _PGJSONB(JSONB): class _PGUUID(UUID): - def bind_processor(self, dialect): if not dialect.has_native_uuid: return super(_PGUUID, self).bind_processor(dialect) @@ -145,32 +153,35 @@ class _PGUUID(UUID): if not dialect.has_native_uuid: return super(_PGUUID, self).result_processor(dialect, coltype) if not self.as_uuid: + def process(value): if value is not None: return str(value) + return process class _PGCompiler(PGCompiler): - def visit_mod_binary(self, binary, operator, **kw): - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) def post_process_text(self, text): - return text.replace('%', '%%') + return text.replace("%", "%%") class _PGIdentifierPreparer(PGIdentifierPreparer): - def _escape_identifier(self, value): value = value.replace(self.escape_quote, self.escape_to_quote) - return value.replace('%', '%%') + return value.replace("%", "%%") class PGDialect_pygresql(PGDialect): - driver = 'pygresql' + driver = "pygresql" statement_compiler = _PGCompiler preparer = _PGIdentifierPreparer @@ -178,6 +189,7 @@ class PGDialect_pygresql(PGDialect): @classmethod def dbapi(cls): import pgdb + return pgdb colspecs = util.update_copy( @@ -189,14 +201,14 @@ class PGDialect_pygresql(PGDialect): JSON: _PGJSON, JSONB: _PGJSONB, UUID: _PGUUID, - } + }, ) def __init__(self, **kwargs): super(PGDialect_pygresql, self).__init__(**kwargs) try: version = self.dbapi.version - m = re.match(r'(\d+)\.(\d+)', version) + m = re.match(r"(\d+)\.(\d+)", version) version = (int(m.group(1)), int(m.group(2))) except (AttributeError, ValueError, TypeError): version = (0, 0) @@ -204,8 +216,10 @@ class PGDialect_pygresql(PGDialect): if version < (5, 0): has_native_hstore = has_native_json = has_native_uuid = False if version != (0, 0): - util.warn("PyGreSQL is only fully supported by SQLAlchemy" - " since version 5.0.") + util.warn( + "PyGreSQL is only fully supported by SQLAlchemy" + " since version 5.0." + ) else: self.supports_unicode_statements = True self.supports_unicode_binds = True @@ -215,10 +229,12 @@ class PGDialect_pygresql(PGDialect): self.has_native_uuid = has_native_uuid def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if 'port' in opts: - opts['host'] = '%s:%s' % ( - opts.get('host', '').rsplit(':', 1)[0], opts.pop('port')) + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["host"] = "%s:%s" % ( + opts.get("host", "").rsplit(":", 1)[0], + opts.pop("port"), + ) opts.update(url.query) return [], opts |