diff options
author | Brian Jarrett <celttechie@gmail.com> | 2014-07-11 10:22:07 -0600 |
---|---|---|
committer | Brian Jarrett <celttechie@gmail.com> | 2014-07-12 11:45:42 -0600 |
commit | 551680d06e7a0913690414c78d6dfdb590f1588f (patch) | |
tree | 159837848a968935c3984185d948341b46076dd6 | |
parent | 600e6bfb3fa261b98b1ae7237080c8f5e757d09a (diff) | |
download | sqlalchemy-551680d06e7a0913690414c78d6dfdb590f1588f.tar.gz |
Style fixes for dialects package
40 files changed, 1036 insertions, 947 deletions
diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py index 31afe1568..74c48820d 100644 --- a/lib/sqlalchemy/dialects/__init__.py +++ b/lib/sqlalchemy/dialects/__init__.py @@ -13,10 +13,11 @@ __all__ = ( 'postgresql', 'sqlite', 'sybase', - ) +) from .. import util + def _auto_fn(name): """default dialect importer. diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index c8f081b2d..c774358b2 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -119,7 +119,7 @@ RESERVED_WORDS = set([ "union", "unique", "update", "upper", "user", "using", "value", "values", "varchar", "variable", "varying", "view", "wait", "when", "where", "while", "with", "work", "write", "year", - ]) +]) class _StringType(sqltypes.String): @@ -160,20 +160,20 @@ colspecs = { } ischema_names = { - 'SHORT': SMALLINT, - 'LONG': INTEGER, - 'QUAD': FLOAT, - 'FLOAT': FLOAT, - 'DATE': DATE, - 'TIME': TIME, - 'TEXT': TEXT, - 'INT64': BIGINT, - 'DOUBLE': FLOAT, - 'TIMESTAMP': TIMESTAMP, + 'SHORT': SMALLINT, + 'LONG': INTEGER, + 'QUAD': FLOAT, + 'FLOAT': FLOAT, + 'DATE': DATE, + 'TIME': TIME, + 'TEXT': TEXT, + 'INT64': BIGINT, + 'DOUBLE': FLOAT, + 'TIMESTAMP': TIMESTAMP, 'VARYING': VARCHAR, 'CSTRING': CHAR, - 'BLOB': BLOB, - } + 'BLOB': BLOB, +} # TODO: date conversion types (should be implemented as _FBDateTime, @@ -193,7 +193,7 @@ class FBTypeCompiler(compiler.GenericTypeCompiler): return "BLOB SUB_TYPE 0" def _extend_string(self, type_, basic): - charset = getattr(type_, 'charset', None) + charset = getattr(type_, 'charset', None) if charset is None: return basic else: @@ -206,8 +206,8 @@ class FBTypeCompiler(compiler.GenericTypeCompiler): def visit_VARCHAR(self, type_): if not type_.length: raise exc.CompileError( - "VARCHAR requires a length on dialect %s" % - self.dialect.name) + "VARCHAR requires a length on dialect %s" % + self.dialect.name) basic = super(FBTypeCompiler, self).visit_VARCHAR(type_) return self._extend_string(type_, basic) @@ -217,46 +217,46 @@ class FBCompiler(sql.compiler.SQLCompiler): ansi_bind_rules = True - #def visit_contains_op_binary(self, binary, operator, **kw): - # cant use CONTAINING b.c. it's case insensitive. + # def visit_contains_op_binary(self, binary, operator, **kw): + # cant use CONTAINING b.c. it's case insensitive. - #def visit_notcontains_op_binary(self, binary, operator, **kw): - # cant use NOT CONTAINING b.c. it's case insensitive. + # def visit_notcontains_op_binary(self, binary, operator, **kw): + # cant use NOT CONTAINING b.c. it's case insensitive. def visit_now_func(self, fn, **kw): return "CURRENT_TIMESTAMP" def visit_startswith_op_binary(self, binary, operator, **kw): return '%s STARTING WITH %s' % ( - binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) def visit_notstartswith_op_binary(self, binary, operator, **kw): return '%s NOT STARTING WITH %s' % ( - binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) def visit_mod_binary(self, binary, operator, **kw): return "mod(%s, %s)" % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.left, **kw), + self.process(binary.right, **kw)) def visit_alias(self, alias, asfrom=False, **kwargs): if self.dialect._version_two: return super(FBCompiler, self).\ - visit_alias(alias, asfrom=asfrom, **kwargs) + visit_alias(alias, asfrom=asfrom, **kwargs) else: # Override to not use the AS keyword which FB 1.5 does not like if asfrom: alias_name = isinstance(alias.name, - expression._truncated_label) and \ - self._truncated_identifier("alias", - alias.name) or alias.name + expression._truncated_label) and \ + self._truncated_identifier("alias", + alias.name) or alias.name return self.process( - alias.original, asfrom=asfrom, **kwargs) + \ - " " + \ - self.preparer.format_alias(alias, alias_name) + alias.original, asfrom=asfrom, **kwargs) + \ + " " + \ + self.preparer.format_alias(alias, alias_name) else: return self.process(alias.original, **kwargs) @@ -315,9 +315,9 @@ class FBCompiler(sql.compiler.SQLCompiler): def returning_clause(self, stmt, returning_cols): columns = [ - self._label_select_column(None, c, True, False, {}) - for c in expression._select_iterables(returning_cols) - ] + self._label_select_column(None, c, True, False, {}) + for c in expression._select_iterables(returning_cols) + ] return 'RETURNING ' + ', '.join(columns) @@ -332,34 +332,35 @@ class FBDDLCompiler(sql.compiler.DDLCompiler): # http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html if create.element.start is not None: raise NotImplemented( - "Firebird SEQUENCE doesn't support START WITH") + "Firebird SEQUENCE doesn't support START WITH") if create.element.increment is not None: raise NotImplemented( - "Firebird SEQUENCE doesn't support INCREMENT BY") + "Firebird SEQUENCE doesn't support INCREMENT BY") if self.dialect._version_two: return "CREATE SEQUENCE %s" % \ - self.preparer.format_sequence(create.element) + self.preparer.format_sequence(create.element) else: return "CREATE GENERATOR %s" % \ - self.preparer.format_sequence(create.element) + self.preparer.format_sequence(create.element) def visit_drop_sequence(self, drop): """Generate a ``DROP GENERATOR`` statement for the sequence.""" if self.dialect._version_two: return "DROP SEQUENCE %s" % \ - self.preparer.format_sequence(drop.element) + self.preparer.format_sequence(drop.element) else: return "DROP GENERATOR %s" % \ - self.preparer.format_sequence(drop.element) + self.preparer.format_sequence(drop.element) class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): """Install Firebird specific reserved words.""" reserved_words = RESERVED_WORDS - illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union(['_']) + illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union( + ['_']) def __init__(self, dialect): super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) @@ -370,10 +371,10 @@ class FBExecutionContext(default.DefaultExecutionContext): """Get the next value from the sequence using ``gen_id()``.""" return self._execute_scalar( - "SELECT gen_id(%s, 1) FROM rdb$database" % - self.dialect.identifier_preparer.format_sequence(seq), - type_ - ) + "SELECT gen_id(%s, 1) FROM rdb$database" % + self.dialect.identifier_preparer.format_sequence(seq), + type_ + ) class FBDialect(default.DefaultDialect): @@ -411,12 +412,12 @@ class FBDialect(default.DefaultDialect): def initialize(self, connection): super(FBDialect, self).initialize(connection) - self._version_two = ('firebird' in self.server_version_info and \ - self.server_version_info >= (2, ) - ) or \ - ('interbase' in self.server_version_info and \ + self._version_two = ('firebird' in self.server_version_info and + self.server_version_info >= (2, ) + ) or \ + ('interbase' in self.server_version_info and self.server_version_info >= (6, ) - ) + ) if not self._version_two: # TODO: whatever other pre < 2.0 stuff goes here @@ -427,7 +428,7 @@ class FBDialect(default.DefaultDialect): } self.implicit_returning = self._version_two and \ - self.__dict__.get('implicit_returning', True) + self.__dict__.get('implicit_returning', True) def normalize_name(self, name): # Remove trailing spaces: FB uses a CHAR() type, @@ -436,7 +437,7 @@ class FBDialect(default.DefaultDialect): if name is None: return None elif name.upper() == name and \ - not self.identifier_preparer._requires_quotes(name.lower()): + not self.identifier_preparer._requires_quotes(name.lower()): return name.lower() else: return name @@ -445,7 +446,7 @@ class FBDialect(default.DefaultDialect): if name is None: return None elif name.lower() == name and \ - not self.identifier_preparer._requires_quotes(name.lower()): + not self.identifier_preparer._requires_quotes(name.lower()): return name.upper() else: return name @@ -539,8 +540,8 @@ class FBDialect(default.DefaultDialect): @reflection.cache def get_column_sequence(self, connection, - table_name, column_name, - schema=None, **kw): + table_name, column_name, + schema=None, **kw): tablename = self.denormalize_name(table_name) colname = self.denormalize_name(column_name) # Heuristic-query to determine the generator associated to a PK field @@ -613,8 +614,8 @@ class FBDialect(default.DefaultDialect): coltype = sqltypes.NULLTYPE elif issubclass(coltype, Integer) and row['fprec'] != 0: coltype = NUMERIC( - precision=row['fprec'], - scale=row['fscale'] * -1) + precision=row['fprec'], + scale=row['fscale'] * -1) elif colspec in ('VARYING', 'CSTRING'): coltype = coltype(row['flen']) elif colspec == 'TEXT': @@ -636,8 +637,8 @@ class FBDialect(default.DefaultDialect): # (see also http://tracker.firebirdsql.org/browse/CORE-356) defexpr = row['fdefault'].lstrip() assert defexpr[:8].rstrip().upper() == \ - 'DEFAULT', "Unrecognized default value: %s" % \ - defexpr + 'DEFAULT', "Unrecognized default value: %s" % \ + defexpr defvalue = defexpr[8:].strip() if defvalue == 'NULL': # Redundant @@ -700,9 +701,9 @@ class FBDialect(default.DefaultDialect): fk['name'] = cname fk['referred_table'] = self.normalize_name(row['targetrname']) fk['constrained_columns'].append( - self.normalize_name(row['fname'])) + self.normalize_name(row['fname'])) fk['referred_columns'].append( - self.normalize_name(row['targetfname'])) + self.normalize_name(row['targetfname'])) return list(fks.values()) @reflection.cache @@ -732,7 +733,6 @@ class FBDialect(default.DefaultDialect): indexrec['unique'] = bool(row['unique_flag']) indexrec['column_names'].append( - self.normalize_name(row['field_name'])) + self.normalize_name(row['field_name'])) return list(indexes.values()) - diff --git a/lib/sqlalchemy/dialects/firebird/fdb.py b/lib/sqlalchemy/dialects/firebird/fdb.py index a691adb53..aca5f7e41 100644 --- a/lib/sqlalchemy/dialects/firebird/fdb.py +++ b/lib/sqlalchemy/dialects/firebird/fdb.py @@ -73,14 +73,14 @@ from ... import util class FBDialect_fdb(FBDialect_kinterbasdb): def __init__(self, enable_rowcount=True, - retaining=False, **kwargs): + retaining=False, **kwargs): super(FBDialect_fdb, self).__init__( - enable_rowcount=enable_rowcount, - retaining=retaining, **kwargs) + enable_rowcount=enable_rowcount, + retaining=retaining, **kwargs) @classmethod def dbapi(cls): - return __import__('fdb') + return __import__('fdb') def create_connect_args(self, url): opts = url.translate_connect_args(username='user') diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py index cdd1f7e7b..256b902c6 100644 --- a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py +++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py @@ -52,9 +52,11 @@ class _kinterbasdb_numeric(object): return value return process + class _FBNumeric_kinterbasdb(_kinterbasdb_numeric, sqltypes.Numeric): pass + class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float): pass @@ -63,7 +65,7 @@ class FBExecutionContext_kinterbasdb(FBExecutionContext): @property def rowcount(self): if self.execution_options.get('enable_rowcount', - self.dialect.enable_rowcount): + self.dialect.enable_rowcount): return self.cursor.rowcount else: return -1 @@ -87,8 +89,8 @@ class FBDialect_kinterbasdb(FBDialect): ) def __init__(self, type_conv=200, concurrency_level=1, - enable_rowcount=True, - retaining=False, **kwargs): + enable_rowcount=True, + retaining=False, **kwargs): super(FBDialect_kinterbasdb, self).__init__(**kwargs) self.enable_rowcount = enable_rowcount self.type_conv = type_conv @@ -123,7 +125,7 @@ class FBDialect_kinterbasdb(FBDialect): type_conv = opts.pop('type_conv', self.type_conv) concurrency_level = opts.pop('concurrency_level', - self.concurrency_level) + self.concurrency_level) if self.dbapi is not None: initialized = getattr(self.dbapi, 'initialized', None) @@ -134,7 +136,7 @@ class FBDialect_kinterbasdb(FBDialect): initialized = getattr(self.dbapi, '_initialized', False) if not initialized: self.dbapi.init(type_conv=type_conv, - concurrency_level=concurrency_level) + concurrency_level=concurrency_level) return ([], opts) def _get_server_version_info(self, connection): @@ -156,10 +158,11 @@ class FBDialect_kinterbasdb(FBDialect): return self._parse_version_info(version) def _parse_version_info(self, version): - m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?', version) + m = match( + '\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?', version) if not m: raise AssertionError( - "Could not determine version from string '%s'" % version) + "Could not determine version from string '%s'" % version) if m.group(5) != None: return tuple([int(x) for x in m.group(6, 7, 4)] + ['firebird']) @@ -168,7 +171,7 @@ class FBDialect_kinterbasdb(FBDialect): def is_disconnect(self, e, connection, cursor): if isinstance(e, (self.dbapi.OperationalError, - self.dbapi.ProgrammingError)): + self.dbapi.ProgrammingError)): msg = str(e) return ('Unable to complete network request to host' in msg or 'Invalid connection state' in msg or diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py index 4c059ae2f..d0047765e 100644 --- a/lib/sqlalchemy/dialects/mssql/__init__.py +++ b/lib/sqlalchemy/dialects/mssql/__init__.py @@ -6,7 +6,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, \ - pymssql, zxjdbc, mxodbc + pymssql, zxjdbc, mxodbc base.dialect = pyodbc.dialect diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py index d94a4517d..e9927f8ed 100644 --- a/lib/sqlalchemy/dialects/mssql/adodbapi.py +++ b/lib/sqlalchemy/dialects/mssql/adodbapi.py @@ -61,7 +61,7 @@ class MSDialect_adodbapi(MSDialect): connectors = ["Provider=SQLOLEDB"] if 'port' in keys: connectors.append("Data Source=%s, %s" % - (keys.get("host"), keys.get("port"))) + (keys.get("host"), keys.get("port"))) else: connectors.append("Data Source=%s" % keys.get("host")) connectors.append("Initial Catalog=%s" % keys.get("database")) @@ -75,6 +75,6 @@ class MSDialect_adodbapi(MSDialect): def is_disconnect(self, e, connection, cursor): return isinstance(e, self.dbapi.adodbapi.DatabaseError) and \ - "'connection failure'" in str(e) + "'connection failure'" in str(e) dialect = MSDialect_adodbapi diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 547df8259..9ba427458 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -364,13 +364,13 @@ import re from ... import sql, schema as sa_schema, exc, util from ...sql import compiler, expression, \ - util as sql_util, cast + util as sql_util, cast from ... import engine from ...engine import reflection, default from ... import types as sqltypes from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ - FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\ - VARBINARY, TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR + FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\ + VARBINARY, TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR from ...util import update_wrapper @@ -409,7 +409,7 @@ RESERVED_WORDS = set( 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values', 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with', 'writetext', - ]) + ]) class REAL(sqltypes.REAL): @@ -447,9 +447,9 @@ class _MSDate(sqltypes.Date): return value.date() elif isinstance(value, util.string_types): return datetime.date(*[ - int(x or 0) - for x in self._reg.match(value).groups() - ]) + int(x or 0) + for x in self._reg.match(value).groups() + ]) else: return value return process @@ -466,7 +466,7 @@ class TIME(sqltypes.TIME): def process(value): if isinstance(value, datetime.datetime): value = datetime.datetime.combine( - self.__zero_date, value.time()) + self.__zero_date, value.time()) elif isinstance(value, datetime.time): value = datetime.datetime.combine(self.__zero_date, value) return value @@ -480,8 +480,8 @@ class TIME(sqltypes.TIME): return value.time() elif isinstance(value, util.string_types): return datetime.time(*[ - int(x or 0) - for x in self._reg.match(value).groups()]) + int(x or 0) + for x in self._reg.match(value).groups()]) else: return value return process @@ -529,8 +529,6 @@ class _StringType(object): super(_StringType, self).__init__(collation=collation) - - class NTEXT(sqltypes.UnicodeText): """MSSQL NTEXT type, for variable-length unicode text up to 2^30 characters.""" @@ -538,7 +536,6 @@ class NTEXT(sqltypes.UnicodeText): __visit_name__ = 'NTEXT' - class IMAGE(sqltypes.LargeBinary): __visit_name__ = 'IMAGE' @@ -638,7 +635,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): spec = spec + "(%s)" % length return ' '.join([c for c in (spec, collation) - if c is not None]) + if c is not None]) def visit_FLOAT(self, type_): precision = getattr(type_, 'precision', None) @@ -717,9 +714,9 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): def visit_VARBINARY(self, type_): return self._extend( - "VARBINARY", - type_, - length=type_.length or 'max') + "VARBINARY", + type_, + length=type_.length or 'max') def visit_boolean(self, type_): return self.visit_BIT(type_) @@ -756,20 +753,21 @@ class MSExecutionContext(default.DefaultExecutionContext): if insert_has_sequence: self._enable_identity_insert = \ - seq_column.key in self.compiled_parameters[0] + seq_column.key in self.compiled_parameters[0] else: self._enable_identity_insert = False self._select_lastrowid = insert_has_sequence and \ - not self.compiled.returning and \ - not self._enable_identity_insert and \ - not self.executemany + not self.compiled.returning and \ + not self._enable_identity_insert and \ + not self.executemany if self._enable_identity_insert: self.root_connection._cursor_execute(self.cursor, - "SET IDENTITY_INSERT %s ON" % - self.dialect.identifier_preparer.format_table(tbl), - (), self) + "SET IDENTITY_INSERT %s ON" % + self.dialect.identifier_preparer.format_table( + tbl), + (), self) def post_exec(self): """Disable IDENTITY_INSERT if enabled.""" @@ -778,10 +776,10 @@ class MSExecutionContext(default.DefaultExecutionContext): if self._select_lastrowid: if self.dialect.use_scope_identity: conn._cursor_execute(self.cursor, - "SELECT scope_identity() AS lastrowid", (), self) + "SELECT scope_identity() AS lastrowid", (), self) else: conn._cursor_execute(self.cursor, - "SELECT @@identity AS lastrowid", (), self) + "SELECT @@identity AS lastrowid", (), self) # fetchall() ensures the cursor is consumed without closing it row = self.cursor.fetchall()[0] self._lastrowid = int(row[0]) @@ -792,10 +790,10 @@ class MSExecutionContext(default.DefaultExecutionContext): if self._enable_identity_insert: conn._cursor_execute(self.cursor, - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer. - format_table(self.compiled.statement.table), - (), self) + "SET IDENTITY_INSERT %s OFF" % + self.dialect.identifier_preparer. + format_table(self.compiled.statement.table), + (), self) def get_lastrowid(self): return self._lastrowid @@ -804,10 +802,10 @@ class MSExecutionContext(default.DefaultExecutionContext): if self._enable_identity_insert: try: self.cursor.execute( - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer.\ - format_table(self.compiled.statement.table) - ) + "SET IDENTITY_INSERT %s OFF" % + self.dialect.identifier_preparer. + format_table(self.compiled.statement.table) + ) except: pass @@ -824,11 +822,11 @@ class MSSQLCompiler(compiler.SQLCompiler): extract_map = util.update_copy( compiler.SQLCompiler.extract_map, { - 'doy': 'dayofyear', - 'dow': 'weekday', - 'milliseconds': 'millisecond', - 'microseconds': 'microsecond' - }) + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond', + 'microseconds': 'microsecond' + }) def __init__(self, *args, **kwargs): self.tablealiases = {} @@ -848,8 +846,8 @@ class MSSQLCompiler(compiler.SQLCompiler): def visit_concat_op_binary(self, binary, operator, **kw): return "%s + %s" % \ - (self.process(binary.left, **kw), - self.process(binary.right, **kw)) + (self.process(binary.left, **kw), + self.process(binary.right, **kw)) def visit_true(self, expr, **kw): return '1' @@ -859,8 +857,8 @@ class MSSQLCompiler(compiler.SQLCompiler): def visit_match_op_binary(self, binary, operator, **kw): return "CONTAINS (%s, %s)" % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.left, **kw), + self.process(binary.right, **kw)) def get_select_precolumns(self, select): """ MS-SQL puts TOP, it's version of LIMIT here """ @@ -896,20 +894,20 @@ class MSSQLCompiler(compiler.SQLCompiler): """ if ( - ( - not select._simple_int_limit and - select._limit_clause is not None - ) or ( - select._offset_clause is not None and - not select._simple_int_offset or select._offset - ) - ) and not getattr(select, '_mssql_visit', None): + ( + not select._simple_int_limit and + select._limit_clause is not None + ) or ( + select._offset_clause is not None and + not select._simple_int_offset or select._offset + ) + ) and not getattr(select, '_mssql_visit', None): # to use ROW_NUMBER(), an ORDER BY is required. if not select._order_by_clause.clauses: raise exc.CompileError('MSSQL requires an order_by when ' - 'using an OFFSET or a non-simple ' - 'LIMIT clause') + 'using an OFFSET or a non-simple ' + 'LIMIT clause') _order_by_clauses = select._order_by_clause.clauses limit_clause = select._limit_clause @@ -917,20 +915,20 @@ class MSSQLCompiler(compiler.SQLCompiler): select = select._generate() select._mssql_visit = True select = select.column( - sql.func.ROW_NUMBER().over(order_by=_order_by_clauses) - .label("mssql_rn")).order_by(None).alias() + sql.func.ROW_NUMBER().over(order_by=_order_by_clauses) + .label("mssql_rn")).order_by(None).alias() mssql_rn = sql.column('mssql_rn') limitselect = sql.select([c for c in select.c if - c.key != 'mssql_rn']) + c.key != 'mssql_rn']) if offset_clause is not None: limitselect.append_whereclause(mssql_rn > offset_clause) if limit_clause is not None: limitselect.append_whereclause( - mssql_rn <= (limit_clause + offset_clause)) + mssql_rn <= (limit_clause + offset_clause)) else: limitselect.append_whereclause( - mssql_rn <= (limit_clause)) + mssql_rn <= (limit_clause)) return self.process(limitselect, iswrapper=True, **kwargs) else: return compiler.SQLCompiler.visit_select(self, select, **kwargs) @@ -962,7 +960,7 @@ class MSSQLCompiler(compiler.SQLCompiler): def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) return 'DATEPART("%s", %s)' % \ - (field, self.process(extract.expr, **kw)) + (field, self.process(extract.expr, **kw)) def visit_savepoint(self, savepoint_stmt): return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) @@ -973,25 +971,25 @@ class MSSQLCompiler(compiler.SQLCompiler): def visit_column(self, column, add_to_result_map=None, **kwargs): if column.table is not None and \ - (not self.isupdate and not self.isdelete) or self.is_subquery(): + (not self.isupdate and not self.isdelete) or self.is_subquery(): # translate for schema-qualified table aliases t = self._schema_aliased_table(column.table) if t is not None: converted = expression._corresponding_column_or_error( - t, column) + t, column) if add_to_result_map is not None: add_to_result_map( - column.name, - column.name, - (column, column.name, column.key), - column.type + column.name, + column.name, + (column, column.name, column.key), + column.type ) return super(MSSQLCompiler, self).\ - visit_column(converted, **kwargs) + visit_column(converted, **kwargs) return super(MSSQLCompiler, self).visit_column( - column, add_to_result_map=add_to_result_map, **kwargs) + column, add_to_result_map=add_to_result_map, **kwargs) def visit_binary(self, binary, **kwargs): """Move bind parameters to the right-hand side of an operator, where @@ -1002,12 +1000,12 @@ class MSSQLCompiler(compiler.SQLCompiler): isinstance(binary.left, expression.BindParameter) and binary.operator == operator.eq and not isinstance(binary.right, expression.BindParameter) - ): + ): return self.process( - expression.BinaryExpression(binary.right, - binary.left, - binary.operator), - **kwargs) + expression.BinaryExpression(binary.right, + binary.left, + binary.operator), + **kwargs) return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) def returning_clause(self, stmt, returning_cols): @@ -1020,10 +1018,10 @@ class MSSQLCompiler(compiler.SQLCompiler): adapter = sql_util.ClauseAdapter(target) columns = [ - self._label_select_column(None, adapter.traverse(c), - True, False, {}) - for c in expression._select_iterables(returning_cols) - ] + self._label_select_column(None, adapter.traverse(c), + True, False, {}) + for c in expression._select_iterables(returning_cols) + ] return 'OUTPUT ' + ', '.join(columns) @@ -1039,7 +1037,7 @@ class MSSQLCompiler(compiler.SQLCompiler): return column.label(None) else: return super(MSSQLCompiler, self).\ - label_select_column(select, column, asfrom) + label_select_column(select, column, asfrom) def for_update_clause(self, select): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which @@ -1056,9 +1054,9 @@ class MSSQLCompiler(compiler.SQLCompiler): return "" def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): + from_table, extra_froms, + from_hints, + **kw): """Render the UPDATE..FROM clause specific to MSSQL. In MSSQL, if the UPDATE statement involves an alias of the table to @@ -1067,9 +1065,9 @@ class MSSQLCompiler(compiler.SQLCompiler): """ return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + t._compiler_dispatch(self, asfrom=True, + fromhints=from_hints, **kw) + for t in [from_table] + extra_froms) class MSSQLStrictCompiler(MSSQLCompiler): @@ -1085,16 +1083,16 @@ class MSSQLStrictCompiler(MSSQLCompiler): def visit_in_op_binary(self, binary, operator, **kw): kw['literal_binds'] = True return "%s IN %s" % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw) - ) + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) def visit_notin_op_binary(self, binary, operator, **kw): kw['literal_binds'] = True return "%s NOT IN %s" % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw) - ) + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) def render_literal_value(self, value, type_): """ @@ -1113,7 +1111,7 @@ class MSSQLStrictCompiler(MSSQLCompiler): return "'" + str(value) + "'" else: return super(MSSQLStrictCompiler, self).\ - render_literal_value(value, type_) + render_literal_value(value, type_) class MSDDLCompiler(compiler.DDLCompiler): @@ -1130,17 +1128,19 @@ class MSDDLCompiler(compiler.DDLCompiler): if column.table is None: raise exc.CompileError( - "mssql requires Table-bound columns " - "in order to generate DDL") + "mssql requires Table-bound columns " + "in order to generate DDL") - # install an IDENTITY Sequence if we either a sequence or an implicit IDENTITY column + # install an IDENTITY Sequence if we either a sequence or an implicit + # IDENTITY column if isinstance(column.default, sa_schema.Sequence): if column.default.start == 0: start = 0 else: start = column.default.start or 1 - colspec += " IDENTITY(%s,%s)" % (start, column.default.increment or 1) + colspec += " IDENTITY(%s,%s)" % (start, + column.default.increment or 1) elif column is column.table._autoincrement_column: colspec += " IDENTITY(1,1)" else: @@ -1163,20 +1163,20 @@ class MSDDLCompiler(compiler.DDLCompiler): text += "CLUSTERED " text += "INDEX %s ON %s (%s)" \ - % ( - self._prepared_index_name(index, - include_schema=include_schema), - preparer.format_table(index.table), - ', '.join( - self.sql_compiler.process(expr, - include_table=False, literal_binds=True) for - expr in index.expressions) - ) + % ( + self._prepared_index_name(index, + include_schema=include_schema), + preparer.format_table(index.table), + ', '.join( + self.sql_compiler.process(expr, + include_table=False, literal_binds=True) for + expr in index.expressions) + ) # handle other included columns if index.dialect_options['mssql']['include']: inclusions = [index.table.c[col] - if isinstance(col, util.string_types) else col + if isinstance(col, util.string_types) else col for col in index.dialect_options['mssql']['include']] text += " INCLUDE (%s)" \ @@ -1189,7 +1189,7 @@ class MSDDLCompiler(compiler.DDLCompiler): return "\nDROP INDEX %s ON %s" % ( self._prepared_index_name(drop.element, include_schema=False), self.preparer.format_table(drop.element.table) - ) + ) def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: @@ -1225,6 +1225,7 @@ class MSDDLCompiler(compiler.DDLCompiler): text += self.define_constraint_deferrability(constraint) return text + class MSIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS @@ -1245,7 +1246,7 @@ def _db_plus_owner_listing(fn): def wrap(dialect, connection, schema=None, **kw): dbname, owner = _owner_plus_db(dialect, schema) return _switch_db(dbname, connection, fn, dialect, connection, - dbname, owner, schema, **kw) + dbname, owner, schema, **kw) return update_wrapper(wrap, fn) @@ -1253,7 +1254,7 @@ def _db_plus_owner(fn): def wrap(dialect, connection, tablename, schema=None, **kw): dbname, owner = _owner_plus_db(dialect, schema) return _switch_db(dbname, connection, fn, dialect, connection, - tablename, dbname, owner, schema, **kw) + tablename, dbname, owner, schema, **kw) return update_wrapper(wrap, fn) @@ -1328,7 +1329,7 @@ class MSDialect(default.DefaultDialect): self.use_scope_identity = use_scope_identity self.max_identifier_length = int(max_identifier_length or 0) or \ - self.max_identifier_length + self.max_identifier_length super(MSDialect, self).__init__(**opts) def do_savepoint(self, connection, name): @@ -1353,7 +1354,7 @@ class MSDialect(default.DefaultDialect): "is configured in the FreeTDS configuration." % ".".join(str(x) for x in self.server_version_info)) if self.server_version_info >= MS_2005_VERSION and \ - 'implicit_returning' not in self.__dict__: + 'implicit_returning' not in self.__dict__: self.implicit_returning = True if self.server_version_info >= MS_2008_VERSION: self.supports_multivalues_insert = True @@ -1386,8 +1387,8 @@ class MSDialect(default.DefaultDialect): @reflection.cache def get_schema_names(self, connection, **kw): s = sql.select([ischema.schemata.c.schema_name], - order_by=[ischema.schemata.c.schema_name] - ) + order_by=[ischema.schemata.c.schema_name] + ) schema_names = [r[0] for r in connection.execute(s)] return schema_names @@ -1396,10 +1397,10 @@ class MSDialect(default.DefaultDialect): def get_table_names(self, connection, dbname, owner, schema, **kw): tables = ischema.tables s = sql.select([tables.c.table_name], - sql.and_( - tables.c.table_schema == owner, - tables.c.table_type == 'BASE TABLE' - ), + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == 'BASE TABLE' + ), order_by=[tables.c.table_name] ) table_names = [r[0] for r in connection.execute(s)] @@ -1410,10 +1411,10 @@ class MSDialect(default.DefaultDialect): def get_view_names(self, connection, dbname, owner, schema, **kw): tables = ischema.tables s = sql.select([tables.c.table_name], - sql.and_( - tables.c.table_schema == owner, - tables.c.table_type == 'VIEW' - ), + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == 'VIEW' + ), order_by=[tables.c.table_name] ) view_names = [r[0] for r in connection.execute(s)] @@ -1429,22 +1430,22 @@ class MSDialect(default.DefaultDialect): rp = connection.execute( sql.text("select ind.index_id, ind.is_unique, ind.name " - "from sys.indexes as ind join sys.tables as tab on " - "ind.object_id=tab.object_id " - "join sys.schemas as sch on sch.schema_id=tab.schema_id " - "where tab.name = :tabname " - "and sch.name=:schname " - "and ind.is_primary_key=0", - bindparams=[ - sql.bindparam('tabname', tablename, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) - ], - typemap={ - 'name': sqltypes.Unicode() - } - ) + "from sys.indexes as ind join sys.tables as tab on " + "ind.object_id=tab.object_id " + "join sys.schemas as sch on sch.schema_id=tab.schema_id " + "where tab.name = :tabname " + "and sch.name=:schname " + "and ind.is_primary_key=0", + bindparams=[ + sql.bindparam('tabname', tablename, + sqltypes.String(convert_unicode=True)), + sql.bindparam('schname', owner, + sqltypes.String(convert_unicode=True)) + ], + typemap={ + 'name': sqltypes.Unicode() + } + ) ) indexes = {} for row in rp: @@ -1464,15 +1465,15 @@ class MSDialect(default.DefaultDialect): "join sys.schemas as sch on sch.schema_id=tab.schema_id " "where tab.name=:tabname " "and sch.name=:schname", - bindparams=[ - sql.bindparam('tabname', tablename, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) - ], - typemap={'name': sqltypes.Unicode()} - ), - ) + bindparams=[ + sql.bindparam('tabname', tablename, + sqltypes.String(convert_unicode=True)), + sql.bindparam('schname', owner, + sqltypes.String(convert_unicode=True)) + ], + typemap={'name': sqltypes.Unicode()} + ), + ) for row in rp: if row['index_id'] in indexes: indexes[row['index_id']]['column_names'].append(row['name']) @@ -1493,9 +1494,9 @@ class MSDialect(default.DefaultDialect): "views.name=:viewname and sch.name=:schname", bindparams=[ sql.bindparam('viewname', viewname, - sqltypes.String(convert_unicode=True)), + sqltypes.String(convert_unicode=True)), sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) + sqltypes.String(convert_unicode=True)) ] ) ) @@ -1515,7 +1516,7 @@ class MSDialect(default.DefaultDialect): else: whereclause = columns.c.table_name == tablename s = sql.select([columns], whereclause, - order_by=[columns.c.ordinal_position]) + order_by=[columns.c.ordinal_position]) c = connection.execute(s) cols = [] @@ -1585,7 +1586,7 @@ class MSDialect(default.DefaultDialect): ic = col_name colmap[col_name]['autoincrement'] = True colmap[col_name]['sequence'] = dict( - name='%s_identity' % col_name) + name='%s_identity' % col_name) break cursor.close() @@ -1594,7 +1595,7 @@ class MSDialect(default.DefaultDialect): cursor = connection.execute( "select ident_seed('%s'), ident_incr('%s')" % (table_fullname, table_fullname) - ) + ) row = cursor.first() if row is not None and row[0] is not None: @@ -1613,11 +1614,11 @@ class MSDialect(default.DefaultDialect): # Primary key constraints s = sql.select([C.c.column_name, TC.c.constraint_type, C.c.constraint_name], - sql.and_(TC.c.constraint_name == C.c.constraint_name, - TC.c.table_schema == C.c.table_schema, - C.c.table_name == tablename, - C.c.table_schema == owner) - ) + sql.and_(TC.c.constraint_name == C.c.constraint_name, + TC.c.table_schema == C.c.table_schema, + C.c.table_name == tablename, + C.c.table_schema == owner) + ) c = connection.execute(s) constraint_name = None for row in c: @@ -1644,11 +1645,11 @@ class MSDialect(default.DefaultDialect): C.c.table_schema == owner, C.c.constraint_name == RR.c.constraint_name, R.c.constraint_name == - RR.c.unique_constraint_name, + RR.c.unique_constraint_name, C.c.ordinal_position == R.c.ordinal_position ), order_by=[RR.c.constraint_name, R.c.ordinal_position] - ) + ) # group rows by constraint ID, to handle multi-column FKs fkeys = [] @@ -1678,8 +1679,8 @@ class MSDialect(default.DefaultDialect): rec['referred_schema'] = rschema local_cols, remote_cols = \ - rec['constrained_columns'],\ - rec['referred_columns'] + rec['constrained_columns'],\ + rec['referred_columns'] local_cols.append(scol) remote_cols.append(rcol) diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index 77251e61a..19d59387d 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -5,7 +5,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -# TODO: should be using the sys. catalog with SQL Server, not information schema +# TODO: should be using the sys. catalog with SQL Server, not information +# schema from ... import Table, MetaData, Column from ...types import String, Unicode, UnicodeText, Integer, TypeDecorator @@ -16,6 +17,7 @@ from ...ext.compiler import compiles ischema = MetaData() + class CoerceUnicode(TypeDecorator): impl = Unicode @@ -27,10 +29,12 @@ class CoerceUnicode(TypeDecorator): def bind_expression(self, bindvalue): return _cast_on_2005(bindvalue) + class _cast_on_2005(expression.ColumnElement): def __init__(self, bindvalue): self.bindvalue = bindvalue + @compiles(_cast_on_2005) def _compile(element, compiler, **kw): from . import base @@ -40,76 +44,91 @@ def _compile(element, compiler, **kw): return compiler.process(cast(element.bindvalue, Unicode), **kw) schemata = Table("SCHEMATA", ischema, - Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), - Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), - Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), - schema="INFORMATION_SCHEMA") + Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), + Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), + Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), + schema="INFORMATION_SCHEMA") tables = Table("TABLES", ischema, - Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"), - schema="INFORMATION_SCHEMA") + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column( + "TABLE_TYPE", String(convert_unicode=True), key="table_type"), + schema="INFORMATION_SCHEMA") columns = Table("COLUMNS", ischema, - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column("COLUMN_NAME", CoerceUnicode, key="column_name"), - Column("IS_NULLABLE", Integer, key="is_nullable"), - Column("DATA_TYPE", String, key="data_type"), - Column("ORDINAL_POSITION", Integer, key="ordinal_position"), - Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"), - Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), - Column("NUMERIC_SCALE", Integer, key="numeric_scale"), - Column("COLUMN_DEFAULT", Integer, key="column_default"), - Column("COLLATION_NAME", String, key="collation_name"), - schema="INFORMATION_SCHEMA") + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column("CHARACTER_MAXIMUM_LENGTH", Integer, + key="character_maximum_length"), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="INFORMATION_SCHEMA") constraints = Table("TABLE_CONSTRAINTS", ischema, - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), - Column("CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type"), - schema="INFORMATION_SCHEMA") + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column( + "CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_TYPE", String( + convert_unicode=True), key="constraint_type"), + schema="INFORMATION_SCHEMA") column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema, - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column("COLUMN_NAME", CoerceUnicode, key="column_name"), - Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), - schema="INFORMATION_SCHEMA") + Column( + "TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column( + "TABLE_NAME", CoerceUnicode, key="table_name"), + Column( + "COLUMN_NAME", CoerceUnicode, key="column_name"), + Column( + "CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + schema="INFORMATION_SCHEMA") key_constraints = Table("KEY_COLUMN_USAGE", ischema, - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column("COLUMN_NAME", CoerceUnicode, key="column_name"), - Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), - Column("ORDINAL_POSITION", Integer, key="ordinal_position"), - schema="INFORMATION_SCHEMA") + Column( + "TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column( + "COLUMN_NAME", CoerceUnicode, key="column_name"), + Column( + "CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column( + "ORDINAL_POSITION", Integer, key="ordinal_position"), + schema="INFORMATION_SCHEMA") ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema, - Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"), - Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), - Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), - # TODO: is CATLOG misspelled ? - Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, - key="unique_constraint_catalog"), - - Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, - key="unique_constraint_schema"), - Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, - key="unique_constraint_name"), - Column("MATCH_OPTION", String, key="match_option"), - Column("UPDATE_RULE", String, key="update_rule"), - Column("DELETE_RULE", String, key="delete_rule"), - schema="INFORMATION_SCHEMA") + Column( + "CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"), + Column( + "CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column( + "CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + # TODO: is CATLOG misspelled ? + Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, + key="unique_constraint_catalog"), + + Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, + key="unique_constraint_schema"), + Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, + key="unique_constraint_name"), + Column("MATCH_OPTION", String, key="match_option"), + Column("UPDATE_RULE", String, key="update_rule"), + Column("DELETE_RULE", String, key="delete_rule"), + schema="INFORMATION_SCHEMA") views = Table("VIEWS", ischema, - Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), - Column("CHECK_OPTION", String, key="check_option"), - Column("IS_UPDATABLE", String, key="is_updatable"), - schema="INFORMATION_SCHEMA") + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), + Column("CHECK_OPTION", String, key="check_option"), + Column("IS_UPDATABLE", String, key="is_updatable"), + schema="INFORMATION_SCHEMA") diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py index ad9e9c2ba..b6749aa2a 100644 --- a/lib/sqlalchemy/dialects/mssql/mxodbc.py +++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py @@ -47,8 +47,8 @@ from ... import types as sqltypes from ...connectors.mxodbc import MxODBCConnector from .pyodbc import MSExecutionContext_pyodbc, _MSNumeric_pyodbc from .base import (MSDialect, - MSSQLStrictCompiler, - _MSDateTime, _MSDate, _MSTime) + MSSQLStrictCompiler, + _MSDateTime, _MSDate, _MSTime) class _MSNumeric_mxodbc(_MSNumeric_pyodbc): @@ -82,7 +82,7 @@ class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc): SELECT SCOPE_IDENTITY in cases where OUTPUT clause does not work (tables with insert triggers). """ - #todo - investigate whether the pyodbc execution context + # todo - investigate whether the pyodbc execution context # is really only being used in cases where OUTPUT # won't work. diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index 4b7be1ac4..5e50b96ac 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -52,7 +52,7 @@ class MSDialect_pymssql(MSDialect): client_ver = tuple(int(x) for x in module.__version__.split(".")) if client_ver < (1, ): util.warn("The pymssql dialect expects at least " - "the 1.0 series of the pymssql DBAPI.") + "the 1.0 series of the pymssql DBAPI.") return module def __init__(self, **params): diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index ab45fa25e..86d896f8b 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -117,6 +117,7 @@ from ...connectors.pyodbc import PyODBCConnector from ... import types as sqltypes, util import decimal + class _ms_numeric_pyodbc(object): """Turns Decimals with adjusted() < 0 or > 7 into strings. @@ -129,7 +130,7 @@ class _ms_numeric_pyodbc(object): def bind_processor(self, dialect): super_process = super(_ms_numeric_pyodbc, self).\ - bind_processor(dialect) + bind_processor(dialect) if not dialect._need_decimal_fix: return super_process @@ -155,38 +156,41 @@ class _ms_numeric_pyodbc(object): def _small_dec_to_string(self, value): return "%s0.%s%s" % ( - (value < 0 and '-' or ''), - '0' * (abs(value.adjusted()) - 1), - "".join([str(nint) for nint in value.as_tuple()[1]])) + (value < 0 and '-' or ''), + '0' * (abs(value.adjusted()) - 1), + "".join([str(nint) for nint in value.as_tuple()[1]])) def _large_dec_to_string(self, value): _int = value.as_tuple()[1] if 'E' in str(value): result = "%s%s%s" % ( - (value < 0 and '-' or ''), - "".join([str(s) for s in _int]), - "0" * (value.adjusted() - (len(_int) - 1))) + (value < 0 and '-' or ''), + "".join([str(s) for s in _int]), + "0" * (value.adjusted() - (len(_int) - 1))) else: if (len(_int) - 1) > value.adjusted(): result = "%s%s.%s" % ( - (value < 0 and '-' or ''), - "".join( - [str(s) for s in _int][0:value.adjusted() + 1]), - "".join( - [str(s) for s in _int][value.adjusted() + 1:])) + (value < 0 and '-' or ''), + "".join( + [str(s) for s in _int][0:value.adjusted() + 1]), + "".join( + [str(s) for s in _int][value.adjusted() + 1:])) else: result = "%s%s" % ( - (value < 0 and '-' or ''), - "".join( - [str(s) for s in _int][0:value.adjusted() + 1])) + (value < 0 and '-' or ''), + "".join( + [str(s) for s in _int][0:value.adjusted() + 1])) return result + class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric): pass + class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float): pass + class MSExecutionContext_pyodbc(MSExecutionContext): _embedded_scope_identity = False @@ -253,9 +257,9 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): super(MSDialect_pyodbc, self).__init__(**params) self.description_encoding = description_encoding self.use_scope_identity = self.use_scope_identity and \ - self.dbapi and \ - hasattr(self.dbapi.Cursor, 'nextset') + self.dbapi and \ + hasattr(self.dbapi.Cursor, 'nextset') self._need_decimal_fix = self.dbapi and \ - self._dbapi_version() < (2, 1, 8) + self._dbapi_version() < (2, 1, 8) dialect = MSDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mssql/zxjdbc.py b/lib/sqlalchemy/dialects/mssql/zxjdbc.py index 5377be1ce..c14ebb70a 100644 --- a/lib/sqlalchemy/dialects/mssql/zxjdbc.py +++ b/lib/sqlalchemy/dialects/mssql/zxjdbc.py @@ -42,12 +42,12 @@ class MSExecutionContext_zxjdbc(MSExecutionContext): self._lastrowid = int(row[0]) if (self.isinsert or self.isupdate or self.isdelete) and \ - self.compiled.returning: + self.compiled.returning: self._result_proxy = engine.FullyBufferedResultProxy(self) if self._enable_identity_insert: table = self.dialect.identifier_preparer.format_table( - self.compiled.statement.table) + self.compiled.statement.table) self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table) @@ -59,8 +59,8 @@ class MSDialect_zxjdbc(ZxJDBCConnector, MSDialect): def _get_server_version_info(self, connection): return tuple( - int(x) - for x in connection.connection.dbversion.split('.') - ) + int(x) + for x in connection.connection.dbversion.split('.') + ) dialect = MSDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py index a9dbd819e..32699abbd 100644 --- a/lib/sqlalchemy/dialects/mysql/__init__.py +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -6,8 +6,8 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from . import base, mysqldb, oursql, \ - pyodbc, zxjdbc, mysqlconnector, pymysql,\ - gaerdbms, cymysql + pyodbc, zxjdbc, mysqlconnector, pymysql,\ + gaerdbms, cymysql # default dialect base.dialect = mysqldb.dialect @@ -22,8 +22,8 @@ from .base import \ VARBINARY, VARCHAR, YEAR, dialect __all__ = ( -'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'DOUBLE', -'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT', -'MEDIUMTEXT', 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', 'TIMESTAMP', -'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', 'YEAR', 'dialect' + 'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'DOUBLE', + 'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT', + 'MEDIUMTEXT', 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', 'TIMESTAMP', + 'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', 'YEAR', 'dialect' ) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index ee5747e39..e2d411957 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -350,7 +350,7 @@ from ...engine import default from ... import types as sqltypes from ...util import topological from ...types import DATE, BOOLEAN, \ - BLOB, BINARY, VARBINARY + BLOB, BINARY, VARBINARY RESERVED_WORDS = set( ['accessible', 'add', 'all', 'alter', 'analyze', 'and', 'as', 'asc', @@ -397,9 +397,9 @@ RESERVED_WORDS = set( 'read_only', 'read_write', # 5.1 'general', 'ignore_server_ids', 'master_heartbeat_period', 'maxvalue', - 'resignal', 'signal', 'slow', # 5.5 + 'resignal', 'signal', 'slow', # 5.5 - 'get', 'io_after_gtids', 'io_before_gtids', 'master_bind', 'one_shot', + 'get', 'io_after_gtids', 'io_before_gtids', 'master_bind', 'one_shot', 'partition', 'sql_after_gtids', 'sql_before_gtids', # 5.6 ]) @@ -427,7 +427,8 @@ class _NumericType(object): def __repr__(self): return util.generic_repr(self, - to_inspect=[_NumericType, sqltypes.Numeric]) + to_inspect=[_NumericType, sqltypes.Numeric]) + class _FloatType(_NumericType, sqltypes.Float): def __init__(self, precision=None, scale=None, asdecimal=True, **kw): @@ -435,16 +436,18 @@ class _FloatType(_NumericType, sqltypes.Float): ( (precision is None and scale is not None) or (precision is not None and scale is None) - ): + ): raise exc.ArgumentError( "You must specify both precision and scale or omit " "both altogether.") - super(_FloatType, self).__init__(precision=precision, asdecimal=asdecimal, **kw) + super(_FloatType, self).__init__( + precision=precision, asdecimal=asdecimal, **kw) self.scale = scale def __repr__(self): return util.generic_repr(self, - to_inspect=[_FloatType, _NumericType, sqltypes.Float]) + to_inspect=[_FloatType, _NumericType, sqltypes.Float]) + class _IntegerType(_NumericType, sqltypes.Integer): def __init__(self, display_width=None, **kw): @@ -453,7 +456,8 @@ class _IntegerType(_NumericType, sqltypes.Integer): def __repr__(self): return util.generic_repr(self, - to_inspect=[_IntegerType, _NumericType, sqltypes.Integer]) + to_inspect=[_IntegerType, _NumericType, sqltypes.Integer]) + class _StringType(sqltypes.String): """Base for MySQL string types.""" @@ -474,7 +478,8 @@ class _StringType(sqltypes.String): def __repr__(self): return util.generic_repr(self, - to_inspect=[_StringType, sqltypes.String]) + to_inspect=[_StringType, sqltypes.String]) + class NUMERIC(_NumericType, sqltypes.NUMERIC): """MySQL NUMERIC type.""" @@ -498,7 +503,7 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC): """ super(NUMERIC, self).__init__(precision=precision, - scale=scale, asdecimal=asdecimal, **kw) + scale=scale, asdecimal=asdecimal, **kw) class DECIMAL(_NumericType, sqltypes.DECIMAL): @@ -1075,11 +1080,12 @@ class CHAR(_StringType, sqltypes.CHAR): ascii=type_.ascii, binary=type_.binary, unicode=type_.unicode, - national=False # not supported in CAST + national=False # not supported in CAST ) else: return CHAR(length=type_.length) + class NVARCHAR(_StringType, sqltypes.NVARCHAR): """MySQL NVARCHAR type. @@ -1149,6 +1155,7 @@ class LONGBLOB(sqltypes._Binary): __visit_name__ = 'LONGBLOB' + class _EnumeratedValues(_StringType): def _init_values(self, values, kw): self.quoting = kw.pop('quoting', 'auto') @@ -1191,6 +1198,7 @@ class _EnumeratedValues(_StringType): strip_values.append(a) return strip_values + class ENUM(sqltypes.Enum, _EnumeratedValues): """MySQL ENUM type.""" @@ -1258,7 +1266,7 @@ class ENUM(sqltypes.Enum, _EnumeratedValues): def __repr__(self): return util.generic_repr(self, - to_inspect=[ENUM, _StringType, sqltypes.Enum]) + to_inspect=[ENUM, _StringType, sqltypes.Enum]) def bind_processor(self, dialect): super_convert = super(ENUM, self).bind_processor(dialect) @@ -1266,7 +1274,7 @@ class ENUM(sqltypes.Enum, _EnumeratedValues): def process(value): if self.strict and value is not None and value not in self.enums: raise exc.InvalidRequestError('"%s" not a valid value for ' - 'this enum' % value) + 'this enum' % value) if super_convert: return super_convert(value) else: @@ -1480,11 +1488,11 @@ class MySQLCompiler(compiler.SQLCompiler): def visit_concat_op_binary(self, binary, operator, **kw): return "concat(%s, %s)" % (self.process(binary.left), - self.process(binary.right)) + self.process(binary.right)) def visit_match_op_binary(self, binary, operator, **kw): return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % \ - (self.process(binary.left), self.process(binary.right)) + (self.process(binary.left), self.process(binary.right)) def get_from_hint_text(self, table, text): return text @@ -1499,7 +1507,7 @@ class MySQLCompiler(compiler.SQLCompiler): elif isinstance(type_, sqltypes.TIMESTAMP): return 'DATETIME' elif isinstance(type_, (sqltypes.DECIMAL, sqltypes.DateTime, - sqltypes.Date, sqltypes.Time)): + sqltypes.Date, sqltypes.Time)): return self.dialect.type_compiler.process(type_) elif isinstance(type_, sqltypes.String) and not isinstance(type_, (ENUM, SET)): adapted = CHAR._adapt_string_for_cast(type_) @@ -1508,7 +1516,7 @@ class MySQLCompiler(compiler.SQLCompiler): return 'BINARY' elif isinstance(type_, sqltypes.NUMERIC): return self.dialect.type_compiler.process( - type_).replace('NUMERIC', 'DECIMAL') + type_).replace('NUMERIC', 'DECIMAL') else: return None @@ -1585,12 +1593,12 @@ class MySQLCompiler(compiler.SQLCompiler): # bound as part of MySQL's "syntax" for OFFSET with # no LIMIT return ' \n LIMIT %s, %s' % ( - self.process(offset_clause), - "18446744073709551615") + self.process(offset_clause), + "18446744073709551615") else: return ' \n LIMIT %s, %s' % ( - self.process(offset_clause), - self.process(limit_clause)) + self.process(offset_clause), + self.process(limit_clause)) else: # No offset provided, so just use the limit return ' \n LIMIT %s' % (self.process(limit_clause),) @@ -1604,10 +1612,10 @@ class MySQLCompiler(compiler.SQLCompiler): def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) - for t in [from_table] + list(extra_froms)) + for t in [from_table] + list(extra_froms)) def update_from_clause(self, update_stmt, from_table, - extra_froms, from_hints, **kw): + extra_froms, from_hints, **kw): return None @@ -1620,11 +1628,12 @@ class MySQLDDLCompiler(compiler.DDLCompiler): def create_table_constraints(self, table): """Get table constraints.""" constraint_string = super( - MySQLDDLCompiler, self).create_table_constraints(table) + MySQLDDLCompiler, self).create_table_constraints(table) # why self.dialect.name and not 'mysql'? because of drizzle is_innodb = 'engine' in table.dialect_options[self.dialect.name] and \ - table.dialect_options[self.dialect.name]['engine'].lower() == 'innodb' + table.dialect_options[self.dialect.name][ + 'engine'].lower() == 'innodb' auto_inc_column = table._autoincrement_column @@ -1634,11 +1643,11 @@ class MySQLDDLCompiler(compiler.DDLCompiler): if constraint_string: constraint_string += ", \n\t" constraint_string += "KEY %s (%s)" % ( - self.preparer.quote( - "idx_autoinc_%s" % auto_inc_column.name - ), - self.preparer.format_column(auto_inc_column) - ) + self.preparer.quote( + "idx_autoinc_%s" % auto_inc_column.name + ), + self.preparer.format_column(auto_inc_column) + ) return constraint_string @@ -1646,7 +1655,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler): """Builds column DDL.""" colspec = [self.preparer.format_column(column), - self.dialect.type_compiler.process(column.type) + self.dialect.type_compiler.process(column.type) ] default = self.get_column_default_string(column) @@ -1661,7 +1670,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler): colspec.append('NULL') if column is column.table._autoincrement_column and \ - column.server_default is None: + column.server_default is None: colspec.append('AUTO_INCREMENT') return ' '.join(colspec) @@ -1709,8 +1718,8 @@ class MySQLDDLCompiler(compiler.DDLCompiler): preparer = self.preparer table = preparer.format_table(index.table) columns = [self.sql_compiler.process(expr, include_table=False, - literal_binds=True) - for expr in index.expressions] + literal_binds=True) + for expr in index.expressions] name = self._prepared_index_name(index) @@ -1763,9 +1772,9 @@ class MySQLDDLCompiler(compiler.DDLCompiler): index = drop.element return "\nDROP INDEX %s ON %s" % ( - self._prepared_index_name(index, - include_schema=False), - self.preparer.format_table(index.table)) + self._prepared_index_name(index, + include_schema=False), + self.preparer.format_table(index.table)) def visit_drop_constraint(self, drop): constraint = drop.element @@ -1782,16 +1791,17 @@ class MySQLDDLCompiler(compiler.DDLCompiler): qual = "" const = self.preparer.format_constraint(constraint) return "ALTER TABLE %s DROP %s%s" % \ - (self.preparer.format_table(constraint.table), - qual, const) + (self.preparer.format_table(constraint.table), + qual, const) def define_constraint_match(self, constraint): if constraint.match is not None: raise exc.CompileError( - "MySQL ignores the 'MATCH' keyword while at the same time " - "causes ON UPDATE/ON DELETE clauses to be ignored.") + "MySQL ignores the 'MATCH' keyword while at the same time " + "causes ON UPDATE/ON DELETE clauses to be ignored.") return "" + class MySQLTypeCompiler(compiler.GenericTypeCompiler): def _extend_numeric(self, type_, spec): "Extend a numeric-type declaration with MySQL specific extensions." @@ -1845,78 +1855,78 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return self._extend_numeric(type_, "NUMERIC") elif type_.scale is None: return self._extend_numeric(type_, - "NUMERIC(%(precision)s)" % - {'precision': type_.precision}) + "NUMERIC(%(precision)s)" % + {'precision': type_.precision}) else: return self._extend_numeric(type_, - "NUMERIC(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + "NUMERIC(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale': type_.scale}) def visit_DECIMAL(self, type_): if type_.precision is None: return self._extend_numeric(type_, "DECIMAL") elif type_.scale is None: return self._extend_numeric(type_, - "DECIMAL(%(precision)s)" % - {'precision': type_.precision}) + "DECIMAL(%(precision)s)" % + {'precision': type_.precision}) else: return self._extend_numeric(type_, - "DECIMAL(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + "DECIMAL(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale': type_.scale}) def visit_DOUBLE(self, type_): if type_.precision is not None and type_.scale is not None: return self._extend_numeric(type_, - "DOUBLE(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + "DOUBLE(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale': type_.scale}) else: return self._extend_numeric(type_, 'DOUBLE') def visit_REAL(self, type_): if type_.precision is not None and type_.scale is not None: return self._extend_numeric(type_, - "REAL(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + "REAL(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale': type_.scale}) else: return self._extend_numeric(type_, 'REAL') def visit_FLOAT(self, type_): if self._mysql_type(type_) and \ - type_.scale is not None and \ - type_.precision is not None: + type_.scale is not None and \ + type_.precision is not None: return self._extend_numeric(type_, - "FLOAT(%s, %s)" % (type_.precision, type_.scale)) + "FLOAT(%s, %s)" % (type_.precision, type_.scale)) elif type_.precision is not None: return self._extend_numeric(type_, - "FLOAT(%s)" % (type_.precision,)) + "FLOAT(%s)" % (type_.precision,)) else: return self._extend_numeric(type_, "FLOAT") def visit_INTEGER(self, type_): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric(type_, - "INTEGER(%(display_width)s)" % - {'display_width': type_.display_width}) + "INTEGER(%(display_width)s)" % + {'display_width': type_.display_width}) else: return self._extend_numeric(type_, "INTEGER") def visit_BIGINT(self, type_): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric(type_, - "BIGINT(%(display_width)s)" % - {'display_width': type_.display_width}) + "BIGINT(%(display_width)s)" % + {'display_width': type_.display_width}) else: return self._extend_numeric(type_, "BIGINT") def visit_MEDIUMINT(self, type_): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric(type_, - "MEDIUMINT(%(display_width)s)" % - {'display_width': type_.display_width}) + "MEDIUMINT(%(display_width)s)" % + {'display_width': type_.display_width}) else: return self._extend_numeric(type_, "MEDIUMINT") @@ -1930,9 +1940,9 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def visit_SMALLINT(self, type_): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric(type_, - "SMALLINT(%(display_width)s)" % - {'display_width': type_.display_width} - ) + "SMALLINT(%(display_width)s)" % + {'display_width': type_.display_width} + ) else: return self._extend_numeric(type_, "SMALLINT") @@ -1989,13 +1999,13 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) else: raise exc.CompileError( - "VARCHAR requires a length on dialect %s" % - self.dialect.name) + "VARCHAR requires a length on dialect %s" % + self.dialect.name) def visit_CHAR(self, type_): if type_.length: return self._extend_string(type_, {}, "CHAR(%(length)s)" % - {'length': type_.length}) + {'length': type_.length}) else: return self._extend_string(type_, {}, "CHAR") @@ -2004,18 +2014,18 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): # of "NVARCHAR". if type_.length: return self._extend_string(type_, {'national': True}, - "VARCHAR(%(length)s)" % {'length': type_.length}) + "VARCHAR(%(length)s)" % {'length': type_.length}) else: raise exc.CompileError( - "NVARCHAR requires a length on dialect %s" % - self.dialect.name) + "NVARCHAR requires a length on dialect %s" % + self.dialect.name) def visit_NCHAR(self, type_): # We'll actually generate the equiv. # "NATIONAL CHAR" instead of "NCHAR". if type_.length: return self._extend_string(type_, {'national': True}, - "CHAR(%(length)s)" % {'length': type_.length}) + "CHAR(%(length)s)" % {'length': type_.length}) else: return self._extend_string(type_, {'national': True}, "CHAR") @@ -2051,16 +2061,16 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): for e in enumerated_values: quoted_enums.append("'%s'" % e.replace("'", "''")) return self._extend_string(type_, {}, "%s(%s)" % ( - name, ",".join(quoted_enums)) - ) + name, ",".join(quoted_enums)) + ) def visit_ENUM(self, type_): return self._visit_enumerated_values("ENUM", type_, - type_._enumerated_values) + type_._enumerated_values) def visit_SET(self, type_): return self._visit_enumerated_values("SET", type_, - type_._enumerated_values) + type_._enumerated_values) def visit_BOOLEAN(self, type): return "BOOL" @@ -2077,9 +2087,9 @@ class MySQLIdentifierPreparer(compiler.IdentifierPreparer): quote = '"' super(MySQLIdentifierPreparer, self).__init__( - dialect, - initial_quote=quote, - escape_quote=quote) + dialect, + initial_quote=quote, + escape_quote=quote) def _quote_free_identifiers(self, *ids): """Unilaterally identifier-quote any number of strings.""" @@ -2149,7 +2159,7 @@ class MySQLDialect(default.DefaultDialect): return None _isolation_lookup = set(['SERIALIZABLE', - 'READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ']) + 'READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ']) def set_isolation_level(self, connection, level): level = level.replace('_', ' ') @@ -2158,7 +2168,7 @@ class MySQLDialect(default.DefaultDialect): "Invalid value '%s' for isolation_level. " "Valid isolation levels for %s are %s" % (level, self.name, ", ".join(self._isolation_lookup)) - ) + ) cursor = connection.cursor() cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level) cursor.execute("COMMIT") @@ -2228,7 +2238,7 @@ class MySQLDialect(default.DefaultDialect): def is_disconnect(self, e, connection, cursor): if isinstance(e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)): return self._extract_error_code(e) in \ - (2006, 2013, 2014, 2045, 2055) + (2006, 2013, 2014, 2045, 2055) elif isinstance(e, self.dbapi.InterfaceError): # if underlying connection is closed, # this is the error you get @@ -2297,14 +2307,14 @@ class MySQLDialect(default.DefaultDialect): # if ansiquotes == True, build a new IdentifierPreparer # with the new setting self.identifier_preparer = self.preparer(self, - server_ansiquotes=self._server_ansiquotes) + server_ansiquotes=self._server_ansiquotes) default.DefaultDialect.initialize(self, connection) @property def _supports_cast(self): return self.server_version_info is None or \ - self.server_version_info >= (4, 0, 2) + self.server_version_info >= (4, 0, 2) @reflection.cache def get_schema_names(self, connection, **kw): @@ -2322,16 +2332,16 @@ class MySQLDialect(default.DefaultDialect): charset = self._connection_charset if self.server_version_info < (5, 0, 2): rp = connection.execute("SHOW TABLES FROM %s" % - self.identifier_preparer.quote_identifier(current_schema)) + self.identifier_preparer.quote_identifier(current_schema)) return [row[0] for - row in self._compat_fetchall(rp, charset=charset)] + row in self._compat_fetchall(rp, charset=charset)] else: rp = connection.execute("SHOW FULL TABLES FROM %s" % - self.identifier_preparer.quote_identifier(current_schema)) + self.identifier_preparer.quote_identifier(current_schema)) return [row[0] - for row in self._compat_fetchall(rp, charset=charset) - if row[1] == 'BASE TABLE'] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] == 'BASE TABLE'] @reflection.cache def get_view_names(self, connection, schema=None, **kw): @@ -2343,28 +2353,28 @@ class MySQLDialect(default.DefaultDialect): return self.get_table_names(connection, schema) charset = self._connection_charset rp = connection.execute("SHOW FULL TABLES FROM %s" % - self.identifier_preparer.quote_identifier(schema)) + self.identifier_preparer.quote_identifier(schema)) return [row[0] - for row in self._compat_fetchall(rp, charset=charset) - if row[1] in ('VIEW', 'SYSTEM VIEW')] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] in ('VIEW', 'SYSTEM VIEW')] @reflection.cache def get_table_options(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw) return parsed_state.table_options @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw) return parsed_state.columns @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw) for key in parsed_state.keys: if key['type'] == 'PRIMARY': # There can be only one. @@ -2376,7 +2386,7 @@ class MySQLDialect(default.DefaultDialect): def get_foreign_keys(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw) default_schema = None fkeys = [] @@ -2416,7 +2426,7 @@ class MySQLDialect(default.DefaultDialect): def get_indexes(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw) indexes = [] for spec in parsed_state.keys: @@ -2466,13 +2476,13 @@ class MySQLDialect(default.DefaultDialect): return sql def _parsed_state_or_create(self, connection, table_name, - schema=None, **kw): + schema=None, **kw): return self._setup_parser( - connection, - table_name, - schema, - info_cache=kw.get('info_cache', None) - ) + connection, + table_name, + schema, + info_cache=kw.get('info_cache', None) + ) @util.memoized_property def _tabledef_parser(self): @@ -2519,7 +2529,7 @@ class MySQLDialect(default.DefaultDialect): charset = self._connection_charset row = self._compat_first(connection.execute( "SHOW VARIABLES LIKE 'lower_case_table_names'"), - charset=charset) + charset=charset) if not row: cs = 0 else: @@ -2554,7 +2564,7 @@ class MySQLDialect(default.DefaultDialect): row = self._compat_first( connection.execute("SHOW VARIABLES LIKE 'sql_mode'"), - charset=self._connection_charset) + charset=self._connection_charset) if not row: mode = '' @@ -2570,7 +2580,6 @@ class MySQLDialect(default.DefaultDialect): # as of MySQL 5.0.1 self._backslash_escapes = 'NO_BACKSLASH_ESCAPES' not in mode - def _show_create_table(self, connection, table, charset=None, full_name=None): """Run SHOW CREATE TABLE for a ``Table``.""" @@ -2595,7 +2604,7 @@ class MySQLDialect(default.DefaultDialect): return sql def _describe_table(self, connection, table, charset=None, - full_name=None): + full_name=None): """Run DESCRIBE for a ``Table`` and return processed rows.""" if full_name is None: @@ -2687,7 +2696,7 @@ class MySQLTableDefinitionParser(object): if m: spec = m.groupdict() spec['table'] = \ - self.preparer.unformat_identifiers(spec['table']) + self.preparer.unformat_identifiers(spec['table']) spec['local'] = [c[0] for c in self._parse_keyexprs(spec['local'])] spec['foreign'] = [c[0] @@ -2768,7 +2777,7 @@ class MySQLTableDefinitionParser(object): util.warn("Incomplete reflection of column definition %r" % line) name, type_, args, notnull = \ - spec['name'], spec['coltype'], spec['arg'], spec['notnull'] + spec['name'], spec['coltype'], spec['arg'], spec['notnull'] try: col_type = self.dialect.ischema_names[type_] @@ -2838,7 +2847,7 @@ class MySQLTableDefinitionParser(object): buffer = [] for row in columns: (name, col_type, nullable, default, extra) = \ - [row[i] for i in (0, 1, 2, 4, 5)] + [row[i] for i in (0, 1, 2, 4, 5)] line = [' '] line.append(self.preparer.quote_identifier(name)) @@ -2917,15 +2926,15 @@ class MySQLTableDefinitionParser(object): r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' r'(?P<coltype>\w+)' r'(?:\((?P<arg>(?:\d+|\d+,\d+|' - r'(?:\x27(?:\x27\x27|[^\x27])*\x27,?)+))\))?' + r'(?:\x27(?:\x27\x27|[^\x27])*\x27,?)+))\))?' r'(?: +(?P<unsigned>UNSIGNED))?' r'(?: +(?P<zerofill>ZEROFILL))?' r'(?: +CHARACTER SET +(?P<charset>[\w_]+))?' r'(?: +COLLATE +(?P<collate>[\w_]+))?' r'(?: +(?P<notnull>NOT NULL))?' r'(?: +DEFAULT +(?P<default>' - r'(?:NULL|\x27(?:\x27\x27|[^\x27])*\x27|\w+' - r'(?: +ON UPDATE \w+)?)' + r'(?:NULL|\x27(?:\x27\x27|[^\x27])*\x27|\w+' + r'(?: +ON UPDATE \w+)?)' r'))?' r'(?: +(?P<autoincr>AUTO_INCREMENT))?' r'(?: +COMMENT +(P<comment>(?:\x27\x27|[^\x27])+))?' @@ -2934,7 +2943,7 @@ class MySQLTableDefinitionParser(object): r'(?: +(?P<extra>.*))?' r',?$' % quotes - ) + ) # Fallback, try to parse as little as possible self._re_column_loose = _re_compile( @@ -2944,7 +2953,7 @@ class MySQLTableDefinitionParser(object): r'(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?' r'.*?(?P<notnull>NOT NULL)?' % quotes - ) + ) # (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))? # (`col` (ASC|DESC)?, `col` (ASC|DESC)?) @@ -2960,7 +2969,7 @@ class MySQLTableDefinitionParser(object): r'(?: +WITH PARSER +(?P<parser>\S+))?' r',?$' % quotes - ) + ) # CONSTRAINT `name` FOREIGN KEY (`local_col`) # REFERENCES `remote` (`remote_col`) @@ -2982,7 +2991,7 @@ class MySQLTableDefinitionParser(object): r'(?: +ON DELETE (?P<ondelete>%(on)s))?' r'(?: +ON UPDATE (?P<onupdate>%(on)s))?' % kw - ) + ) # PARTITION # @@ -3006,7 +3015,7 @@ class MySQLTableDefinitionParser(object): self._add_option_regex('UNION', r'\([^\)]+\)') self._add_option_regex('TABLESPACE', r'.*? STORAGE DISK') self._add_option_regex('RAID_TYPE', - r'\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+') + r'\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+') _optional_equals = r'(?:\s*(?:=\s*)|\s+)' @@ -3015,7 +3024,7 @@ class MySQLTableDefinitionParser(object): r"'(?P<val>(?:[^']|'')*?)'(?!')" % (re.escape(directive), self._optional_equals)) self._pr_options.append(_pr_compile(regex, lambda v: - v.replace("\\\\", "\\").replace("''", "'"))) + v.replace("\\\\", "\\").replace("''", "'"))) def _add_option_word(self, directive): regex = (r'(?P<directive>%s)%s' @@ -3033,7 +3042,6 @@ _options_of_type_string = ('COMMENT', 'DATA DIRECTORY', 'INDEX DIRECTORY', 'PASSWORD', 'CONNECTION') - class _DecodingRowProxy(object): """Return unicode-decoded values based on type inspection. diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py index c9f82a0bd..7bf8fac66 100644 --- a/lib/sqlalchemy/dialects/mysql/cymysql.py +++ b/lib/sqlalchemy/dialects/mysql/cymysql.py @@ -20,6 +20,7 @@ from .mysqldb import MySQLDialect_mysqldb from .base import (BIT, MySQLDialect) from ... import util + class _cymysqlBIT(BIT): def result_processor(self, dialect, coltype): """Convert a MySQL's 64 bit, variable length binary string to a long. @@ -74,7 +75,7 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb): def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.OperationalError): return self._extract_error_code(e) in \ - (2006, 2013, 2014, 2045, 2055) + (2006, 2013, 2014, 2045, 2055) elif isinstance(e, self.dbapi.InterfaceError): # if underlying connection is closed, # this is the error you get diff --git a/lib/sqlalchemy/dialects/mysql/gaerdbms.py b/lib/sqlalchemy/dialects/mysql/gaerdbms.py index 6f231198d..56a4af205 100644 --- a/lib/sqlalchemy/dialects/mysql/gaerdbms.py +++ b/lib/sqlalchemy/dialects/mysql/gaerdbms.py @@ -45,7 +45,7 @@ class MySQLDialect_gaerdbms(MySQLDialect_mysqldb): # from django: # http://code.google.com/p/googleappengine/source/ # browse/trunk/python/google/storage/speckle/ - # python/django/backend/base.py#118 + # python/django/backend/base.py#118 # see also [ticket:2649] # see also http://stackoverflow.com/q/14224679/34549 from google.appengine.api import apiproxy_stub_map diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index 91223e270..8d38ef4a0 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -16,8 +16,8 @@ """ from .base import (MySQLDialect, - MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer, - BIT) + MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer, + BIT) from ... import util @@ -31,7 +31,7 @@ class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext): class MySQLCompiler_mysqlconnector(MySQLCompiler): def visit_mod_binary(self, binary, operator, **kw): return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + self.process(binary.right, **kw) def post_process_text(self, text): return text.replace('%', '%%') @@ -98,7 +98,8 @@ class MySQLDialect_mysqlconnector(MySQLDialect): if self.dbapi is not None: try: from mysql.connector.constants import ClientFlag - client_flags = opts.get('client_flags', ClientFlag.get_default()) + client_flags = opts.get( + 'client_flags', ClientFlag.get_default()) client_flags |= ClientFlag.FOUND_ROWS opts['client_flags'] = client_flags except: diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 8ee367a07..06f38fa93 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -42,7 +42,7 @@ It is strongly advised to use the latest version of MySQL-Python. """ from .base import (MySQLDialect, MySQLExecutionContext, - MySQLCompiler, MySQLIdentifierPreparer) + MySQLCompiler, MySQLIdentifierPreparer) from .base import TEXT from ... import sql from ... import util @@ -58,14 +58,16 @@ class MySQLExecutionContext_mysqldb(MySQLExecutionContext): else: return self.cursor.rowcount + class MySQLCompiler_mysqldb(MySQLCompiler): def visit_mod_binary(self, binary, operator, **kw): return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + self.process(binary.right, **kw) def post_process_text(self, text): return text.replace('%', '%%') + class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer): def _escape_identifier(self, value): @@ -86,7 +88,6 @@ class MySQLDialect_mysqldb(MySQLDialect): statement_compiler = MySQLCompiler_mysqldb preparer = MySQLIdentifierPreparer_mysqldb - @classmethod def dbapi(cls): return __import__('MySQLdb') @@ -102,23 +103,22 @@ class MySQLDialect_mysqldb(MySQLDialect): # specific issue w/ the utf8_bin collation and unicode returns has_utf8_bin = connection.scalar( - "show collation where %s = 'utf8' and %s = 'utf8_bin'" - % ( - self.identifier_preparer.quote("Charset"), - self.identifier_preparer.quote("Collation") - )) + "show collation where %s = 'utf8' and %s = 'utf8_bin'" + % ( + self.identifier_preparer.quote("Charset"), + self.identifier_preparer.quote("Collation") + )) if has_utf8_bin: additional_tests = [ sql.collate(sql.cast( - sql.literal_column( + sql.literal_column( "'test collated returns'"), - TEXT(charset='utf8')), "utf8_bin") + TEXT(charset='utf8')), "utf8_bin") ] else: additional_tests = [] return super(MySQLDialect_mysqldb, self)._check_unicode_returns( - connection, additional_tests) - + connection, additional_tests) def create_connect_args(self, url): opts = url.translate_connect_args(database='db', username='user', @@ -155,8 +155,8 @@ class MySQLDialect_mysqldb(MySQLDialect): if self.dbapi is not None: try: CLIENT_FLAGS = __import__( - self.dbapi.__name__ + '.constants.CLIENT' - ).constants.CLIENT + self.dbapi.__name__ + '.constants.CLIENT' + ).constants.CLIENT client_flag |= CLIENT_FLAGS.FOUND_ROWS except (AttributeError, ImportError): self.supports_sane_rowcount = False diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py index 12136514c..eba117c5f 100644 --- a/lib/sqlalchemy/dialects/mysql/oursql.py +++ b/lib/sqlalchemy/dialects/mysql/oursql.py @@ -95,9 +95,11 @@ class MySQLDialect_oursql(MySQLDialect): arg = connection.connection._escape_string(xid) else: charset = self._connection_charset - arg = connection.connection._escape_string(xid.encode(charset)).decode(charset) + arg = connection.connection._escape_string( + xid.encode(charset)).decode(charset) arg = "'%s'" % arg - connection.execution_options(_oursql_plain_query=True).execute(query % arg) + connection.execution_options( + _oursql_plain_query=True).execute(query % arg) # Because mysql is bad, these methods have to be # reimplemented to use _PlainQuery. Basically, some queries @@ -127,10 +129,10 @@ class MySQLDialect_oursql(MySQLDialect): # am i on a newer/older version of OurSQL ? def has_table(self, connection, table_name, schema=None): return MySQLDialect.has_table( - self, - connection.connect().execution_options(_oursql_plain_query=True), - table_name, - schema + self, + connection.connect().execution_options(_oursql_plain_query=True), + table_name, + schema ) def get_table_options(self, connection, table_name, schema=None, **kw): @@ -218,7 +220,7 @@ class MySQLDialect_oursql(MySQLDialect): ssl = {} for key in ['ssl_ca', 'ssl_key', 'ssl_cert', - 'ssl_capath', 'ssl_cipher']: + 'ssl_capath', 'ssl_cipher']: if key in opts: ssl[key[4:]] = opts[key] util.coerce_kw_type(ssl, key[4:], str) diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py index b05c22295..e3fdea753 100644 --- a/lib/sqlalchemy/dialects/mysql/pymysql.py +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -25,6 +25,7 @@ the pymysql driver as well. from .mysqldb import MySQLDialect_mysqldb from ...util import py3k + class MySQLDialect_pymysql(MySQLDialect_mysqldb): driver = 'pymysql' @@ -32,7 +33,6 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): if py3k: supports_unicode_statements = True - @classmethod def dbapi(cls): return __import__('pymysql') diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index 8b6821643..c45c673a0 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -67,7 +67,8 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): if opts.get(key, None): return opts[key] - util.warn("Could not detect the connection character set. Assuming latin1.") + util.warn( + "Could not detect the connection character set. Assuming latin1.") return 'latin1' def _extract_error_code(self, exception): diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py index 17e062770..81a4c1e13 100644 --- a/lib/sqlalchemy/dialects/mysql/zxjdbc.py +++ b/lib/sqlalchemy/dialects/mysql/zxjdbc.py @@ -83,7 +83,8 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect): if opts.get(key, None): return opts[key] - util.warn("Could not detect the connection character set. Assuming latin1.") + util.warn( + "Could not detect the connection character set. Assuming latin1.") return 'latin1' def _driver_kwargs(self): diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py index 4e57e3cee..fd32f2235 100644 --- a/lib/sqlalchemy/dialects/oracle/__init__.py +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -17,8 +17,8 @@ from sqlalchemy.dialects.oracle.base import \ __all__ = ( -'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'NUMBER', -'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW', -'FLOAT', 'DOUBLE_PRECISION', 'LONG', 'dialect', 'INTERVAL', -'VARCHAR2', 'NVARCHAR2', 'ROWID' + 'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'NUMBER', + 'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW', + 'FLOAT', 'DOUBLE_PRECISION', 'LONG', 'dialect', 'INTERVAL', + 'VARCHAR2', 'NVARCHAR2', 'ROWID' ) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 781fc601f..e872a3f9a 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -205,21 +205,21 @@ from sqlalchemy.sql import compiler, visitors, expression from sqlalchemy.sql import operators as sql_operators, functions as sql_functions from sqlalchemy import types as sqltypes, schema as sa_schema from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, \ - BLOB, CLOB, TIMESTAMP, FLOAT + BLOB, CLOB, TIMESTAMP, FLOAT RESERVED_WORDS = \ - set('SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN '\ - 'DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED '\ - 'ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE '\ - 'ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE '\ - 'BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES '\ - 'AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS '\ - 'NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER '\ - 'CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR '\ + set('SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN ' + 'DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED ' + 'ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE ' + 'ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE ' + 'BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES ' + 'AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS ' + 'NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER ' + 'CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR ' 'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL'.split()) NO_ARG_FNS = set('UID CURRENT_DATE SYSDATE USER ' - 'CURRENT_TIME CURRENT_TIMESTAMP'.split()) + 'CURRENT_TIME CURRENT_TIMESTAMP'.split()) class RAW(sqltypes._Binary): @@ -244,7 +244,8 @@ class NUMBER(sqltypes.Numeric, sqltypes.Integer): if asdecimal is None: asdecimal = bool(scale and scale > 0) - super(NUMBER, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal) + super(NUMBER, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal) def adapt(self, impltype): ret = super(NUMBER, self).adapt(impltype) @@ -267,7 +268,8 @@ class DOUBLE_PRECISION(sqltypes.Numeric): if asdecimal is None: asdecimal = False - super(DOUBLE_PRECISION, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal) + super(DOUBLE_PRECISION, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal) class BFILE(sqltypes.LargeBinary): @@ -277,6 +279,7 @@ class BFILE(sqltypes.LargeBinary): class LONG(sqltypes.Text): __visit_name__ = 'LONG' + class DATE(sqltypes.DateTime): """Provide the oracle DATE type. @@ -289,7 +292,6 @@ class DATE(sqltypes.DateTime): """ __visit_name__ = 'DATE' - def _compare_type_affinity(self, other): return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) @@ -298,8 +300,8 @@ class INTERVAL(sqltypes.TypeEngine): __visit_name__ = 'INTERVAL' def __init__(self, - day_precision=None, - second_precision=None): + day_precision=None, + second_precision=None): """Construct an INTERVAL. Note that only DAY TO SECOND intervals are currently supported. @@ -385,11 +387,11 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_INTERVAL(self, type_): return "INTERVAL DAY%s TO SECOND%s" % ( type_.day_precision is not None and - "(%d)" % type_.day_precision or - "", + "(%d)" % type_.day_precision or + "", type_.second_precision is not None and - "(%d)" % type_.second_precision or - "", + "(%d)" % type_.second_precision or + "", ) def visit_LONG(self, type_): @@ -483,7 +485,7 @@ class OracleCompiler(compiler.SQLCompiler): compound_keywords = util.update_copy( compiler.SQLCompiler.compound_keywords, { - expression.CompoundSelect.EXCEPT: 'MINUS' + expression.CompoundSelect.EXCEPT: 'MINUS' } ) @@ -504,7 +506,7 @@ class OracleCompiler(compiler.SQLCompiler): def visit_match_op_binary(self, binary, operator, **kw): return "CONTAINS (%s, %s)" % (self.process(binary.left), - self.process(binary.right)) + self.process(binary.right)) def visit_true(self, expr, **kw): return '1' @@ -542,8 +544,7 @@ class OracleCompiler(compiler.SQLCompiler): else: right = join.right return self.process(join.left, **kwargs) + \ - ", " + self.process(right, **kwargs) - + ", " + self.process(right, **kwargs) def _get_nonansi_join_whereclause(self, froms): clauses = [] @@ -557,7 +558,7 @@ class OracleCompiler(compiler.SQLCompiler): elif join.right.is_derived_from(binary.right.table): binary.right = _OuterJoinColumn(binary.right) clauses.append(visitors.cloned_traverse(join.onclause, {}, - {'binary': visit_binary})) + {'binary': visit_binary})) else: clauses.append(join.onclause) @@ -587,13 +588,13 @@ class OracleCompiler(compiler.SQLCompiler): if asfrom or ashint: alias_name = isinstance(alias.name, expression._truncated_label) and \ - self._truncated_identifier("alias", alias.name) or alias.name + self._truncated_identifier("alias", alias.name) or alias.name if ashint: return alias_name elif asfrom: return self.process(alias.original, asfrom=asfrom, **kwargs) + \ - " " + self.preparer.format_alias(alias, alias_name) + " " + self.preparer.format_alias(alias, alias_name) else: return self.process(alias.original, **kwargs) @@ -607,12 +608,13 @@ class OracleCompiler(compiler.SQLCompiler): col_expr = column outparam = sql.outparam("ret_%d" % i, type_=column.type) self.binds[outparam.key] = outparam - binds.append(self.bindparam_string(self._truncate_bindparam(outparam))) + binds.append( + self.bindparam_string(self._truncate_bindparam(outparam))) columns.append(self.process(col_expr, within_columns_clause=False)) self.result_map[outparam.key] = ( outparam.key, (column, getattr(column, 'name', None), - getattr(column, 'key', None)), + getattr(column, 'key', None)), column.type ) @@ -630,7 +632,7 @@ class OracleCompiler(compiler.SQLCompiler): if not getattr(select, '_oracle_visit', None): if not self.dialect.use_ansi: froms = self._display_froms_for_select( - select, kwargs.get('asfrom', False)) + select, kwargs.get('asfrom', False)) whereclause = self._get_nonansi_join_whereclause(froms) if whereclause is not None: select = select.where(whereclause) @@ -659,8 +661,8 @@ class OracleCompiler(compiler.SQLCompiler): self.dialect.optimize_limits and \ select._simple_int_limit: limitselect = limitselect.prefix_with( - "/*+ FIRST_ROWS(%d) */" % - select._limit) + "/*+ FIRST_ROWS(%d) */" % + select._limit) limitselect._oracle_visit = True limitselect._is_wrapper = True @@ -680,7 +682,7 @@ class OracleCompiler(compiler.SQLCompiler): if offset_clause is not None: max_row = max_row + offset_clause limitselect.append_whereclause( - sql.literal_column("ROWNUM") <= max_row) + sql.literal_column("ROWNUM") <= max_row) # If needed, add the ora_rn, and wrap again with offset. if offset_clause is None: @@ -688,20 +690,20 @@ class OracleCompiler(compiler.SQLCompiler): select = limitselect else: limitselect = limitselect.column( - sql.literal_column("ROWNUM").label("ora_rn")) + sql.literal_column("ROWNUM").label("ora_rn")) limitselect._oracle_visit = True limitselect._is_wrapper = True offsetselect = sql.select( - [c for c in limitselect.c if c.key != 'ora_rn']) + [c for c in limitselect.c if c.key != 'ora_rn']) offsetselect._oracle_visit = True offsetselect._is_wrapper = True if not self.dialect.use_binds_for_limits: offset_clause = sql.literal_column( - "%d" % select._offset) + "%d" % select._offset) offsetselect.append_whereclause( - sql.literal_column("ora_rn") > offset_clause) + sql.literal_column("ora_rn") > offset_clause) offsetselect._for_update_arg = select._for_update_arg select = offsetselect @@ -720,9 +722,9 @@ class OracleCompiler(compiler.SQLCompiler): if select._for_update_arg.of: tmp += ' OF ' + ', '.join( - self.process(elem) for elem in - select._for_update_arg.of - ) + self.process(elem) for elem in + select._for_update_arg.of + ) if select._for_update_arg.nowait: tmp += " NOWAIT" @@ -738,18 +740,19 @@ class OracleDDLCompiler(compiler.DDLCompiler): text += " ON DELETE %s" % constraint.ondelete # oracle has no ON UPDATE CASCADE - - # its only available via triggers http://asktom.oracle.com/tkyte/update_cascade/index.html + # its only available via triggers + # http://asktom.oracle.com/tkyte/update_cascade/index.html if constraint.onupdate is not None: util.warn( "Oracle does not contain native UPDATE CASCADE " - "functionality - onupdates will not be rendered for foreign keys. " - "Consider using deferrable=True, initially='deferred' or triggers.") + "functionality - onupdates will not be rendered for foreign keys. " + "Consider using deferrable=True, initially='deferred' or triggers.") return text def visit_create_index(self, create, **kw): return super(OracleDDLCompiler, self).\ - visit_create_index(create, include_schema=True) + visit_create_index(create, include_schema=True) class OracleIdentifierPreparer(compiler.IdentifierPreparer): @@ -773,8 +776,8 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer): class OracleExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): return self._execute_scalar("SELECT " + - self.dialect.identifier_preparer.format_sequence(seq) + - ".nextval FROM DUAL", type_) + self.dialect.identifier_preparer.format_sequence(seq) + + ".nextval FROM DUAL", type_) class OracleDialect(default.DefaultDialect): @@ -811,10 +814,10 @@ class OracleDialect(default.DefaultDialect): ] def __init__(self, - use_ansi=True, - optimize_limits=False, - use_binds_for_limits=True, - **kwargs): + use_ansi=True, + optimize_limits=False, + use_binds_for_limits=True, + **kwargs): default.DefaultDialect.__init__(self, **kwargs) self.use_ansi = use_ansi self.optimize_limits = optimize_limits @@ -823,9 +826,9 @@ class OracleDialect(default.DefaultDialect): def initialize(self, connection): super(OracleDialect, self).initialize(connection) self.implicit_returning = self.__dict__.get( - 'implicit_returning', - self.server_version_info > (10, ) - ) + 'implicit_returning', + self.server_version_info > (10, ) + ) if self._is_oracle_8: self.colspecs = self.colspecs.copy() @@ -835,7 +838,7 @@ class OracleDialect(default.DefaultDialect): @property def _is_oracle_8(self): return self.server_version_info and \ - self.server_version_info < (9, ) + self.server_version_info < (9, ) @property def _supports_char_length(self): @@ -874,7 +877,7 @@ class OracleDialect(default.DefaultDialect): if isinstance(name, str): name = name.decode(self.encoding) if name.upper() == name and \ - not self.identifier_preparer._requires_quotes(name.lower()): + not self.identifier_preparer._requires_quotes(name.lower()): return name.lower() else: return name @@ -903,7 +906,7 @@ class OracleDialect(default.DefaultDialect): """ q = "SELECT owner, table_owner, table_name, db_link, "\ - "synonym_name FROM all_synonyms WHERE " + "synonym_name FROM all_synonyms WHERE " clauses = [] params = {} if desired_synonym: @@ -928,7 +931,8 @@ class OracleDialect(default.DefaultDialect): else: rows = result.fetchall() if len(rows) > 1: - raise AssertionError("There are multiple tables visible to the schema, you must specify owner") + raise AssertionError( + "There are multiple tables visible to the schema, you must specify owner") elif len(rows) == 1: row = rows[0] return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name'] @@ -941,10 +945,10 @@ class OracleDialect(default.DefaultDialect): if resolve_synonyms: actual_name, owner, dblink, synonym = self._resolve_synonym( - connection, - desired_owner=self.denormalize_name(schema), - desired_synonym=self.denormalize_name(table_name) - ) + connection, + desired_owner=self.denormalize_name(schema), + desired_synonym=self.denormalize_name(table_name) + ) else: actual_name, owner, dblink, synonym = None, None, None, None if not actual_name: @@ -957,8 +961,8 @@ class OracleDialect(default.DefaultDialect): # will need to hear from more users if we are doing # the right thing here. See [ticket:2619] owner = connection.scalar( - sql.text("SELECT username FROM user_db_links " - "WHERE db_link=:link"), link=dblink) + sql.text("SELECT username FROM user_db_links " + "WHERE db_link=:link"), link=dblink) dblink = "@" + dblink elif not owner: owner = self.denormalize_name(schema or self.default_schema_name) @@ -1021,9 +1025,9 @@ class OracleDialect(default.DefaultDialect): params = {"table_name": table_name} text = "SELECT column_name, data_type, %(char_length_col)s, "\ - "data_precision, data_scale, "\ - "nullable, data_default FROM ALL_TAB_COLUMNS%(dblink)s "\ - "WHERE table_name = :table_name" + "data_precision, data_scale, "\ + "nullable, data_default FROM ALL_TAB_COLUMNS%(dblink)s "\ + "WHERE table_name = :table_name" if schema is not None: params['owner'] = schema text += " AND owner = :owner " @@ -1034,7 +1038,8 @@ class OracleDialect(default.DefaultDialect): for row in c: (colname, orig_colname, coltype, length, precision, scale, nullable, default) = \ - (self.normalize_name(row[0]), row[0], row[1], row[2], row[3], row[4], row[5] == 'Y', row[6]) + (self.normalize_name(row[0]), row[0], row[1], row[ + 2], row[3], row[4], row[5] == 'Y', row[6]) if coltype == 'NUMBER': coltype = NUMBER(precision, scale) @@ -1121,21 +1126,23 @@ class OracleDialect(default.DefaultDialect): for rset in rp: if rset.index_name != last_index_name: remove_if_primary_key(index) - index = dict(name=self.normalize_name(rset.index_name), column_names=[]) + index = dict( + name=self.normalize_name(rset.index_name), column_names=[]) indexes.append(index) index['unique'] = uniqueness.get(rset.uniqueness, False) # filter out Oracle SYS_NC names. could also do an outer join # to the all_tab_columns table and check for real col names there. if not oracle_sys_col.match(rset.column_name): - index['column_names'].append(self.normalize_name(rset.column_name)) + index['column_names'].append( + self.normalize_name(rset.column_name)) last_index_name = rset.index_name remove_if_primary_key(index) return indexes @reflection.cache def _get_constraint_data(self, connection, table_name, schema=None, - dblink='', **kw): + dblink='', **kw): params = {'table_name': table_name} @@ -1185,8 +1192,8 @@ class OracleDialect(default.DefaultDialect): pkeys = [] constraint_name = None constraint_data = self._get_constraint_data(connection, table_name, - schema, dblink, - info_cache=kw.get('info_cache')) + schema, dblink, + info_cache=kw.get('info_cache')) for row in constraint_data: (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ @@ -1220,8 +1227,8 @@ class OracleDialect(default.DefaultDialect): info_cache=info_cache) constraint_data = self._get_constraint_data(connection, table_name, - schema, dblink, - info_cache=kw.get('info_cache')) + schema, dblink, + info_cache=kw.get('info_cache')) def fkey_rec(): return { @@ -1236,7 +1243,7 @@ class OracleDialect(default.DefaultDialect): for row in constraint_data: (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ - row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) if cons_type == 'R': if remote_table is None: @@ -1249,19 +1256,23 @@ class OracleDialect(default.DefaultDialect): rec = fkeys[cons_name] rec['name'] = cons_name - local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns'] + local_cols, remote_cols = rec[ + 'constrained_columns'], rec['referred_columns'] if not rec['referred_table']: if resolve_synonyms: ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \ - self._resolve_synonym( - connection, - desired_owner=self.denormalize_name(remote_owner), - desired_table=self.denormalize_name(remote_table) - ) + self._resolve_synonym( + connection, + desired_owner=self.denormalize_name( + remote_owner), + desired_table=self.denormalize_name( + remote_table) + ) if ref_synonym: remote_table = self.normalize_name(ref_synonym) - remote_owner = self.normalize_name(ref_remote_owner) + remote_owner = self.normalize_name( + ref_remote_owner) rec['referred_table'] = remote_table diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index bb3c837cc..e9f5780d2 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -314,7 +314,7 @@ class _OracleNumeric(sqltypes.Numeric): if self.precision is None and self.scale is None: return processors.to_float elif not getattr(self, '_is_oracle_number', False) \ - and self.scale is not None: + and self.scale is not None: return processors.to_float else: return None @@ -322,7 +322,7 @@ class _OracleNumeric(sqltypes.Numeric): # cx_oracle 4 behavior, will assume # floats return super(_OracleNumeric, self).\ - result_processor(dialect, coltype) + result_processor(dialect, coltype) class _OracleDate(sqltypes.Date): @@ -392,6 +392,7 @@ class _OracleLong(oracle.LONG): def get_dbapi_type(self, dbapi): return dbapi.LONG_STRING + class _OracleString(_NativeUnicodeMixin, sqltypes.String): pass @@ -405,7 +406,8 @@ class _OracleUnicodeText(_LOBMixin, _NativeUnicodeMixin, sqltypes.UnicodeText): if lob_processor is None: return None - string_processor = sqltypes.UnicodeText.result_processor(self, dialect, coltype) + string_processor = sqltypes.UnicodeText.result_processor( + self, dialect, coltype) if string_processor is None: return lob_processor @@ -450,7 +452,7 @@ class OracleCompiler_cx_oracle(OracleCompiler): def bindparam_string(self, name, **kw): quote = getattr(name, 'quote', None) if quote is True or quote is not False and \ - self.preparer._bindparam_requires_quotes(name): + self.preparer._bindparam_requires_quotes(name): quoted_name = '"%s"' % name self._quoted_bind_names[name] = quoted_name return OracleCompiler.bindparam_string(self, quoted_name, **kw) @@ -470,12 +472,12 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): # here. so convert names in quoted_bind_names # to encoded as well. quoted_bind_names = \ - dict( - (fromname.encode(self.dialect.encoding), - toname.encode(self.dialect.encoding)) - for fromname, toname in - quoted_bind_names.items() - ) + dict( + (fromname.encode(self.dialect.encoding), + toname.encode(self.dialect.encoding)) + for fromname, toname in + quoted_bind_names.items() + ) for param in self.parameters: for fromname, toname in quoted_bind_names.items(): param[toname] = param[fromname] @@ -487,27 +489,27 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): # breaks for varchars self.set_input_sizes(quoted_bind_names, exclude_types=self.dialect.exclude_setinputsizes - ) + ) # if a single execute, check for outparams if len(self.compiled_parameters) == 1: for bindparam in self.compiled.binds.values(): if bindparam.isoutparam: dbtype = bindparam.type.dialect_impl(self.dialect).\ - get_dbapi_type(self.dialect.dbapi) + get_dbapi_type(self.dialect.dbapi) if not hasattr(self, 'out_parameters'): self.out_parameters = {} if dbtype is None: raise exc.InvalidRequestError( - "Cannot create out parameter for parameter " - "%r - its type %r is not supported by" - " cx_oracle" % - (bindparam.key, bindparam.type) - ) + "Cannot create out parameter for parameter " + "%r - its type %r is not supported by" + " cx_oracle" % + (bindparam.key, bindparam.type) + ) name = self.compiled.bind_names[bindparam] self.out_parameters[name] = self.cursor.var(dbtype) self.parameters[0][quoted_bind_names.get(name, name)] = \ - self.out_parameters[name] + self.out_parameters[name] def create_cursor(self): c = self._dbapi_connection.cursor() @@ -519,9 +521,9 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): def get_result_proxy(self): if hasattr(self, 'out_parameters') and self.compiled.returning: returning_params = dict( - (k, v.getvalue()) - for k, v in self.out_parameters.items() - ) + (k, v.getvalue()) + for k, v in self.out_parameters.items() + ) return ReturningResultProxy(self, returning_params) result = None @@ -543,20 +545,23 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): if name in self.out_parameters: type = bind.type impl_type = type.dialect_impl(self.dialect) - dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi) + dbapi_type = impl_type.get_dbapi_type( + self.dialect.dbapi) result_processor = impl_type.\ - result_processor(self.dialect, - dbapi_type) + result_processor(self.dialect, + dbapi_type) if result_processor is not None: out_parameters[name] = \ - result_processor(self.out_parameters[name].getvalue()) + result_processor( + self.out_parameters[name].getvalue()) else: - out_parameters[name] = self.out_parameters[name].getvalue() + out_parameters[name] = self.out_parameters[ + name].getvalue() else: result.out_parameters = dict( (k, v.getvalue()) - for k, v in self.out_parameters.items() - ) + for k, v in self.out_parameters.items() + ) return result @@ -574,13 +579,14 @@ class OracleExecutionContext_cx_oracle_with_unicode(OracleExecutionContext_cx_or passed as Python unicode objects. """ + def __init__(self, *arg, **kw): OracleExecutionContext_cx_oracle.__init__(self, *arg, **kw) self.statement = util.text_type(self.statement) def _execute_scalar(self, stmt): return super(OracleExecutionContext_cx_oracle_with_unicode, self).\ - _execute_scalar(util.text_type(stmt)) + _execute_scalar(util.text_type(stmt)) class ReturningResultProxy(_result.FullyBufferedResultProxy): @@ -599,7 +605,7 @@ class ReturningResultProxy(_result.FullyBufferedResultProxy): def _buffer_rows(self): return collections.deque([tuple(self._returning_params["ret_%d" % i] - for i, c in enumerate(self._returning_params))]) + for i, c in enumerate(self._returning_params))]) class OracleDialect_cx_oracle(OracleDialect): @@ -610,7 +616,8 @@ class OracleDialect_cx_oracle(OracleDialect): colspecs = colspecs = { sqltypes.Numeric: _OracleNumeric, - sqltypes.Date: _OracleDate, # generic type, assume datetime.date is desired + # generic type, assume datetime.date is desired + sqltypes.Date: _OracleDate, sqltypes.LargeBinary: _OracleBinary, sqltypes.Boolean: oracle._OracleBoolean, sqltypes.Interval: _OracleInterval, @@ -637,50 +644,50 @@ class OracleDialect_cx_oracle(OracleDialect): execute_sequence_format = list def __init__(self, - auto_setinputsizes=True, - exclude_setinputsizes=("STRING", "UNICODE"), - auto_convert_lobs=True, - threaded=True, - allow_twophase=True, - coerce_to_decimal=True, - coerce_to_unicode=False, - arraysize=50, **kwargs): + auto_setinputsizes=True, + exclude_setinputsizes=("STRING", "UNICODE"), + auto_convert_lobs=True, + threaded=True, + allow_twophase=True, + coerce_to_decimal=True, + coerce_to_unicode=False, + arraysize=50, **kwargs): OracleDialect.__init__(self, **kwargs) self.threaded = threaded self.arraysize = arraysize self.allow_twophase = allow_twophase self.supports_timestamp = self.dbapi is None or \ - hasattr(self.dbapi, 'TIMESTAMP') + hasattr(self.dbapi, 'TIMESTAMP') self.auto_setinputsizes = auto_setinputsizes self.auto_convert_lobs = auto_convert_lobs if hasattr(self.dbapi, 'version'): self.cx_oracle_ver = tuple([int(x) for x in - self.dbapi.version.split('.')]) + self.dbapi.version.split('.')]) else: self.cx_oracle_ver = (0, 0, 0) def types(*names): return set( - getattr(self.dbapi, name, None) for name in names - ).difference([None]) + getattr(self.dbapi, name, None) for name in names + ).difference([None]) self.exclude_setinputsizes = types(*(exclude_setinputsizes or ())) self._cx_oracle_string_types = types("STRING", "UNICODE", - "NCLOB", "CLOB") + "NCLOB", "CLOB") self._cx_oracle_unicode_types = types("UNICODE", "NCLOB") self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB") self.supports_unicode_binds = self.cx_oracle_ver >= (5, 0) self.coerce_to_unicode = ( - self.cx_oracle_ver >= (5, 0) and - coerce_to_unicode - ) + self.cx_oracle_ver >= (5, 0) and + coerce_to_unicode + ) self.supports_native_decimal = ( - self.cx_oracle_ver >= (5, 0) and - coerce_to_decimal - ) + self.cx_oracle_ver >= (5, 0) and + coerce_to_decimal + ) self._cx_oracle_native_nvarchar = self.cx_oracle_ver >= (5, 0) @@ -708,15 +715,15 @@ class OracleDialect_cx_oracle(OracleDialect): "or otherwise be passed as Python unicode. " "Plain Python strings passed as bind parameters will be " "silently corrupted by cx_Oracle." - ) + ) self.execution_ctx_cls = \ - OracleExecutionContext_cx_oracle_with_unicode + OracleExecutionContext_cx_oracle_with_unicode else: self._cx_oracle_with_unicode = False if self.cx_oracle_ver is None or \ - not self.auto_convert_lobs or \ - not hasattr(self.dbapi, 'CLOB'): + not self.auto_convert_lobs or \ + not hasattr(self.dbapi, 'CLOB'): self.dbapi_type_map = {} else: # only use this for LOB objects. using it for strings, dates @@ -764,8 +771,8 @@ class OracleDialect_cx_oracle(OracleDialect): def output_type_handler(cursor, name, defaultType, size, precision, scale): return cursor.var( - cx_Oracle.STRING, - 255, arraysize=cursor.arraysize) + cx_Oracle.STRING, + 255, arraysize=cursor.arraysize) cursor = conn.cursor() cursor.outputtypehandler = output_type_handler @@ -796,17 +803,17 @@ class OracleDialect_cx_oracle(OracleDialect): cx_Oracle = self.dbapi def output_type_handler(cursor, name, defaultType, - size, precision, scale): + size, precision, scale): # convert all NUMBER with precision + positive scale to Decimal # this almost allows "native decimal" mode. if self.supports_native_decimal and \ defaultType == cx_Oracle.NUMBER and \ precision and scale > 0: return cursor.var( - cx_Oracle.STRING, - 255, - outconverter=self._to_decimal, - arraysize=cursor.arraysize) + cx_Oracle.STRING, + 255, + outconverter=self._to_decimal, + arraysize=cursor.arraysize) # if NUMBER with zero precision and 0 or neg scale, this appears # to indicate "ambiguous". Use a slower converter that will # make a decision based on each value received - the type @@ -816,10 +823,10 @@ class OracleDialect_cx_oracle(OracleDialect): defaultType == cx_Oracle.NUMBER \ and not precision and scale <= 0: return cursor.var( - cx_Oracle.STRING, - 255, - outconverter=self._detect_decimal, - arraysize=cursor.arraysize) + cx_Oracle.STRING, + 255, + outconverter=self._detect_decimal, + arraysize=cursor.arraysize) # allow all strings to come back natively as Unicode elif self.coerce_to_unicode and \ defaultType in (cx_Oracle.STRING, cx_Oracle.FIXED_CHAR): @@ -856,7 +863,7 @@ class OracleDialect_cx_oracle(OracleDialect): dsn=dsn, threaded=self.threaded, twophase=self.allow_twophase, - ) + ) if util.py2k: if self._cx_oracle_with_unicode: @@ -882,9 +889,9 @@ class OracleDialect_cx_oracle(OracleDialect): def _get_server_version_info(self, connection): return tuple( - int(x) - for x in connection.connection.version.split('.') - ) + int(x) + for x in connection.connection.version.split('.') + ) def is_disconnect(self, e, connection, cursor): error, = e.args @@ -924,11 +931,11 @@ class OracleDialect_cx_oracle(OracleDialect): connection.info['cx_oracle_prepared'] = result def do_rollback_twophase(self, connection, xid, is_prepared=True, - recover=False): + recover=False): self.do_rollback(connection.connection) def do_commit_twophase(self, connection, xid, is_prepared=True, - recover=False): + recover=False): if not is_prepared: self.do_commit(connection.connection) else: diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py index 19a668a3e..b3bae1ca0 100644 --- a/lib/sqlalchemy/dialects/oracle/zxjdbc.py +++ b/lib/sqlalchemy/dialects/oracle/zxjdbc.py @@ -40,7 +40,7 @@ class _ZxJDBCDate(sqltypes.Date): class _ZxJDBCNumeric(sqltypes.Numeric): def result_processor(self, dialect, coltype): - #XXX: does the dialect return Decimal or not??? + # XXX: does the dialect return Decimal or not??? # if it does (in all cases), we could use a None processor as well as # the to_float generic processor if self.asdecimal: @@ -61,7 +61,8 @@ class _ZxJDBCNumeric(sqltypes.Numeric): class OracleCompiler_zxjdbc(OracleCompiler): def returning_clause(self, stmt, returning_cols): - self.returning_cols = list(expression._select_iterables(returning_cols)) + self.returning_cols = list( + expression._select_iterables(returning_cols)) # within_columns_clause=False so that labels (foo AS bar) don't render columns = [self.process(c, within_columns_clause=False, result_map=self.result_map) @@ -72,12 +73,15 @@ class OracleCompiler_zxjdbc(OracleCompiler): binds = [] for i, col in enumerate(self.returning_cols): - dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + dbtype = col.type.dialect_impl( + self.dialect).get_dbapi_type(self.dialect.dbapi) self.returning_parameters.append((i + 1, dbtype)) - bindparam = sql.bindparam("ret_%d" % i, value=ReturningParam(dbtype)) + bindparam = sql.bindparam( + "ret_%d" % i, value=ReturningParam(dbtype)) self.binds[bindparam.key] = bindparam - binds.append(self.bindparam_string(self._truncate_bindparam(bindparam))) + binds.append( + self.bindparam_string(self._truncate_bindparam(bindparam))) return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) @@ -98,7 +102,8 @@ class OracleExecutionContext_zxjdbc(OracleExecutionContext): rrs = self.statement.__statement__.getReturnResultSet() next(rrs) except SQLException as sqle: - msg = '%s [SQLCode: %d]' % (sqle.getMessage(), sqle.getErrorCode()) + msg = '%s [SQLCode: %d]' % ( + sqle.getMessage(), sqle.getErrorCode()) if sqle.getSQLState() is not None: msg += ' [SQLState: %s]' % sqle.getSQLState() raise zxJDBC.Error(msg) @@ -213,7 +218,8 @@ class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect): return 'jdbc:oracle:thin:@%s:%s:%s' % (url.host, url.port or 1521, url.database) def _get_server_version_info(self, connection): - version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1) + version = re.search( + r'Release ([\d\.]+)', connection.connection.dbversion).group(1) return tuple(int(x) for x in version.split('.')) dialect = OracleDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/postgres.py b/lib/sqlalchemy/dialects/postgres.py index 046be760d..f9725b2a5 100644 --- a/lib/sqlalchemy/dialects/postgres.py +++ b/lib/sqlalchemy/dialects/postgres.py @@ -11,7 +11,7 @@ from sqlalchemy.util import warn_deprecated warn_deprecated( "The SQLAlchemy PostgreSQL dialect has been renamed from 'postgres' to 'postgresql'. " "The new URL format is postgresql[+driver]://<user>:<pass>@<host>/<dbname>" - ) +) from sqlalchemy.dialects.postgresql import * from sqlalchemy.dialects.postgresql import base diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 2ae71c2a7..c033a792d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -402,26 +402,26 @@ except ImportError: _python_UUID = None from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \ - CHAR, TEXT, FLOAT, NUMERIC, \ - DATE, BOOLEAN, REAL + CHAR, TEXT, FLOAT, NUMERIC, \ + DATE, BOOLEAN, REAL RESERVED_WORDS = set( ["all", "analyse", "analyze", "and", "any", "array", "as", "asc", - "asymmetric", "both", "case", "cast", "check", "collate", "column", - "constraint", "create", "current_catalog", "current_date", - "current_role", "current_time", "current_timestamp", "current_user", - "default", "deferrable", "desc", "distinct", "do", "else", "end", - "except", "false", "fetch", "for", "foreign", "from", "grant", "group", - "having", "in", "initially", "intersect", "into", "leading", "limit", - "localtime", "localtimestamp", "new", "not", "null", "of", "off", "offset", - "old", "on", "only", "or", "order", "placing", "primary", "references", - "returning", "select", "session_user", "some", "symmetric", "table", - "then", "to", "trailing", "true", "union", "unique", "user", "using", - "variadic", "when", "where", "window", "with", "authorization", - "between", "binary", "cross", "current_schema", "freeze", "full", - "ilike", "inner", "is", "isnull", "join", "left", "like", "natural", - "notnull", "outer", "over", "overlaps", "right", "similar", "verbose" - ]) + "asymmetric", "both", "case", "cast", "check", "collate", "column", + "constraint", "create", "current_catalog", "current_date", + "current_role", "current_time", "current_timestamp", "current_user", + "default", "deferrable", "desc", "distinct", "do", "else", "end", + "except", "false", "fetch", "for", "foreign", "from", "grant", "group", + "having", "in", "initially", "intersect", "into", "leading", "limit", + "localtime", "localtimestamp", "new", "not", "null", "of", "off", "offset", + "old", "on", "only", "or", "order", "placing", "primary", "references", + "returning", "select", "session_user", "some", "symmetric", "table", + "then", "to", "trailing", "true", "union", "unique", "user", "using", + "variadic", "when", "where", "window", "with", "authorization", + "between", "binary", "cross", "current_schema", "freeze", "full", + "ilike", "inner", "is", "isnull", "join", "left", "like", "natural", + "notnull", "outer", "over", "overlaps", "right", "similar", "verbose" + ]) _DECIMAL_TYPES = (1231, 1700) _FLOAT_TYPES = (700, 701, 1021, 1022) @@ -560,6 +560,7 @@ class UUID(sqltypes.TypeEngine): PGUuid = UUID + class TSVECTOR(sqltypes.TypeEngine): """The :class:`.postgresql.TSVECTOR` type implements the Postgresql text search type TSVECTOR. @@ -577,18 +578,17 @@ class TSVECTOR(sqltypes.TypeEngine): __visit_name__ = 'TSVECTOR' - class _Slice(expression.ColumnElement): __visit_name__ = 'slice' type = sqltypes.NULLTYPE def __init__(self, slice_, source_comparator): self.start = source_comparator._check_literal( - source_comparator.expr, - operators.getitem, slice_.start) + source_comparator.expr, + operators.getitem, slice_.start) self.stop = source_comparator._check_literal( - source_comparator.expr, - operators.getitem, slice_.stop) + source_comparator.expr, + operators.getitem, slice_.stop) class Any(expression.ColumnElement): @@ -673,7 +673,7 @@ class array(expression.Tuple): def _bind_param(self, operator, obj): return array(*[ expression.BindParameter(None, o, _compared_to_operator=operator, - _compared_to_type=self.type, unique=True) + _compared_to_type=self.type, unique=True) for o in obj ]) @@ -775,7 +775,7 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine): return_type = self.type.item_type return self._binary_operate(self.expr, operators.getitem, index, - result_type=return_type) + result_type=return_type) def any(self, other, operator=operators.eq): """Return ``other operator ANY (array)`` clause. @@ -902,7 +902,7 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine): """ if isinstance(item_type, ARRAY): raise ValueError("Do not nest ARRAY types; ARRAY(basetype) " - "handles multi-dimensional arrays of basetype") + "handles multi-dimensional arrays of basetype") if isinstance(item_type, type): item_type = item_type() self.item_type = item_type @@ -921,53 +921,53 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine): if dim is None: arr = list(arr) if dim == 1 or dim is None and ( - # this has to be (list, tuple), or at least - # not hasattr('__iter__'), since Py3K strings - # etc. have __iter__ - not arr or not isinstance(arr[0], (list, tuple))): + # this has to be (list, tuple), or at least + # not hasattr('__iter__'), since Py3K strings + # etc. have __iter__ + not arr or not isinstance(arr[0], (list, tuple))): if itemproc: return collection(itemproc(x) for x in arr) else: return collection(arr) else: return collection( - self._proc_array( - x, itemproc, - dim - 1 if dim is not None else None, - collection) - for x in arr - ) + self._proc_array( + x, itemproc, + dim - 1 if dim is not None else None, + collection) + for x in arr + ) def bind_processor(self, dialect): item_proc = self.item_type.\ - dialect_impl(dialect).\ - bind_processor(dialect) + dialect_impl(dialect).\ + bind_processor(dialect) def process(value): if value is None: return value else: return self._proc_array( - value, - item_proc, - self.dimensions, - list) + value, + item_proc, + self.dimensions, + list) return process def result_processor(self, dialect, coltype): item_proc = self.item_type.\ - dialect_impl(dialect).\ - result_processor(dialect, coltype) + dialect_impl(dialect).\ + result_processor(dialect, coltype) def process(value): if value is None: return value else: return self._proc_array( - value, - item_proc, - self.dimensions, - tuple if self.as_tuple else list) + value, + item_proc, + self.dimensions, + tuple if self.as_tuple else list) return process PGArray = ARRAY @@ -1047,7 +1047,7 @@ class ENUM(sqltypes.Enum): return if not checkfirst or \ - not bind.dialect.has_type(bind, self.name, schema=self.schema): + not bind.dialect.has_type(bind, self.name, schema=self.schema): bind.execute(CreateEnumType(self)) def drop(self, bind=None, checkfirst=True): @@ -1069,7 +1069,7 @@ class ENUM(sqltypes.Enum): return if not checkfirst or \ - bind.dialect.has_type(bind, self.name, schema=self.schema): + bind.dialect.has_type(bind, self.name, schema=self.schema): bind.execute(DropEnumType(self)) def _check_for_name_in_memos(self, checkfirst, kw): @@ -1101,7 +1101,7 @@ class ENUM(sqltypes.Enum): def _on_metadata_create(self, target, bind, checkfirst, **kw): if self.metadata is not None and \ - not self._check_for_name_in_memos(checkfirst, kw): + not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) def _on_metadata_drop(self, target, bind, checkfirst, **kw): @@ -1145,7 +1145,7 @@ ischema_names = { 'interval': INTERVAL, 'interval year to month': INTERVAL, 'interval day to second': INTERVAL, - 'tsvector' : TSVECTOR + 'tsvector': TSVECTOR } @@ -1156,9 +1156,9 @@ class PGCompiler(compiler.SQLCompiler): def visit_slice(self, element, **kw): return "%s:%s" % ( - self.process(element.start, **kw), - self.process(element.stop, **kw), - ) + self.process(element.start, **kw), + self.process(element.stop, **kw), + ) def visit_any(self, element, **kw): return "%s%sANY (%s)" % ( @@ -1182,7 +1182,7 @@ class PGCompiler(compiler.SQLCompiler): def visit_match_op_binary(self, binary, operator, **kw): if "postgresql_regconfig" in binary.modifiers: - regconfig = self.render_literal_value(\ + regconfig = self.render_literal_value( binary.modifiers['postgresql_regconfig'], sqltypes.STRINGTYPE) if regconfig: @@ -1200,8 +1200,8 @@ class PGCompiler(compiler.SQLCompiler): escape = binary.modifiers.get("escape", None) return '%s ILIKE %s' % \ - (self.process(binary.left, **kw), - self.process(binary.right, **kw)) \ + (self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + ( ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) @@ -1211,8 +1211,8 @@ class PGCompiler(compiler.SQLCompiler): def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return '%s NOT ILIKE %s' % \ - (self.process(binary.left, **kw), - self.process(binary.right, **kw)) \ + (self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + ( ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) @@ -1266,12 +1266,12 @@ class PGCompiler(compiler.SQLCompiler): if select._for_update_arg.of: tables = util.OrderedSet( - c.table if isinstance(c, expression.ColumnClause) - else c for c in select._for_update_arg.of) + c.table if isinstance(c, expression.ColumnClause) + else c for c in select._for_update_arg.of) tmp += " OF " + ", ".join( - self.process(table, ashint=True) - for table in tables - ) + self.process(table, ashint=True) + for table in tables + ) if select._for_update_arg.nowait: tmp += " NOWAIT" @@ -1281,13 +1281,12 @@ class PGCompiler(compiler.SQLCompiler): def returning_clause(self, stmt, returning_cols): columns = [ - self._label_select_column(None, c, True, False, {}) - for c in expression._select_iterables(returning_cols) - ] + self._label_select_column(None, c, True, False, {}) + for c in expression._select_iterables(returning_cols) + ] return 'RETURNING ' + ', '.join(columns) - def visit_substring_func(self, func, **kw): s = self.process(func.clauses.clauses[0], **kw) start = self.process(func.clauses.clauses[1], **kw) @@ -1297,6 +1296,7 @@ class PGCompiler(compiler.SQLCompiler): else: return "SUBSTRING(%s FROM %s)" % (s, start) + class PGDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): @@ -1336,7 +1336,7 @@ class PGDDLCompiler(compiler.DDLCompiler): self.preparer.format_type(type_), ", ".join( self.sql_compiler.process(sql.literal(e), literal_binds=True) - for e in type_.enums) + for e in type_.enums) ) def visit_drop_enum_type(self, drop): @@ -1354,10 +1354,10 @@ class PGDDLCompiler(compiler.DDLCompiler): if index.unique: text += "UNIQUE " text += "INDEX %s ON %s " % ( - self._prepared_index_name(index, - include_schema=False), - preparer.format_table(index.table) - ) + self._prepared_index_name(index, + include_schema=False), + preparer.format_table(index.table) + ) using = index.dialect_options['postgresql']['using'] if using: @@ -1368,20 +1368,20 @@ class PGDDLCompiler(compiler.DDLCompiler): % ( ', '.join([ self.sql_compiler.process( - expr.self_group() - if not isinstance(expr, expression.ColumnClause) - else expr, - include_table=False, literal_binds=True) + + expr.self_group() + if not isinstance(expr, expression.ColumnClause) + else expr, + include_table=False, literal_binds=True) + (c.key in ops and (' ' + ops[c.key]) or '') for expr, c in zip(index.expressions, index.columns)]) - ) + ) whereclause = index.dialect_options["postgresql"]["where"] if whereclause is not None: where_compiled = self.sql_compiler.process( - whereclause, include_table=False, - literal_binds=True) + whereclause, include_table=False, + literal_binds=True) text += " WHERE " + where_compiled return text @@ -1393,12 +1393,13 @@ class PGDDLCompiler(compiler.DDLCompiler): elements = [] for c in constraint.columns: op = constraint.operators[c.name] - elements.append(self.preparer.quote(c.name) + ' WITH '+op) - text += "EXCLUDE USING %s (%s)" % (constraint.using, ', '.join(elements)) + elements.append(self.preparer.quote(c.name) + ' WITH ' + op) + text += "EXCLUDE USING %s (%s)" % (constraint.using, + ', '.join(elements)) if constraint.where is not None: text += ' WHERE (%s)' % self.sql_compiler.process( - constraint.where, - literal_binds=True) + constraint.where, + literal_binds=True) text += self.define_constraint_deferrability(constraint) return text @@ -1510,8 +1511,8 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_ARRAY(self, type_): return self.process(type_.item_type) + ('[]' * (type_.dimensions - if type_.dimensions - is not None else 1)) + if type_.dimensions + is not None else 1)) class PGIdentifierPreparer(compiler.IdentifierPreparer): @@ -1521,7 +1522,7 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): def _unquote_identifier(self, value): if value[0] == self.initial_quote: value = value[1:-1].\ - replace(self.escape_to_quote, self.escape_quote) + replace(self.escape_to_quote, self.escape_quote) return value def format_type(self, type_, use_schema=True): @@ -1556,8 +1557,8 @@ class DropEnumType(schema._CreateDropBase): class PGExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): - return self._execute_scalar(("select nextval('%s')" % \ - self.dialect.identifier_preparer.format_sequence(seq)), type_) + return self._execute_scalar(("select nextval('%s')" % + self.dialect.identifier_preparer.format_sequence(seq)), type_) def get_insert_default(self, column): if column.primary_key and column is column.table._autoincrement_column: @@ -1565,11 +1566,11 @@ class PGExecutionContext(default.DefaultExecutionContext): # pre-execute passive defaults on primary key columns return self._execute_scalar("select %s" % - column.server_default.arg, column.type) + column.server_default.arg, column.type) elif (column.default is None or - (column.default.is_sequence and - column.default.optional)): + (column.default.is_sequence and + column.default.optional)): # execute the sequence associated with a SERIAL primary # key column. for non-primary-key SERIAL, the ID just @@ -1588,10 +1589,10 @@ class PGExecutionContext(default.DefaultExecutionContext): sch = column.table.schema if sch is not None: exc = "select nextval('\"%s\".\"%s\"')" % \ - (sch, seq_name) + (sch, seq_name) else: exc = "select nextval('\"%s\"')" % \ - (seq_name, ) + (seq_name, ) return self._execute_scalar(exc, column.type) @@ -1644,7 +1645,7 @@ class PGDialect(default.DefaultDialect): _backslash_escapes = True def __init__(self, isolation_level=None, json_serializer=None, - json_deserializer=None, **kwargs): + json_deserializer=None, **kwargs): default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level self._json_deserializer = json_deserializer @@ -1653,7 +1654,7 @@ class PGDialect(default.DefaultDialect): def initialize(self, connection): super(PGDialect, self).initialize(connection) self.implicit_returning = self.server_version_info > (8, 2) and \ - self.__dict__.get('implicit_returning', True) + self.__dict__.get('implicit_returning', True) self.supports_native_enum = self.server_version_info >= (8, 3) if not self.supports_native_enum: self.colspecs = self.colspecs.copy() @@ -1666,9 +1667,9 @@ class PGDialect(default.DefaultDialect): self.supports_smallserial = self.server_version_info >= (9, 2) self._backslash_escapes = self.server_version_info < (8, 2) or \ - connection.scalar( - "show standard_conforming_strings" - ) == 'off' + connection.scalar( + "show standard_conforming_strings" + ) == 'off' def on_connect(self): if self.isolation_level is not None: @@ -1679,7 +1680,7 @@ class PGDialect(default.DefaultDialect): return None _isolation_lookup = set(['SERIALIZABLE', - 'READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ']) + 'READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ']) def set_isolation_level(self, connection, level): level = level.replace('_', ' ') @@ -1688,7 +1689,7 @@ class PGDialect(default.DefaultDialect): "Invalid value '%s' for isolation_level. " "Valid isolation levels for %s are %s" % (level, self.name, ", ".join(self._isolation_lookup)) - ) + ) cursor = connection.cursor() cursor.execute( "SET SESSION CHARACTERISTICS AS TRANSACTION " @@ -1710,10 +1711,10 @@ class PGDialect(default.DefaultDialect): connection.execute("PREPARE TRANSACTION '%s'" % xid) def do_rollback_twophase(self, connection, xid, - is_prepared=True, recover=False): + is_prepared=True, recover=False): if is_prepared: if recover: - #FIXME: ugly hack to get out of transaction + # FIXME: ugly hack to get out of transaction # context when committing recoverable transactions # Must find out a way how to make the dbapi not # open a transaction. @@ -1725,7 +1726,7 @@ class PGDialect(default.DefaultDialect): self.do_rollback(connection.connection) def do_commit_twophase(self, connection, xid, - is_prepared=True, recover=False): + is_prepared=True, recover=False): if is_prepared: if recover: connection.execute("ROLLBACK") @@ -1737,7 +1738,7 @@ class PGDialect(default.DefaultDialect): def do_recover_twophase(self, connection): resultset = connection.execute( - sql.text("SELECT gid FROM pg_prepared_xacts")) + sql.text("SELECT gid FROM pg_prepared_xacts")) return [row[0] for row in resultset] def _get_default_schema_name(self, connection): @@ -1762,25 +1763,25 @@ class PGDialect(default.DefaultDialect): if schema is None: cursor = connection.execute( sql.text( - "select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where n.nspname=current_schema() and " - "relname=:name", - bindparams=[ + "select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where n.nspname=current_schema() and " + "relname=:name", + bindparams=[ sql.bindparam('name', util.text_type(table_name), - type_=sqltypes.Unicode)] + type_=sqltypes.Unicode)] ) ) else: cursor = connection.execute( sql.text( - "select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where n.nspname=:schema and " - "relname=:name", + "select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where n.nspname=:schema and " + "relname=:name", bindparams=[ sql.bindparam('name', - util.text_type(table_name), type_=sqltypes.Unicode), + util.text_type(table_name), type_=sqltypes.Unicode), sql.bindparam('schema', - util.text_type(schema), type_=sqltypes.Unicode)] + util.text_type(schema), type_=sqltypes.Unicode)] ) ) return bool(cursor.first()) @@ -1795,23 +1796,23 @@ class PGDialect(default.DefaultDialect): "and relname=:name", bindparams=[ sql.bindparam('name', util.text_type(sequence_name), - type_=sqltypes.Unicode) + type_=sqltypes.Unicode) ] ) ) else: cursor = connection.execute( sql.text( - "SELECT relname FROM pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where relkind='S' and " - "n.nspname=:schema and relname=:name", - bindparams=[ - sql.bindparam('name', util.text_type(sequence_name), - type_=sqltypes.Unicode), - sql.bindparam('schema', - util.text_type(schema), type_=sqltypes.Unicode) - ] - ) + "SELECT relname FROM pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where relkind='S' and " + "n.nspname=:schema and relname=:name", + bindparams=[ + sql.bindparam('name', util.text_type(sequence_name), + type_=sqltypes.Unicode), + sql.bindparam('schema', + util.text_type(schema), type_=sqltypes.Unicode) + ] + ) ) return bool(cursor.first()) @@ -1837,14 +1838,14 @@ class PGDialect(default.DefaultDialect): """ query = sql.text(query) query = query.bindparams( - sql.bindparam('typname', - util.text_type(type_name), type_=sqltypes.Unicode), - ) + sql.bindparam('typname', + util.text_type(type_name), type_=sqltypes.Unicode), + ) if schema is not None: query = query.bindparams( - sql.bindparam('nspname', - util.text_type(schema), type_=sqltypes.Unicode), - ) + sql.bindparam('nspname', + util.text_type(schema), type_=sqltypes.Unicode), + ) cursor = connection.execute(query) return bool(cursor.scalar()) @@ -1856,7 +1857,7 @@ class PGDialect(default.DefaultDialect): v) if not m: raise AssertionError( - "Could not determine version from string '%s'" % v) + "Could not determine version from string '%s'" % v) return tuple([int(x) for x in m.group(1, 2, 3) if x is not None]) @reflection.cache @@ -1906,11 +1907,11 @@ class PGDialect(default.DefaultDialect): # what about system tables? if util.py2k: - schema_names = [row[0].decode(self.encoding) for row in rp \ - if not row[0].startswith('pg_')] + schema_names = [row[0].decode(self.encoding) for row in rp + if not row[0].startswith('pg_')] else: - schema_names = [row[0] for row in rp \ - if not row[0].startswith('pg_')] + schema_names = [row[0] for row in rp + if not row[0].startswith('pg_')] return schema_names @reflection.cache @@ -1922,12 +1923,12 @@ class PGDialect(default.DefaultDialect): result = connection.execute( sql.text("SELECT relname FROM pg_class c " - "WHERE relkind = 'r' " - "AND '%s' = (select nspname from pg_namespace n " - "where n.oid = c.relnamespace) " % - current_schema, - typemap={'relname': sqltypes.Unicode} - ) + "WHERE relkind = 'r' " + "AND '%s' = (select nspname from pg_namespace n " + "where n.oid = c.relnamespace) " % + current_schema, + typemap={'relname': sqltypes.Unicode} + ) ) return [row[0] for row in result] @@ -1947,7 +1948,7 @@ class PGDialect(default.DefaultDialect): if util.py2k: view_names = [row[0].decode(self.encoding) - for row in connection.execute(s)] + for row in connection.execute(s)] else: view_names = [row[0] for row in connection.execute(s)] return view_names @@ -1992,9 +1993,11 @@ class PGDialect(default.DefaultDialect): ORDER BY a.attnum """ s = sql.text(SQL_COLS, - bindparams=[sql.bindparam('table_oid', type_=sqltypes.Integer)], - typemap={'attname': sqltypes.Unicode, 'default': sqltypes.Unicode} - ) + bindparams=[ + sql.bindparam('table_oid', type_=sqltypes.Integer)], + typemap={ + 'attname': sqltypes.Unicode, 'default': sqltypes.Unicode} + ) c = connection.execute(s, table_oid=table_oid) rows = c.fetchall() domains = self._load_domains(connection) @@ -2010,7 +2013,7 @@ class PGDialect(default.DefaultDialect): def _get_column_info(self, name, format_type, default, notnull, domains, enums, schema): - ## strip (*) from character varying(5), timestamp(5) + # strip (*) from character varying(5), timestamp(5) # with time zone, geometry(POLYGON), etc. attype = re.sub(r'\(.*\)', '', format_type) @@ -2058,7 +2061,7 @@ class PGDialect(default.DefaultDialect): else: args = () elif attype in ('interval', 'interval year to month', - 'interval day to second'): + 'interval day to second'): if charlen: kwargs['precision'] = int(charlen) args = () @@ -2113,8 +2116,8 @@ class PGDialect(default.DefaultDialect): # later be enhanced to obey quoting rules / # "quote schema" default = match.group(1) + \ - ('"%s"' % sch) + '.' + \ - match.group(2) + match.group(3) + ('"%s"' % sch) + '.' + \ + match.group(2) + match.group(3) column_info = dict(name=name, type=coltype, nullable=nullable, default=default, autoincrement=autoincrement) @@ -2170,7 +2173,7 @@ class PGDialect(default.DefaultDialect): @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, - postgresql_ignore_search_path=False, **kw): + postgresql_ignore_search_path=False, **kw): preparer = self.identifier_preparer table_oid = self.get_table_oid(connection, table_name, schema, info_cache=kw.get('info_cache')) @@ -2200,22 +2203,22 @@ class PGDialect(default.DefaultDialect): ) t = sql.text(FK_SQL, typemap={ - 'conname': sqltypes.Unicode, - 'condef': sqltypes.Unicode}) + 'conname': sqltypes.Unicode, + 'condef': sqltypes.Unicode}) c = connection.execute(t, table=table_oid) fkeys = [] for conname, condef, conschema in c.fetchall(): m = re.search(FK_REGEX, condef).groups() constrained_columns, referred_schema, \ - referred_table, referred_columns, \ - _, match, _, onupdate, _, ondelete, \ - deferrable, _, initially = m + referred_table, referred_columns, \ + _, match, _, onupdate, _, ondelete, \ + deferrable, _, initially = m if deferrable is not None: deferrable = True if deferrable == 'DEFERRABLE' else False constrained_columns = [preparer._unquote_identifier(x) - for x in re.split(r'\s*,\s*', constrained_columns)] + for x in re.split(r'\s*,\s*', constrained_columns)] if postgresql_ignore_search_path: # when ignoring search path, we use the actual schema @@ -2229,7 +2232,7 @@ class PGDialect(default.DefaultDialect): # pg_get_constraintdef(). If the schema is in the search # path, pg_get_constraintdef() will give us None. referred_schema = \ - preparer._unquote_identifier(referred_schema) + preparer._unquote_identifier(referred_schema) elif schema is not None and schema == conschema: # If the actual schema matches the schema of the table # we're reflecting, then we will use that. @@ -2237,7 +2240,7 @@ class PGDialect(default.DefaultDialect): referred_table = preparer._unquote_identifier(referred_table) referred_columns = [preparer._unquote_identifier(x) - for x in re.split(r'\s*,\s', referred_columns)] + for x in re.split(r'\s*,\s', referred_columns)] fkey_d = { 'name': conname, 'constrained_columns': constrained_columns, @@ -2264,9 +2267,9 @@ class PGDialect(default.DefaultDialect): # for now. # regards, tom lane" return "(%s)" % " OR ".join( - "%s[%d] = %s" % (compare_to, ind, col) - for ind in range(0, 10) - ) + "%s[%d] = %s" % (compare_to, ind, col) + for ind in range(0, 10) + ) else: return "%s = ANY(%s)" % (col, compare_to) @@ -2298,12 +2301,12 @@ class PGDialect(default.DefaultDialect): t.relname, i.relname """ % ( - # version 8.3 here was based on observing the - # cast does not work in PG 8.2.4, does work in 8.3.0. - # nothing in PG changelogs regarding this. - "::varchar" if self.server_version_info >= (8, 3) else "", - self._pg_index_any("a.attnum", "ix.indkey") - ) + # version 8.3 here was based on observing the + # cast does not work in PG 8.2.4, does work in 8.3.0. + # nothing in PG changelogs regarding this. + "::varchar" if self.server_version_info >= (8, 3) else "", + self._pg_index_any("a.attnum", "ix.indkey") + ) t = sql.text(IDX_SQL, typemap={'attname': sqltypes.Unicode}) c = connection.execute(t, table_oid=table_oid) @@ -2317,16 +2320,16 @@ class PGDialect(default.DefaultDialect): if expr: if idx_name != sv_idx_name: util.warn( - "Skipped unsupported reflection of " - "expression-based index %s" - % idx_name) + "Skipped unsupported reflection of " + "expression-based index %s" + % idx_name) sv_idx_name = idx_name continue if prd and not idx_name == sv_idx_name: util.warn( - "Predicate of partial index %s ignored during reflection" - % idx_name) + "Predicate of partial index %s ignored during reflection" + % idx_name) sv_idx_name = idx_name index = indexes[idx_name] @@ -2382,7 +2385,7 @@ class PGDialect(default.DefaultDialect): if not self.supports_native_enum: return {} - ## Load data types for enums: + # Load data types for enums: SQL_ENUMS = """ SELECT t.typname as "name", -- no enum defaults in 8.4 at least @@ -2398,8 +2401,8 @@ class PGDialect(default.DefaultDialect): """ s = sql.text(SQL_ENUMS, typemap={ - 'attname': sqltypes.Unicode, - 'label': sqltypes.Unicode}) + 'attname': sqltypes.Unicode, + 'label': sqltypes.Unicode}) c = connection.execute(s) enums = {} @@ -2417,13 +2420,13 @@ class PGDialect(default.DefaultDialect): enums[name]['labels'].append(enum['label']) else: enums[name] = { - 'labels': [enum['label']], - } + 'labels': [enum['label']], + } return enums def _load_domains(self, connection): - ## Load data types for domains: + # Load data types for domains: SQL_DOMAINS = """ SELECT t.typname as "name", pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype", @@ -2441,7 +2444,7 @@ class PGDialect(default.DefaultDialect): domains = {} for domain in c.fetchall(): - ## strip (30) from character varying(30) + # strip (30) from character varying(30) attype = re.search('([^\(]+)', domain['attype']).group(1) if domain['visible']: # 'visible' just means whether or not the domain is in a @@ -2453,9 +2456,9 @@ class PGDialect(default.DefaultDialect): name = "%s.%s" % (domain['schema'], domain['name']) domains[name] = { - 'attype': attype, - 'nullable': domain['nullable'], - 'default': domain['default'] - } + 'attype': attype, + 'nullable': domain['nullable'], + 'default': domain['default'] + } return domains diff --git a/lib/sqlalchemy/dialects/postgresql/constraints.py b/lib/sqlalchemy/dialects/postgresql/constraints.py index 2eed2fb36..02d7a8998 100644 --- a/lib/sqlalchemy/dialects/postgresql/constraints.py +++ b/lib/sqlalchemy/dialects/postgresql/constraints.py @@ -6,6 +6,7 @@ from sqlalchemy.schema import ColumnCollectionConstraint from sqlalchemy.sql import expression + class ExcludeConstraint(ColumnCollectionConstraint): """A table-level EXCLUDE constraint. @@ -52,7 +53,7 @@ class ExcludeConstraint(ColumnCollectionConstraint): name=kw.get('name'), deferrable=kw.get('deferrable'), initially=kw.get('initially') - ) + ) self.operators = {} for col_or_string, op in elements: name = getattr(col_or_string, 'name', col_or_string) @@ -60,15 +61,14 @@ class ExcludeConstraint(ColumnCollectionConstraint): self.using = kw.get('using', 'gist') where = kw.get('where') if where: - self.where = expression._literal_as_text(where) + self.where = expression._literal_as_text(where) def copy(self, **kw): elements = [(col, self.operators[col]) for col in self.columns.keys()] c = self.__class__(*elements, - name=self.name, - deferrable=self.deferrable, - initially=self.initially) + name=self.name, + deferrable=self.deferrable, + initially=self.initially) c.dispatch._update(self.dispatch) return c - diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py index f1fb3d308..8db55d6bc 100644 --- a/lib/sqlalchemy/dialects/postgresql/hstore.py +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -73,7 +73,8 @@ def _parse_hstore(hstore_str): if pair_match.group('value_null'): value = None else: - value = pair_match.group('value').replace(r'\"', '"').replace("\\\\", "\\") + value = pair_match.group('value').replace( + r'\"', '"').replace("\\\\", "\\") result[key] = value pos += pair_match.end() @@ -272,6 +273,7 @@ class HSTORE(sqltypes.Concatenable, sqltypes.TypeEngine): def bind_processor(self, dialect): if util.py2k: encoding = dialect.encoding + def process(value): if isinstance(value, dict): return _serialize_hstore(value).encode(encoding) @@ -288,6 +290,7 @@ class HSTORE(sqltypes.Concatenable, sqltypes.TypeEngine): def result_processor(self, dialect, coltype): if util.py2k: encoding = dialect.encoding + def process(value): if value is not None: return _parse_hstore(value.decode(encoding)) diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 902d0a80d..6e0c5a4b1 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -31,20 +31,23 @@ class JSONElement(elements.BinaryExpression): and :attr:`.JSONElement.astext`. """ + def __init__(self, left, right, astext=False, opstring=None, result_type=None): self._astext = astext if opstring is None: if hasattr(right, '__iter__') and \ - not isinstance(right, util.string_types): + not isinstance(right, util.string_types): opstring = "#>" - right = "{%s}" % (", ".join(util.text_type(elem) for elem in right)) + right = "{%s}" % ( + ", ".join(util.text_type(elem) for elem in right)) else: opstring = "->" self._json_opstring = opstring operator = custom_op(opstring, precedence=5) right = left._check_literal(left, operator, right) - super(JSONElement, self).__init__(left, right, operator, type_=result_type) + super(JSONElement, self).__init__( + left, right, operator, type_=result_type) @property def astext(self): @@ -64,12 +67,12 @@ class JSONElement(elements.BinaryExpression): return self else: return JSONElement( - self.left, - self.right, - astext=True, - opstring=self._json_opstring + ">", - result_type=sqltypes.String(convert_unicode=True) - ) + self.left, + self.right, + astext=True, + opstring=self._json_opstring + ">", + result_type=sqltypes.String(convert_unicode=True) + ) def cast(self, type_): """Convert this :class:`.JSONElement` to apply both the 'astext' operator @@ -178,6 +181,7 @@ class JSON(sqltypes.TypeEngine): json_serializer = dialect._json_serializer or json.dumps if util.py2k: encoding = dialect.encoding + def process(value): return json_serializer(value).encode(encoding) else: @@ -189,6 +193,7 @@ class JSON(sqltypes.TypeEngine): json_deserializer = dialect._json_deserializer or json.loads if util.py2k: encoding = dialect.encoding + def process(value): return json_deserializer(value.decode(encoding)) else: @@ -200,7 +205,6 @@ class JSON(sqltypes.TypeEngine): ischema_names['json'] = JSON - class JSONB(JSON): """Represent the Postgresql JSONB type. @@ -280,7 +284,8 @@ class JSONB(JSON): return JSONElement(self.expr, other) def _adapt_expression(self, op, other_comparator): - # How does one do equality?? jsonb also has "=" eg. '[1,2,3]'::jsonb = '[1,2,3]'::jsonb + # How does one do equality?? jsonb also has "=" eg. + # '[1,2,3]'::jsonb = '[1,2,3]'::jsonb if isinstance(op, custom_op): if op.opstring in ['?', '?&', '?|', '@>', '<@']: return op, sqltypes.Boolean @@ -317,4 +322,4 @@ class JSONB(JSON): """ return self.expr.op('<@')(other) -ischema_names['jsonb'] = JSONB
\ No newline at end of file +ischema_names['jsonb'] = JSONB diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index dc5ed6e73..512f3e1b0 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -165,6 +165,6 @@ class PGDialect_pg8000(PGDialect): "Invalid value '%s' for isolation_level. " "Valid isolation levels for %s are %s or AUTOCOMMIT" % (level, self.name, ", ".join(self._isolation_lookup)) - ) + ) dialect = PGDialect_pg8000 diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 0ab4abb09..b7971e8de 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -209,9 +209,9 @@ from ...engine import result as _result from ...sql import expression from ... import types as sqltypes from .base import PGDialect, PGCompiler, \ - PGIdentifierPreparer, PGExecutionContext, \ - ENUM, ARRAY, _DECIMAL_TYPES, _FLOAT_TYPES,\ - _INT_TYPES + PGIdentifierPreparer, PGExecutionContext, \ + ENUM, ARRAY, _DECIMAL_TYPES, _FLOAT_TYPES,\ + _INT_TYPES from .hstore import HSTORE from .json import JSON @@ -227,14 +227,14 @@ class _PGNumeric(sqltypes.Numeric): if self.asdecimal: if coltype in _FLOAT_TYPES: return processors.to_decimal_processor_factory( - decimal.Decimal, - self._effective_decimal_return_scale) + decimal.Decimal, + self._effective_decimal_return_scale) elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: # pg8000 returns Decimal natively for 1700 return None else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype) else: if coltype in _FLOAT_TYPES: # pg8000 returns float natively for 701 @@ -243,7 +243,7 @@ class _PGNumeric(sqltypes.Numeric): return processors.to_float else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype) class _PGEnum(ENUM): @@ -255,6 +255,7 @@ class _PGEnum(ENUM): self.convert_unicode = "force_nocheck" return super(_PGEnum, self).result_processor(dialect, coltype) + class _PGHStore(HSTORE): def bind_processor(self, dialect): if dialect._has_native_hstore: @@ -293,13 +294,13 @@ class PGExecutionContext_psycopg2(PGExecutionContext): if self.dialect.server_side_cursors: is_server_side = \ self.execution_options.get('stream_results', True) and ( - (self.compiled and isinstance(self.compiled.statement, expression.Selectable) \ - or \ - ( + (self.compiled and isinstance(self.compiled.statement, expression.Selectable) + or + ( (not self.compiled or - isinstance(self.compiled.statement, expression.TextClause)) + isinstance(self.compiled.statement, expression.TextClause)) and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement)) - ) + ) ) else: is_server_side = \ @@ -336,7 +337,7 @@ class PGExecutionContext_psycopg2(PGExecutionContext): class PGCompiler_psycopg2(PGCompiler): def visit_mod_binary(self, binary, operator, **kw): return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + self.process(binary.right, **kw) def post_process_text(self, text): return text.replace('%', '%%') @@ -354,7 +355,8 @@ class PGDialect_psycopg2(PGDialect): supports_unicode_statements = False default_paramstyle = 'pyformat' - supports_sane_multi_rowcount = False # set to true based on psycopg2 version + # set to true based on psycopg2 version + supports_sane_multi_rowcount = False execution_ctx_cls = PGExecutionContext_psycopg2 statement_compiler = PGCompiler_psycopg2 preparer = PGIdentifierPreparer_psycopg2 @@ -375,9 +377,9 @@ class PGDialect_psycopg2(PGDialect): ) def __init__(self, server_side_cursors=False, use_native_unicode=True, - client_encoding=None, - use_native_hstore=True, - **kwargs): + client_encoding=None, + use_native_hstore=True, + **kwargs): PGDialect.__init__(self, **kwargs) self.server_side_cursors = server_side_cursors self.use_native_unicode = use_native_unicode @@ -386,18 +388,18 @@ class PGDialect_psycopg2(PGDialect): self.client_encoding = client_encoding if self.dbapi and hasattr(self.dbapi, '__version__'): m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', - self.dbapi.__version__) + self.dbapi.__version__) if m: self.psycopg2_version = tuple( - int(x) - for x in m.group(1, 2, 3) - if x is not None) + int(x) + for x in m.group(1, 2, 3) + if x is not None) def initialize(self, connection): super(PGDialect_psycopg2, self).initialize(connection) self._has_native_hstore = self.use_native_hstore and \ - self._hstore_oids(connection.connection) \ - is not None + self._hstore_oids(connection.connection) \ + is not None self._has_native_json = self.psycopg2_version >= (2, 5) # http://initd.org/psycopg/docs/news.html#what-s-new-in-psycopg-2-0-9 @@ -427,7 +429,7 @@ class PGDialect_psycopg2(PGDialect): "Invalid value '%s' for isolation_level. " "Valid isolation levels for %s are %s" % (level, self.name, ", ".join(self._isolation_lookup)) - ) + ) connection.set_isolation_level(level) @@ -458,16 +460,17 @@ class PGDialect_psycopg2(PGDialect): oid, array_oid = hstore_oids if util.py2k: extras.register_hstore(conn, oid=oid, - array_oid=array_oid, - unicode=True) + array_oid=array_oid, + unicode=True) else: extras.register_hstore(conn, oid=oid, - array_oid=array_oid) + array_oid=array_oid) fns.append(on_connect) if self.dbapi and self._json_deserializer: def on_connect(conn): - extras.register_default_json(conn, loads=self._json_deserializer) + extras.register_default_json( + conn, loads=self._json_deserializer) fns.append(on_connect) if fns: diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 31434743c..28f80d000 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -9,6 +9,7 @@ from ... import types as sqltypes __all__ = ('INT4RANGE', 'INT8RANGE', 'NUMRANGE') + class RangeOperators(object): """ This mixin provides functionality for the Range Operators @@ -94,6 +95,7 @@ class RangeOperators(object): """ return self.expr.op('+')(other) + class INT4RANGE(RangeOperators, sqltypes.TypeEngine): """Represent the Postgresql INT4RANGE type. @@ -105,6 +107,7 @@ class INT4RANGE(RangeOperators, sqltypes.TypeEngine): ischema_names['int4range'] = INT4RANGE + class INT8RANGE(RangeOperators, sqltypes.TypeEngine): """Represent the Postgresql INT8RANGE type. @@ -116,6 +119,7 @@ class INT8RANGE(RangeOperators, sqltypes.TypeEngine): ischema_names['int8range'] = INT8RANGE + class NUMRANGE(RangeOperators, sqltypes.TypeEngine): """Represent the Postgresql NUMRANGE type. @@ -127,6 +131,7 @@ class NUMRANGE(RangeOperators, sqltypes.TypeEngine): ischema_names['numrange'] = NUMRANGE + class DATERANGE(RangeOperators, sqltypes.TypeEngine): """Represent the Postgresql DATERANGE type. @@ -138,6 +143,7 @@ class DATERANGE(RangeOperators, sqltypes.TypeEngine): ischema_names['daterange'] = DATERANGE + class TSRANGE(RangeOperators, sqltypes.TypeEngine): """Represent the Postgresql TSRANGE type. @@ -149,6 +155,7 @@ class TSRANGE(RangeOperators, sqltypes.TypeEngine): ischema_names['tsrange'] = TSRANGE + class TSTZRANGE(RangeOperators, sqltypes.TypeEngine): """Represent the Postgresql TSTZRANGE type. diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 8daada528..7455bf3fe 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -214,6 +214,7 @@ class _DateTimeMixin(object): def literal_processor(self, dialect): bp = self.bind_processor(dialect) + def process(value): return "'%s'" % bp(value) return process @@ -619,7 +620,7 @@ class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): 'temporary', 'then', 'to', 'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using', 'vacuum', 'values', 'view', 'virtual', 'when', 'where', - ]) + ]) def format_index(self, index, use_schema=True, name=None): """Prepare a quoted index and schema name.""" @@ -716,7 +717,7 @@ class SQLiteDialect(default.DefaultDialect): "Invalid value '%s' for isolation_level. " "Valid isolation levels for %s are %s" % (level, self.name, ", ".join(self._isolation_lookup)) - ) + ) cursor = connection.cursor() cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level) cursor.close() @@ -918,9 +919,9 @@ class SQLiteDialect(default.DefaultDialect): coltype = coltype(*[int(a) for a in args]) except TypeError: util.warn( - "Could not instantiate type %s with " - "reflected arguments %s; using no arguments." % - (coltype, args)) + "Could not instantiate type %s with " + "reflected arguments %s; using no arguments." % + (coltype, args)) coltype = coltype() else: coltype = coltype() diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 51e5f0cdf..54bc19763 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -331,6 +331,6 @@ class SQLiteDialect_pysqlite(SQLiteDialect): def is_disconnect(self, e, connection, cursor): return isinstance(e, self.dbapi.ProgrammingError) and \ - "Cannot operate on a closed database." in str(e) + "Cannot operate on a closed database." in str(e) dialect = SQLiteDialect_pysqlite diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py index a9263dc3f..eb313592b 100644 --- a/lib/sqlalchemy/dialects/sybase/__init__.py +++ b/lib/sqlalchemy/dialects/sybase/__init__.py @@ -11,11 +11,11 @@ from sqlalchemy.dialects.sybase import base, pysybase, pyodbc base.dialect = pyodbc.dialect from .base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\ - TEXT, DATE, DATETIME, FLOAT, NUMERIC,\ - BIGINT, INT, INTEGER, SMALLINT, BINARY,\ - VARBINARY, UNITEXT, UNICHAR, UNIVARCHAR,\ - IMAGE, BIT, MONEY, SMALLMONEY, TINYINT,\ - dialect + TEXT, DATE, DATETIME, FLOAT, NUMERIC,\ + BIGINT, INT, INTEGER, SMALLINT, BINARY,\ + VARBINARY, UNITEXT, UNICHAR, UNIVARCHAR,\ + IMAGE, BIT, MONEY, SMALLMONEY, TINYINT,\ + dialect __all__ = ( diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index 38f665838..713405e1b 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -32,10 +32,10 @@ from sqlalchemy import schema as sa_schema from sqlalchemy import util, sql, exc from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\ - TEXT, DATE, DATETIME, FLOAT, NUMERIC,\ - BIGINT, INT, INTEGER, SMALLINT, BINARY,\ - VARBINARY, DECIMAL, TIMESTAMP, Unicode,\ - UnicodeText, REAL + TEXT, DATE, DATETIME, FLOAT, NUMERIC,\ + BIGINT, INT, INTEGER, SMALLINT, BINARY,\ + VARBINARY, DECIMAL, TIMESTAMP, Unicode,\ + UnicodeText, REAL RESERVED_WORDS = set([ "add", "all", "alter", "and", @@ -94,7 +94,7 @@ RESERVED_WORDS = set([ "when", "where", "while", "window", "with", "with_cube", "with_lparen", "with_rollup", "within", "work", "writetext", - ]) +]) class _SybaseUnitypeMixin(object): @@ -225,7 +225,7 @@ ischema_names = { 'image': IMAGE, 'bit': BIT, -# not in documentation for ASE 15.7 + # not in documentation for ASE 15.7 'long varchar': TEXT, # TODO 'timestamp': TIMESTAMP, 'uniqueidentifier': UNIQUEIDENTIFIER, @@ -268,13 +268,13 @@ class SybaseExecutionContext(default.DefaultExecutionContext): if insert_has_sequence: self._enable_identity_insert = \ - seq_column.key in self.compiled_parameters[0] + seq_column.key in self.compiled_parameters[0] else: self._enable_identity_insert = False if self._enable_identity_insert: self.cursor.execute("SET IDENTITY_INSERT %s ON" % - self.dialect.identifier_preparer.format_table(tbl)) + self.dialect.identifier_preparer.format_table(tbl)) if self.isddl: # TODO: to enhance this, we can detect "ddl in tran" on the @@ -282,15 +282,15 @@ class SybaseExecutionContext(default.DefaultExecutionContext): # include a note about that. if not self.should_autocommit: raise exc.InvalidRequestError( - "The Sybase dialect only supports " - "DDL in 'autocommit' mode at this time.") + "The Sybase dialect only supports " + "DDL in 'autocommit' mode at this time.") self.root_connection.engine.logger.info( - "AUTOCOMMIT (Assuming no Sybase 'ddl in tran')") + "AUTOCOMMIT (Assuming no Sybase 'ddl in tran')") self.set_ddl_autocommit( - self.root_connection.connection.connection, - True) + self.root_connection.connection.connection, + True) def post_exec(self): if self.isddl: @@ -298,10 +298,10 @@ class SybaseExecutionContext(default.DefaultExecutionContext): if self._enable_identity_insert: self.cursor.execute( - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer. - format_table(self.compiled.statement.table) - ) + "SET IDENTITY_INSERT %s OFF" % + self.dialect.identifier_preparer. + format_table(self.compiled.statement.table) + ) def get_lastrowid(self): cursor = self.create_cursor() @@ -317,10 +317,10 @@ class SybaseSQLCompiler(compiler.SQLCompiler): extract_map = util.update_copy( compiler.SQLCompiler.extract_map, { - 'doy': 'dayofyear', - 'dow': 'weekday', - 'milliseconds': 'millisecond' - }) + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond' + }) def get_select_precolumns(self, select): s = select._distinct and "DISTINCT " or "" @@ -328,9 +328,9 @@ class SybaseSQLCompiler(compiler.SQLCompiler): # bind params for FIRST / TOP limit = select._limit if limit: - #if select._limit == 1: + # if select._limit == 1: #s += "FIRST " - #else: + # else: #s += "TOP %s " % (select._limit,) s += "TOP %s " % (limit,) offset = select._offset @@ -352,7 +352,7 @@ class SybaseSQLCompiler(compiler.SQLCompiler): def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) return 'DATEPART("%s", %s)' % ( - field, self.process(extract.expr, **kw)) + field, self.process(extract.expr, **kw)) def visit_now_func(self, fn, **kw): return "GETDATE()" @@ -376,21 +376,21 @@ class SybaseSQLCompiler(compiler.SQLCompiler): class SybaseDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process(column.type) + self.dialect.type_compiler.process(column.type) if column.table is None: raise exc.CompileError( - "The Sybase dialect requires Table-bound " - "columns in order to generate DDL") + "The Sybase dialect requires Table-bound " + "columns in order to generate DDL") seq_col = column.table._autoincrement_column # install a IDENTITY Sequence if we have an implicit IDENTITY column if seq_col is column: sequence = isinstance(column.default, sa_schema.Sequence) \ - and column.default + and column.default if sequence: start, increment = sequence.start or 1, \ - sequence.increment or 1 + sequence.increment or 1 else: start, increment = 1, 1 if (start, increment) == (1, 1): @@ -416,8 +416,8 @@ class SybaseDDLCompiler(compiler.DDLCompiler): return "\nDROP INDEX %s.%s" % ( self.preparer.quote_identifier(index.table.name), self._prepared_index_name(drop.element, - include_schema=False) - ) + include_schema=False) + ) class SybaseIdentifierPreparer(compiler.IdentifierPreparer): @@ -447,14 +447,14 @@ class SybaseDialect(default.DefaultDialect): def _get_default_schema_name(self, connection): return connection.scalar( - text("SELECT user_name() as user_name", - typemap={'user_name': Unicode}) - ) + text("SELECT user_name() as user_name", + typemap={'user_name': Unicode}) + ) def initialize(self, connection): super(SybaseDialect, self).initialize(connection) if self.server_version_info is not None and\ - self.server_version_info < (15, ): + self.server_version_info < (15, ): self.max_identifier_length = 30 else: self.max_identifier_length = 255 @@ -520,14 +520,15 @@ class SybaseDialect(default.DefaultDialect): for (name, type_, nullable, autoincrement, default, precision, scale, length) in results: col_info = self._get_column_info(name, type_, bool(nullable), - bool(autoincrement), default, precision, scale, - length) + bool( + autoincrement), default, precision, scale, + length) columns.append(col_info) return columns def _get_column_info(self, name, type_, nullable, autoincrement, default, - precision, scale, length): + precision, scale, length): coltype = self.ischema_names.get(type_, None) @@ -544,8 +545,8 @@ class SybaseDialect(default.DefaultDialect): if coltype: coltype = coltype(*args, **kwargs) - #is this necessary - #if is_array: + # is this necessary + # if is_array: # coltype = ARRAY(coltype) else: util.warn("Did not recognize type '%s' of column '%s'" % @@ -643,12 +644,12 @@ class SybaseDialect(default.DefaultDialect): referred_columns.append(reftable_columns[r["refkey%i" % i]]) fk_info = { - "constrained_columns": constrained_columns, - "referred_schema": reftable["schema"], - "referred_table": reftable["name"], - "referred_columns": referred_columns, - "name": r["name"] - } + "constrained_columns": constrained_columns, + "referred_schema": reftable["schema"], + "referred_table": reftable["name"], + "referred_columns": referred_columns, + "name": r["name"] + } foreign_keys.append(fk_info) diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py index 3b849a680..b4c139ea0 100644 --- a/lib/sqlalchemy/dialects/sybase/pyodbc.py +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -34,7 +34,7 @@ Currently *not* supported are:: """ from sqlalchemy.dialects.sybase.base import SybaseDialect,\ - SybaseExecutionContext + SybaseExecutionContext from sqlalchemy.connectors.pyodbc import PyODBCConnector from sqlalchemy import types as sqltypes, processors import decimal @@ -51,7 +51,7 @@ class _SybNumeric_pyodbc(sqltypes.Numeric): def bind_processor(self, dialect): super_process = super(_SybNumeric_pyodbc, self).\ - bind_processor(dialect) + bind_processor(dialect) def process(value): if self.asdecimal and \ diff --git a/lib/sqlalchemy/dialects/sybase/pysybase.py b/lib/sqlalchemy/dialects/sybase/pysybase.py index 678c146d3..a60a8fea2 100644 --- a/lib/sqlalchemy/dialects/sybase/pysybase.py +++ b/lib/sqlalchemy/dialects/sybase/pysybase.py @@ -22,7 +22,7 @@ kind at this time. from sqlalchemy import types as sqltypes, processors from sqlalchemy.dialects.sybase.base import SybaseDialect, \ - SybaseExecutionContext, SybaseSQLCompiler + SybaseExecutionContext, SybaseSQLCompiler class _SybNumeric(sqltypes.Numeric): @@ -62,8 +62,8 @@ class SybaseDialect_pysybase(SybaseDialect): statement_compiler = SybaseSQLCompiler_pysybase colspecs = { - sqltypes.Numeric: _SybNumeric, - sqltypes.Float: sqltypes.Float + sqltypes.Numeric: _SybNumeric, + sqltypes.Float: sqltypes.Float } @classmethod @@ -90,7 +90,7 @@ class SybaseDialect_pysybase(SybaseDialect): def is_disconnect(self, e, connection, cursor): if isinstance(e, (self.dbapi.OperationalError, - self.dbapi.ProgrammingError)): + self.dbapi.ProgrammingError)): msg = str(e) return ('Unable to complete network request to host' in msg or 'Invalid connection state' in msg or |