diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/pg8000.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/pg8000.py | 109 |
1 files changed, 71 insertions, 38 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 80929b808..fef09e0eb 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -69,8 +69,15 @@ import decimal from ... import processors from ... import types as sqltypes from .base import ( - PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext, - _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID) + PGDialect, + PGCompiler, + PGIdentifierPreparer, + PGExecutionContext, + _DECIMAL_TYPES, + _FLOAT_TYPES, + _INT_TYPES, + UUID, +) import re from sqlalchemy.dialects.postgresql.json import JSON from ...sql.elements import quoted_name @@ -86,13 +93,15 @@ class _PGNumeric(sqltypes.Numeric): if self.asdecimal: if coltype in _FLOAT_TYPES: return processors.to_decimal_processor_factory( - decimal.Decimal, self._effective_decimal_return_scale) + decimal.Decimal, self._effective_decimal_return_scale + ) elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: # pg8000 returns Decimal natively for 1700 return None else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) else: if coltype in _FLOAT_TYPES: # pg8000 returns float natively for 701 @@ -101,7 +110,8 @@ class _PGNumeric(sqltypes.Numeric): return processors.to_float else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) class _PGNumericNoBind(_PGNumeric): @@ -110,7 +120,6 @@ class _PGNumericNoBind(_PGNumeric): class _PGJSON(JSON): - def result_processor(self, dialect, coltype): if dialect._dbapi_version > (1, 10, 1): return None # Has native JSON @@ -121,18 +130,22 @@ class _PGJSON(JSON): class _PGUUID(UUID): def bind_processor(self, dialect): if not self.as_uuid: + def process(value): if value is not None: value = _python_UUID(value) return value + return process def result_processor(self, dialect, coltype): if not self.as_uuid: + def process(value): if value is not None: value = str(value) return value + return process @@ -142,36 +155,41 @@ class PGExecutionContext_pg8000(PGExecutionContext): class PGCompiler_pg8000(PGCompiler): def visit_mod_binary(self, binary, operator, **kw): - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) def post_process_text(self, text): - if '%%' in text: - util.warn("The SQLAlchemy postgresql dialect " - "now automatically escapes '%' in text() " - "expressions to '%%'.") - return text.replace('%', '%%') + if "%%" in text: + util.warn( + "The SQLAlchemy postgresql dialect " + "now automatically escapes '%' in text() " + "expressions to '%%'." + ) + return text.replace("%", "%%") class PGIdentifierPreparer_pg8000(PGIdentifierPreparer): def _escape_identifier(self, value): value = value.replace(self.escape_quote, self.escape_to_quote) - return value.replace('%', '%%') + return value.replace("%", "%%") class PGDialect_pg8000(PGDialect): - driver = 'pg8000' + driver = "pg8000" supports_unicode_statements = True supports_unicode_binds = True - default_paramstyle = 'format' + default_paramstyle = "format" supports_sane_multi_rowcount = True execution_ctx_cls = PGExecutionContext_pg8000 statement_compiler = PGCompiler_pg8000 preparer = PGIdentifierPreparer_pg8000 - description_encoding = 'use_encoding' + description_encoding = "use_encoding" colspecs = util.update_copy( PGDialect.colspecs, @@ -180,8 +198,8 @@ class PGDialect_pg8000(PGDialect): sqltypes.Float: _PGNumeric, JSON: _PGJSON, sqltypes.JSON: _PGJSON, - UUID: _PGUUID - } + UUID: _PGUUID, + }, ) def __init__(self, client_encoding=None, **kwargs): @@ -194,22 +212,26 @@ class PGDialect_pg8000(PGDialect): @util.memoized_property def _dbapi_version(self): - if self.dbapi and hasattr(self.dbapi, '__version__'): + if self.dbapi and hasattr(self.dbapi, "__version__"): return tuple( [ - int(x) for x in re.findall( - r'(\d+)(?:[-\.]?|$)', self.dbapi.__version__)]) + int(x) + for x in re.findall( + r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ + ) + ] + ) else: return (99, 99, 99) @classmethod def dbapi(cls): - return __import__('pg8000') + return __import__("pg8000") def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if 'port' in opts: - opts['port'] = int(opts['port']) + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["port"] = int(opts["port"]) opts.update(url.query) return ([], opts) @@ -217,32 +239,33 @@ class PGDialect_pg8000(PGDialect): return "connection is closed" in str(e) def set_isolation_level(self, connection, level): - level = level.replace('_', ' ') + level = level.replace("_", " ") # adjust for ConnectionFairy possibly being present - if hasattr(connection, 'connection'): + if hasattr(connection, "connection"): connection = connection.connection - if level == 'AUTOCOMMIT': + if level == "AUTOCOMMIT": connection.autocommit = True elif level in self._isolation_lookup: connection.autocommit = False cursor = connection.cursor() cursor.execute( "SET SESSION CHARACTERISTICS AS TRANSACTION " - "ISOLATION LEVEL %s" % level) + "ISOLATION LEVEL %s" % level + ) cursor.execute("COMMIT") cursor.close() else: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s or AUTOCOMMIT" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s or AUTOCOMMIT" + % (level, self.name, ", ".join(self._isolation_lookup)) ) def set_client_encoding(self, connection, client_encoding): # adjust for ConnectionFairy possibly being present - if hasattr(connection, 'connection'): + if hasattr(connection, "connection"): connection = connection.connection cursor = connection.cursor() @@ -251,18 +274,20 @@ class PGDialect_pg8000(PGDialect): cursor.close() def do_begin_twophase(self, connection, xid): - connection.connection.tpc_begin((0, xid, '')) + connection.connection.tpc_begin((0, xid, "")) def do_prepare_twophase(self, connection, xid): connection.connection.tpc_prepare() def do_rollback_twophase( - self, connection, xid, is_prepared=True, recover=False): - connection.connection.tpc_rollback((0, xid, '')) + self, connection, xid, is_prepared=True, recover=False + ): + connection.connection.tpc_rollback((0, xid, "")) def do_commit_twophase( - self, connection, xid, is_prepared=True, recover=False): - connection.connection.tpc_commit((0, xid, '')) + self, connection, xid, is_prepared=True, recover=False + ): + connection.connection.tpc_commit((0, xid, "")) def do_recover_twophase(self, connection): return [row[1] for row in connection.connection.tpc_recover()] @@ -272,24 +297,32 @@ class PGDialect_pg8000(PGDialect): def on_connect(conn): conn.py_types[quoted_name] = conn.py_types[util.text_type] + fns.append(on_connect) if self.client_encoding is not None: + def on_connect(conn): self.set_client_encoding(conn, self.client_encoding) + fns.append(on_connect) if self.isolation_level is not None: + def on_connect(conn): self.set_isolation_level(conn, self.isolation_level) + fns.append(on_connect) if len(fns) > 0: + def on_connect(conn): for fn in fns: fn(conn) + return on_connect else: return None + dialect = PGDialect_pg8000 |