diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/oracle/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 1023 |
1 files changed, 606 insertions, 417 deletions
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index b5aea4386..944fe21c3 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -353,49 +353,63 @@ from sqlalchemy.sql import compiler, visitors, expression, util as sql_util from sqlalchemy.sql import operators as sql_operators from sqlalchemy.sql.elements import quoted_name from sqlalchemy import types as sqltypes, schema as sa_schema -from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, \ - BLOB, CLOB, TIMESTAMP, FLOAT, INTEGER +from sqlalchemy.types import ( + VARCHAR, + NVARCHAR, + CHAR, + BLOB, + CLOB, + TIMESTAMP, + FLOAT, + INTEGER, +) from itertools import groupby -RESERVED_WORDS = \ - set('SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN ' - 'DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED ' - 'ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE ' - 'ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE ' - 'BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES ' - 'AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS ' - 'NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER ' - 'CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR ' - 'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL'.split()) +RESERVED_WORDS = set( + "SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN " + "DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED " + "ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE " + "ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE " + "BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES " + "AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS " + "NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER " + "CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR " + "DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL".split() +) -NO_ARG_FNS = set('UID CURRENT_DATE SYSDATE USER ' - 'CURRENT_TIME CURRENT_TIMESTAMP'.split()) +NO_ARG_FNS = set( + "UID CURRENT_DATE SYSDATE USER " "CURRENT_TIME CURRENT_TIMESTAMP".split() +) class RAW(sqltypes._Binary): - __visit_name__ = 'RAW' + __visit_name__ = "RAW" + + OracleRaw = RAW class NCLOB(sqltypes.Text): - __visit_name__ = 'NCLOB' + __visit_name__ = "NCLOB" class VARCHAR2(VARCHAR): - __visit_name__ = 'VARCHAR2' + __visit_name__ = "VARCHAR2" + NVARCHAR2 = NVARCHAR class NUMBER(sqltypes.Numeric, sqltypes.Integer): - __visit_name__ = 'NUMBER' + __visit_name__ = "NUMBER" def __init__(self, precision=None, scale=None, asdecimal=None): if asdecimal is None: asdecimal = bool(scale and scale > 0) super(NUMBER, self).__init__( - precision=precision, scale=scale, asdecimal=asdecimal) + precision=precision, scale=scale, asdecimal=asdecimal + ) def adapt(self, impltype): ret = super(NUMBER, self).adapt(impltype) @@ -412,23 +426,23 @@ class NUMBER(sqltypes.Numeric, sqltypes.Integer): class DOUBLE_PRECISION(sqltypes.Float): - __visit_name__ = 'DOUBLE_PRECISION' + __visit_name__ = "DOUBLE_PRECISION" class BINARY_DOUBLE(sqltypes.Float): - __visit_name__ = 'BINARY_DOUBLE' + __visit_name__ = "BINARY_DOUBLE" class BINARY_FLOAT(sqltypes.Float): - __visit_name__ = 'BINARY_FLOAT' + __visit_name__ = "BINARY_FLOAT" class BFILE(sqltypes.LargeBinary): - __visit_name__ = 'BFILE' + __visit_name__ = "BFILE" class LONG(sqltypes.Text): - __visit_name__ = 'LONG' + __visit_name__ = "LONG" class DATE(sqltypes.DateTime): @@ -441,18 +455,17 @@ class DATE(sqltypes.DateTime): .. versionadded:: 0.9.4 """ - __visit_name__ = 'DATE' + + __visit_name__ = "DATE" def _compare_type_affinity(self, other): return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) class INTERVAL(sqltypes.TypeEngine): - __visit_name__ = 'INTERVAL' + __visit_name__ = "INTERVAL" - def __init__(self, - day_precision=None, - second_precision=None): + def __init__(self, day_precision=None, second_precision=None): """Construct an INTERVAL. Note that only DAY TO SECOND intervals are currently supported. @@ -471,8 +484,10 @@ class INTERVAL(sqltypes.TypeEngine): @classmethod def _adapt_from_generic_interval(cls, interval): - return INTERVAL(day_precision=interval.day_precision, - second_precision=interval.second_precision) + return INTERVAL( + day_precision=interval.day_precision, + second_precision=interval.second_precision, + ) @property def _type_affinity(self): @@ -485,38 +500,40 @@ class ROWID(sqltypes.TypeEngine): When used in a cast() or similar, generates ROWID. """ - __visit_name__ = 'ROWID' + + __visit_name__ = "ROWID" class _OracleBoolean(sqltypes.Boolean): def get_dbapi_type(self, dbapi): return dbapi.NUMBER + colspecs = { sqltypes.Boolean: _OracleBoolean, sqltypes.Interval: INTERVAL, - sqltypes.DateTime: DATE + sqltypes.DateTime: DATE, } ischema_names = { - 'VARCHAR2': VARCHAR, - 'NVARCHAR2': NVARCHAR, - 'CHAR': CHAR, - 'DATE': DATE, - 'NUMBER': NUMBER, - 'BLOB': BLOB, - 'BFILE': BFILE, - 'CLOB': CLOB, - 'NCLOB': NCLOB, - 'TIMESTAMP': TIMESTAMP, - 'TIMESTAMP WITH TIME ZONE': TIMESTAMP, - 'INTERVAL DAY TO SECOND': INTERVAL, - 'RAW': RAW, - 'FLOAT': FLOAT, - 'DOUBLE PRECISION': DOUBLE_PRECISION, - 'LONG': LONG, - 'BINARY_DOUBLE': BINARY_DOUBLE, - 'BINARY_FLOAT': BINARY_FLOAT + "VARCHAR2": VARCHAR, + "NVARCHAR2": NVARCHAR, + "CHAR": CHAR, + "DATE": DATE, + "NUMBER": NUMBER, + "BLOB": BLOB, + "BFILE": BFILE, + "CLOB": CLOB, + "NCLOB": NCLOB, + "TIMESTAMP": TIMESTAMP, + "TIMESTAMP WITH TIME ZONE": TIMESTAMP, + "INTERVAL DAY TO SECOND": INTERVAL, + "RAW": RAW, + "FLOAT": FLOAT, + "DOUBLE PRECISION": DOUBLE_PRECISION, + "LONG": LONG, + "BINARY_DOUBLE": BINARY_DOUBLE, + "BINARY_FLOAT": BINARY_FLOAT, } @@ -540,12 +557,12 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_INTERVAL(self, type_, **kw): return "INTERVAL DAY%s TO SECOND%s" % ( - type_.day_precision is not None and - "(%d)" % type_.day_precision or - "", - type_.second_precision is not None and - "(%d)" % type_.second_precision or - "", + type_.day_precision is not None + and "(%d)" % type_.day_precision + or "", + type_.second_precision is not None + and "(%d)" % type_.second_precision + or "", ) def visit_LONG(self, type_, **kw): @@ -569,52 +586,53 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_FLOAT(self, type_, **kw): # don't support conversion between decimal/binary # precision yet - kw['no_precision'] = True + kw["no_precision"] = True return self._generate_numeric(type_, "FLOAT", **kw) def visit_NUMBER(self, type_, **kw): return self._generate_numeric(type_, "NUMBER", **kw) def _generate_numeric( - self, type_, name, precision=None, - scale=None, no_precision=False, **kw): + self, type_, name, precision=None, scale=None, no_precision=False, **kw + ): if precision is None: precision = type_.precision if scale is None: - scale = getattr(type_, 'scale', None) + scale = getattr(type_, "scale", None) if no_precision or precision is None: return name elif scale is None: n = "%(name)s(%(precision)s)" - return n % {'name': name, 'precision': precision} + return n % {"name": name, "precision": precision} else: n = "%(name)s(%(precision)s, %(scale)s)" - return n % {'name': name, 'precision': precision, 'scale': scale} + return n % {"name": name, "precision": precision, "scale": scale} def visit_string(self, type_, **kw): return self.visit_VARCHAR2(type_, **kw) def visit_VARCHAR2(self, type_, **kw): - return self._visit_varchar(type_, '', '2') + return self._visit_varchar(type_, "", "2") def visit_NVARCHAR2(self, type_, **kw): - return self._visit_varchar(type_, 'N', '2') + return self._visit_varchar(type_, "N", "2") + visit_NVARCHAR = visit_NVARCHAR2 def visit_VARCHAR(self, type_, **kw): - return self._visit_varchar(type_, '', '') + return self._visit_varchar(type_, "", "") def _visit_varchar(self, type_, n, num): if not type_.length: - return "%(n)sVARCHAR%(two)s" % {'two': num, 'n': n} + return "%(n)sVARCHAR%(two)s" % {"two": num, "n": n} elif not n and self.dialect._supports_char_length: varchar = "VARCHAR%(two)s(%(length)s CHAR)" - return varchar % {'length': type_.length, 'two': num} + return varchar % {"length": type_.length, "two": num} else: varchar = "%(n)sVARCHAR%(two)s(%(length)s)" - return varchar % {'length': type_.length, 'two': num, 'n': n} + return varchar % {"length": type_.length, "two": num, "n": n} def visit_text(self, type_, **kw): return self.visit_CLOB(type_, **kw) @@ -636,7 +654,7 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_RAW(self, type_, **kw): if type_.length: - return "RAW(%(length)s)" % {'length': type_.length} + return "RAW(%(length)s)" % {"length": type_.length} else: return "RAW" @@ -652,9 +670,7 @@ class OracleCompiler(compiler.SQLCompiler): compound_keywords = util.update_copy( compiler.SQLCompiler.compound_keywords, - { - expression.CompoundSelect.EXCEPT: 'MINUS' - } + {expression.CompoundSelect.EXCEPT: "MINUS"}, ) def __init__(self, *args, **kwargs): @@ -663,8 +679,10 @@ class OracleCompiler(compiler.SQLCompiler): super(OracleCompiler, self).__init__(*args, **kwargs) def visit_mod_binary(self, binary, operator, **kw): - return "mod(%s, %s)" % (self.process(binary.left, **kw), - self.process(binary.right, **kw)) + return "mod(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def visit_now_func(self, fn, **kw): return "CURRENT_TIMESTAMP" @@ -673,22 +691,22 @@ class OracleCompiler(compiler.SQLCompiler): return "LENGTH" + self.function_argspec(fn, **kw) def visit_match_op_binary(self, binary, operator, **kw): - return "CONTAINS (%s, %s)" % (self.process(binary.left), - self.process(binary.right)) + return "CONTAINS (%s, %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) def visit_true(self, expr, **kw): - return '1' + return "1" def visit_false(self, expr, **kw): - return '0' + return "0" def get_cte_preamble(self, recursive): return "WITH" def get_select_hint_text(self, byfroms): - return " ".join( - "/*+ %s */" % text for table, text in byfroms.items() - ) + return " ".join("/*+ %s */" % text for table, text in byfroms.items()) def function_argspec(self, fn, **kw): if len(fn.clauses) > 0 or fn.name.upper() not in NO_ARG_FNS: @@ -709,13 +727,16 @@ class OracleCompiler(compiler.SQLCompiler): if self.dialect.use_ansi: return compiler.SQLCompiler.visit_join(self, join, **kwargs) else: - kwargs['asfrom'] = True + kwargs["asfrom"] = True if isinstance(join.right, expression.FromGrouping): right = join.right.element else: right = join.right - return self.process(join.left, **kwargs) + \ - ", " + self.process(right, **kwargs) + return ( + self.process(join.left, **kwargs) + + ", " + + self.process(right, **kwargs) + ) def _get_nonansi_join_whereclause(self, froms): clauses = [] @@ -727,14 +748,20 @@ class OracleCompiler(compiler.SQLCompiler): # the join condition in the WHERE clause" - that is, # unconditionally regardless of operator or the other side def visit_binary(binary): - if isinstance(binary.left, expression.ColumnClause) \ - and join.right.is_derived_from(binary.left.table): + if isinstance( + binary.left, expression.ColumnClause + ) and join.right.is_derived_from(binary.left.table): binary.left = _OuterJoinColumn(binary.left) - elif isinstance(binary.right, expression.ColumnClause) \ - and join.right.is_derived_from(binary.right.table): + elif isinstance( + binary.right, expression.ColumnClause + ) and join.right.is_derived_from(binary.right.table): binary.right = _OuterJoinColumn(binary.right) - clauses.append(visitors.cloned_traverse( - join.onclause, {}, {'binary': visit_binary})) + + clauses.append( + visitors.cloned_traverse( + join.onclause, {}, {"binary": visit_binary} + ) + ) else: clauses.append(join.onclause) @@ -757,8 +784,9 @@ class OracleCompiler(compiler.SQLCompiler): return self.process(vc.column, **kw) + "(+)" def visit_sequence(self, seq, **kw): - return (self.dialect.identifier_preparer.format_sequence(seq) + - ".nextval") + return ( + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" + ) def get_render_as_alias_suffix(self, alias_name_text): """Oracle doesn't like ``FROM table AS alias``""" @@ -770,7 +798,8 @@ class OracleCompiler(compiler.SQLCompiler): binds = [] for i, column in enumerate( - expression._select_iterables(returning_cols)): + expression._select_iterables(returning_cols) + ): if column.type._has_column_expression: col_expr = column.type.column_expression(column) else: @@ -779,19 +808,22 @@ class OracleCompiler(compiler.SQLCompiler): outparam = sql.outparam("ret_%d" % i, type_=column.type) self.binds[outparam.key] = outparam binds.append( - self.bindparam_string(self._truncate_bindparam(outparam))) - columns.append( - self.process(col_expr, within_columns_clause=False)) + self.bindparam_string(self._truncate_bindparam(outparam)) + ) + columns.append(self.process(col_expr, within_columns_clause=False)) self._add_to_result_map( - getattr(col_expr, 'name', col_expr.anon_label), - getattr(col_expr, 'name', col_expr.anon_label), - (column, getattr(column, 'name', None), - getattr(column, 'key', None)), - column.type + getattr(col_expr, "name", col_expr.anon_label), + getattr(col_expr, "name", col_expr.anon_label), + ( + column, + getattr(column, "name", None), + getattr(column, "key", None), + ), + column.type, ) - return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds) def _TODO_visit_compound_select(self, select): """Need to determine how to get ``LIMIT``/``OFFSET`` into a @@ -804,10 +836,11 @@ class OracleCompiler(compiler.SQLCompiler): so tries to wrap it in a subquery with ``rownum`` criterion. """ - if not getattr(select, '_oracle_visit', None): + if not getattr(select, "_oracle_visit", None): if not self.dialect.use_ansi: froms = self._display_froms_for_select( - select, kwargs.get('asfrom', False)) + select, kwargs.get("asfrom", False) + ) whereclause = self._get_nonansi_join_whereclause(froms) if whereclause is not None: select = select.where(whereclause) @@ -828,18 +861,20 @@ class OracleCompiler(compiler.SQLCompiler): # Outer select and "ROWNUM as ora_rn" can be dropped if # limit=0 - kwargs['select_wraps_for'] = select + kwargs["select_wraps_for"] = select select = select._generate() select._oracle_visit = True # Wrap the middle select and add the hint limitselect = sql.select([c for c in select.c]) - if limit_clause is not None and \ - self.dialect.optimize_limits and \ - select._simple_int_limit: + if ( + limit_clause is not None + and self.dialect.optimize_limits + and select._simple_int_limit + ): limitselect = limitselect.prefix_with( - "/*+ FIRST_ROWS(%d) */" % - select._limit) + "/*+ FIRST_ROWS(%d) */" % select._limit + ) limitselect._oracle_visit = True limitselect._is_wrapper = True @@ -855,8 +890,8 @@ class OracleCompiler(compiler.SQLCompiler): adapter = sql_util.ClauseAdapter(select) for_update.of = [ - adapter.traverse(elem) - for elem in for_update.of] + adapter.traverse(elem) for elem in for_update.of + ] # If needed, add the limiting clause if limit_clause is not None: @@ -873,7 +908,8 @@ class OracleCompiler(compiler.SQLCompiler): if offset_clause is not None: max_row = max_row + offset_clause limitselect.append_whereclause( - sql.literal_column("ROWNUM") <= max_row) + sql.literal_column("ROWNUM") <= max_row + ) # If needed, add the ora_rn, and wrap again with offset. if offset_clause is None: @@ -881,12 +917,14 @@ class OracleCompiler(compiler.SQLCompiler): select = limitselect else: limitselect = limitselect.column( - sql.literal_column("ROWNUM").label("ora_rn")) + sql.literal_column("ROWNUM").label("ora_rn") + ) limitselect._oracle_visit = True limitselect._is_wrapper = True offsetselect = sql.select( - [c for c in limitselect.c if c.key != 'ora_rn']) + [c for c in limitselect.c if c.key != "ora_rn"] + ) offsetselect._oracle_visit = True offsetselect._is_wrapper = True @@ -897,9 +935,11 @@ class OracleCompiler(compiler.SQLCompiler): if not self.dialect.use_binds_for_limits: offset_clause = sql.literal_column( - "%d" % select._offset) + "%d" % select._offset + ) offsetselect.append_whereclause( - sql.literal_column("ora_rn") > offset_clause) + sql.literal_column("ora_rn") > offset_clause + ) offsetselect._for_update_arg = for_update select = offsetselect @@ -910,18 +950,17 @@ class OracleCompiler(compiler.SQLCompiler): return "" def visit_empty_set_expr(self, type_): - return 'SELECT 1 FROM DUAL WHERE 1!=1' + return "SELECT 1 FROM DUAL WHERE 1!=1" def for_update_clause(self, select, **kw): if self.is_subquery(): return "" - tmp = ' FOR UPDATE' + tmp = " FOR UPDATE" if select._for_update_arg.of: - tmp += ' OF ' + ', '.join( - self.process(elem, **kw) for elem in - select._for_update_arg.of + tmp += " OF " + ", ".join( + self.process(elem, **kw) for elem in select._for_update_arg.of ) if select._for_update_arg.nowait: @@ -933,7 +972,6 @@ class OracleCompiler(compiler.SQLCompiler): class OracleDDLCompiler(compiler.DDLCompiler): - def define_constraint_cascades(self, constraint): text = "" if constraint.ondelete is not None: @@ -947,7 +985,8 @@ class OracleDDLCompiler(compiler.DDLCompiler): "Oracle does not contain native UPDATE CASCADE " "functionality - onupdates will not be rendered for foreign " "keys. Consider using deferrable=True, initially='deferred' " - "or triggers.") + "or triggers." + ) return text @@ -958,75 +997,79 @@ class OracleDDLCompiler(compiler.DDLCompiler): text = "CREATE " if index.unique: text += "UNIQUE " - if index.dialect_options['oracle']['bitmap']: + if index.dialect_options["oracle"]["bitmap"]: text += "BITMAP " text += "INDEX %s ON %s (%s)" % ( self._prepared_index_name(index, include_schema=True), preparer.format_table(index.table, use_schema=True), - ', '.join( + ", ".join( self.sql_compiler.process( - expr, - include_table=False, literal_binds=True) - for expr in index.expressions) + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), ) - if index.dialect_options['oracle']['compress'] is not False: - if index.dialect_options['oracle']['compress'] is True: + if index.dialect_options["oracle"]["compress"] is not False: + if index.dialect_options["oracle"]["compress"] is True: text += " COMPRESS" else: text += " COMPRESS %d" % ( - index.dialect_options['oracle']['compress'] + index.dialect_options["oracle"]["compress"] ) return text def post_create_table(self, table): table_opts = [] - opts = table.dialect_options['oracle'] + opts = table.dialect_options["oracle"] - if opts['on_commit']: - on_commit_options = opts['on_commit'].replace("_", " ").upper() - table_opts.append('\n ON COMMIT %s' % on_commit_options) + if opts["on_commit"]: + on_commit_options = opts["on_commit"].replace("_", " ").upper() + table_opts.append("\n ON COMMIT %s" % on_commit_options) - if opts['compress']: - if opts['compress'] is True: + if opts["compress"]: + if opts["compress"] is True: table_opts.append("\n COMPRESS") else: - table_opts.append("\n COMPRESS FOR %s" % ( - opts['compress'] - )) + table_opts.append("\n COMPRESS FOR %s" % (opts["compress"])) - return ''.join(table_opts) + return "".join(table_opts) class OracleIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = {x.lower() for x in RESERVED_WORDS} - illegal_initial_characters = {str(dig) for dig in range(0, 10)} \ - .union(["_", "$"]) + illegal_initial_characters = {str(dig) for dig in range(0, 10)}.union( + ["_", "$"] + ) def _bindparam_requires_quotes(self, value): """Return True if the given identifier requires quoting.""" lc_value = value.lower() - return (lc_value in self.reserved_words - or value[0] in self.illegal_initial_characters - or not self.legal_characters.match(util.text_type(value)) - ) + return ( + lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(util.text_type(value)) + ) def format_savepoint(self, savepoint): - name = savepoint.ident.lstrip('_') - return super( - OracleIdentifierPreparer, self).format_savepoint(savepoint, name) + name = savepoint.ident.lstrip("_") + return super(OracleIdentifierPreparer, self).format_savepoint( + savepoint, name + ) class OracleExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): return self._execute_scalar( - "SELECT " + - self.dialect.identifier_preparer.format_sequence(seq) + - ".nextval FROM DUAL", type_) + "SELECT " + + self.dialect.identifier_preparer.format_sequence(seq) + + ".nextval FROM DUAL", + type_, + ) class OracleDialect(default.DefaultDialect): - name = 'oracle' + name = "oracle" supports_alter = True supports_unicode_statements = False supports_unicode_binds = False @@ -1039,7 +1082,7 @@ class OracleDialect(default.DefaultDialect): sequences_optional = False postfetch_lastrowid = False - default_paramstyle = 'named' + default_paramstyle = "named" colspecs = colspecs ischema_names = ischema_names requires_name_normalize = True @@ -1054,29 +1097,27 @@ class OracleDialect(default.DefaultDialect): preparer = OracleIdentifierPreparer execution_ctx_cls = OracleExecutionContext - reflection_options = ('oracle_resolve_synonyms', ) + reflection_options = ("oracle_resolve_synonyms",) _use_nchar_for_unicode = False construct_arguments = [ - (sa_schema.Table, { - "resolve_synonyms": False, - "on_commit": None, - "compress": False - }), - (sa_schema.Index, { - "bitmap": False, - "compress": False - }) + ( + sa_schema.Table, + {"resolve_synonyms": False, "on_commit": None, "compress": False}, + ), + (sa_schema.Index, {"bitmap": False, "compress": False}), ] - def __init__(self, - use_ansi=True, - optimize_limits=False, - use_binds_for_limits=True, - use_nchar_for_unicode=False, - exclude_tablespaces=('SYSTEM', 'SYSAUX', ), - **kwargs): + def __init__( + self, + use_ansi=True, + optimize_limits=False, + use_binds_for_limits=True, + use_nchar_for_unicode=False, + exclude_tablespaces=("SYSTEM", "SYSAUX"), + **kwargs + ): default.DefaultDialect.__init__(self, **kwargs) self._use_nchar_for_unicode = use_nchar_for_unicode self.use_ansi = use_ansi @@ -1087,8 +1128,7 @@ class OracleDialect(default.DefaultDialect): def initialize(self, connection): super(OracleDialect, self).initialize(connection) self.implicit_returning = self.__dict__.get( - 'implicit_returning', - self.server_version_info > (10, ) + "implicit_returning", self.server_version_info > (10,) ) if self._is_oracle_8: @@ -1098,18 +1138,15 @@ class OracleDialect(default.DefaultDialect): @property def _is_oracle_8(self): - return self.server_version_info and \ - self.server_version_info < (9, ) + return self.server_version_info and self.server_version_info < (9,) @property def _supports_table_compression(self): - return self.server_version_info and \ - self.server_version_info >= (10, 1, ) + return self.server_version_info and self.server_version_info >= (10, 1) @property def _supports_table_compress_for(self): - return self.server_version_info and \ - self.server_version_info >= (11, ) + return self.server_version_info and self.server_version_info >= (11,) @property def _supports_char_length(self): @@ -1123,31 +1160,38 @@ class OracleDialect(default.DefaultDialect): additional_tests = [ expression.cast( expression.literal_column("'test nvarchar2 returns'"), - sqltypes.NVARCHAR(60) - ), + sqltypes.NVARCHAR(60), + ) ] return super(OracleDialect, self)._check_unicode_returns( - connection, additional_tests) + connection, additional_tests + ) def has_table(self, connection, table_name, schema=None): if not schema: schema = self.default_schema_name cursor = connection.execute( - sql.text("SELECT table_name FROM all_tables " - "WHERE table_name = :name AND owner = :schema_name"), + sql.text( + "SELECT table_name FROM all_tables " + "WHERE table_name = :name AND owner = :schema_name" + ), name=self.denormalize_name(table_name), - schema_name=self.denormalize_name(schema)) + schema_name=self.denormalize_name(schema), + ) return cursor.first() is not None def has_sequence(self, connection, sequence_name, schema=None): if not schema: schema = self.default_schema_name cursor = connection.execute( - sql.text("SELECT sequence_name FROM all_sequences " - "WHERE sequence_name = :name AND " - "sequence_owner = :schema_name"), + sql.text( + "SELECT sequence_name FROM all_sequences " + "WHERE sequence_name = :name AND " + "sequence_owner = :schema_name" + ), name=self.denormalize_name(sequence_name), - schema_name=self.denormalize_name(schema)) + schema_name=self.denormalize_name(schema), + ) return cursor.first() is not None def normalize_name(self, name): @@ -1156,8 +1200,9 @@ class OracleDialect(default.DefaultDialect): if util.py2k: if isinstance(name, str): name = name.decode(self.encoding) - if name.upper() == name and not \ - self.identifier_preparer._requires_quotes(name.lower()): + if name.upper() == name and not self.identifier_preparer._requires_quotes( + name.lower() + ): return name.lower() elif name.lower() == name: return quoted_name(name, quote=True) @@ -1167,8 +1212,9 @@ class OracleDialect(default.DefaultDialect): def denormalize_name(self, name): if name is None: return None - elif name.lower() == name and not \ - self.identifier_preparer._requires_quotes(name.lower()): + elif name.lower() == name and not self.identifier_preparer._requires_quotes( + name.lower() + ): name = name.upper() if util.py2k: if not self.supports_unicode_binds: @@ -1179,10 +1225,16 @@ class OracleDialect(default.DefaultDialect): def _get_default_schema_name(self, connection): return self.normalize_name( - connection.execute('SELECT USER FROM DUAL').scalar()) + connection.execute("SELECT USER FROM DUAL").scalar() + ) - def _resolve_synonym(self, connection, desired_owner=None, - desired_synonym=None, desired_table=None): + def _resolve_synonym( + self, + connection, + desired_owner=None, + desired_synonym=None, + desired_table=None, + ): """search for a local synonym matching the given desired owner/name. if desired_owner is None, attempts to locate a distinct owner. @@ -1191,19 +1243,21 @@ class OracleDialect(default.DefaultDialect): found. """ - q = "SELECT owner, table_owner, table_name, db_link, "\ + q = ( + "SELECT owner, table_owner, table_name, db_link, " "synonym_name FROM all_synonyms WHERE " + ) clauses = [] params = {} if desired_synonym: clauses.append("synonym_name = :synonym_name") - params['synonym_name'] = desired_synonym + params["synonym_name"] = desired_synonym if desired_owner: clauses.append("owner = :desired_owner") - params['desired_owner'] = desired_owner + params["desired_owner"] = desired_owner if desired_table: clauses.append("table_name = :tname") - params['tname'] = desired_table + params["tname"] = desired_table q += " AND ".join(clauses) @@ -1211,8 +1265,12 @@ class OracleDialect(default.DefaultDialect): if desired_owner: row = result.first() if row: - return (row['table_name'], row['table_owner'], - row['db_link'], row['synonym_name']) + return ( + row["table_name"], + row["table_owner"], + row["db_link"], + row["synonym_name"], + ) else: return None, None, None, None else: @@ -1220,23 +1278,35 @@ class OracleDialect(default.DefaultDialect): if len(rows) > 1: raise AssertionError( "There are multiple tables visible to the schema, you " - "must specify owner") + "must specify owner" + ) elif len(rows) == 1: row = rows[0] - return (row['table_name'], row['table_owner'], - row['db_link'], row['synonym_name']) + return ( + row["table_name"], + row["table_owner"], + row["db_link"], + row["synonym_name"], + ) else: return None, None, None, None @reflection.cache - def _prepare_reflection_args(self, connection, table_name, schema=None, - resolve_synonyms=False, dblink='', **kw): + def _prepare_reflection_args( + self, + connection, + table_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): if resolve_synonyms: actual_name, owner, dblink, synonym = self._resolve_synonym( connection, desired_owner=self.denormalize_name(schema), - desired_synonym=self.denormalize_name(table_name) + desired_synonym=self.denormalize_name(table_name), ) else: actual_name, owner, dblink, synonym = None, None, None, None @@ -1250,18 +1320,21 @@ class OracleDialect(default.DefaultDialect): # will need to hear from more users if we are doing # the right thing here. See [ticket:2619] owner = connection.scalar( - sql.text("SELECT username FROM user_db_links " - "WHERE db_link=:link"), link=dblink) + sql.text( + "SELECT username FROM user_db_links " "WHERE db_link=:link" + ), + link=dblink, + ) dblink = "@" + dblink elif not owner: owner = self.denormalize_name(schema or self.default_schema_name) - return (actual_name, owner, dblink or '', synonym) + return (actual_name, owner, dblink or "", synonym) @reflection.cache def get_schema_names(self, connection, **kw): s = "SELECT username FROM all_users ORDER BY username" - cursor = connection.execute(s,) + cursor = connection.execute(s) return [self.normalize_name(row[0]) for row in cursor] @reflection.cache @@ -1276,14 +1349,12 @@ class OracleDialect(default.DefaultDialect): if self.exclude_tablespaces: sql_str += ( "nvl(tablespace_name, 'no tablespace') " - "NOT IN (%s) AND " % ( - ', '.join(["'%s'" % ts for ts in self.exclude_tablespaces]) - ) + "NOT IN (%s) AND " + % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) ) sql_str += ( - "OWNER = :owner " - "AND IOT_NAME IS NULL " - "AND DURATION IS NULL") + "OWNER = :owner " "AND IOT_NAME IS NULL " "AND DURATION IS NULL" + ) cursor = connection.execute(sql.text(sql_str), owner=schema) return [self.normalize_name(row[0]) for row in cursor] @@ -1296,14 +1367,14 @@ class OracleDialect(default.DefaultDialect): if self.exclude_tablespaces: sql_str += ( "nvl(tablespace_name, 'no tablespace') " - "NOT IN (%s) AND " % ( - ', '.join(["'%s'" % ts for ts in self.exclude_tablespaces]) - ) + "NOT IN (%s) AND " + % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) ) sql_str += ( "OWNER = :owner " "AND IOT_NAME IS NULL " - "AND DURATION IS NOT NULL") + "AND DURATION IS NOT NULL" + ) cursor = connection.execute(sql.text(sql_str), owner=schema) return [self.normalize_name(row[0]) for row in cursor] @@ -1319,14 +1390,18 @@ class OracleDialect(default.DefaultDialect): def get_table_options(self, connection, table_name, schema=None, **kw): options = {} - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) params = {"table_name": table_name} @@ -1336,14 +1411,16 @@ class OracleDialect(default.DefaultDialect): if self._supports_table_compress_for: columns.append("compress_for") - text = "SELECT %(columns)s "\ - "FROM ALL_TABLES%(dblink)s "\ + text = ( + "SELECT %(columns)s " + "FROM ALL_TABLES%(dblink)s " "WHERE table_name = :table_name" + ) if schema is not None: - params['owner'] = schema + params["owner"] = schema text += " AND owner = :owner " - text = text % {'dblink': dblink, 'columns': ", ".join(columns)} + text = text % {"dblink": dblink, "columns": ", ".join(columns)} result = connection.execute(sql.text(text), **params) @@ -1353,9 +1430,9 @@ class OracleDialect(default.DefaultDialect): if row: if "compression" in row and enabled.get(row.compression, False): if "compress_for" in row: - options['oracle_compress'] = row.compress_for + options["oracle_compress"] = row.compress_for else: - options['oracle_compress'] = True + options["oracle_compress"] = True return options @@ -1371,19 +1448,23 @@ class OracleDialect(default.DefaultDialect): """ - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) columns = [] if self._supports_char_length: - char_length_col = 'char_length' + char_length_col = "char_length" else: - char_length_col = 'data_length' + char_length_col = "data_length" params = {"table_name": table_name} text = """ @@ -1398,10 +1479,10 @@ class OracleDialect(default.DefaultDialect): WHERE col.table_name = :table_name """ if schema is not None: - params['owner'] = schema + params["owner"] = schema text += " AND col.owner = :owner " text += " ORDER BY col.column_id" - text = text % {'dblink': dblink, 'char_length_col': char_length_col} + text = text % {"dblink": dblink, "char_length_col": char_length_col} c = connection.execute(sql.text(text), **params) @@ -1412,54 +1493,67 @@ class OracleDialect(default.DefaultDialect): length = row[2] precision = row[3] scale = row[4] - nullable = row[5] == 'Y' + nullable = row[5] == "Y" default = row[6] comment = row[7] - if coltype == 'NUMBER': + if coltype == "NUMBER": if precision is None and scale == 0: coltype = INTEGER() else: coltype = NUMBER(precision, scale) - elif coltype == 'FLOAT': + elif coltype == "FLOAT": # TODO: support "precision" here as "binary_precision" coltype = FLOAT() - elif coltype in ('VARCHAR2', 'NVARCHAR2', 'CHAR'): + elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR"): coltype = self.ischema_names.get(coltype)(length) - elif 'WITH TIME ZONE' in coltype: + elif "WITH TIME ZONE" in coltype: coltype = TIMESTAMP(timezone=True) else: - coltype = re.sub(r'\(\d+\)', '', coltype) + coltype = re.sub(r"\(\d+\)", "", coltype) try: coltype = self.ischema_names[coltype] except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (coltype, colname)) + util.warn( + "Did not recognize type '%s' of column '%s'" + % (coltype, colname) + ) coltype = sqltypes.NULLTYPE cdict = { - 'name': colname, - 'type': coltype, - 'nullable': nullable, - 'default': default, - 'autoincrement': 'auto', - 'comment': comment, + "name": colname, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": "auto", + "comment": comment, } if orig_colname.lower() == orig_colname: - cdict['quote'] = True + cdict["quote"] = True columns.append(cdict) return columns @reflection.cache - def get_table_comment(self, connection, table_name, schema=None, - resolve_synonyms=False, dblink='', **kw): - - info_cache = kw.get('info_cache') - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + def get_table_comment( + self, + connection, + table_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): + + info_cache = kw.get("info_cache") + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) COMMENT_SQL = """ SELECT comments @@ -1471,67 +1565,90 @@ class OracleDialect(default.DefaultDialect): return {"text": c.scalar()} @reflection.cache - def get_indexes(self, connection, table_name, schema=None, - resolve_synonyms=False, dblink='', **kw): - - info_cache = kw.get('info_cache') - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + def get_indexes( + self, + connection, + table_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): + + info_cache = kw.get("info_cache") + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) indexes = [] - params = {'table_name': table_name} - text = \ - "SELECT a.index_name, a.column_name, "\ - "\nb.index_type, b.uniqueness, b.compression, b.prefix_length "\ - "\nFROM ALL_IND_COLUMNS%(dblink)s a, "\ - "\nALL_INDEXES%(dblink)s b "\ - "\nWHERE "\ - "\na.index_name = b.index_name "\ - "\nAND a.table_owner = b.table_owner "\ - "\nAND a.table_name = b.table_name "\ + params = {"table_name": table_name} + text = ( + "SELECT a.index_name, a.column_name, " + "\nb.index_type, b.uniqueness, b.compression, b.prefix_length " + "\nFROM ALL_IND_COLUMNS%(dblink)s a, " + "\nALL_INDEXES%(dblink)s b " + "\nWHERE " + "\na.index_name = b.index_name " + "\nAND a.table_owner = b.table_owner " + "\nAND a.table_name = b.table_name " "\nAND a.table_name = :table_name " + ) if schema is not None: - params['schema'] = schema + params["schema"] = schema text += "AND a.table_owner = :schema " text += "ORDER BY a.index_name, a.column_position" - text = text % {'dblink': dblink} + text = text % {"dblink": dblink} q = sql.text(text) rp = connection.execute(q, **params) indexes = [] last_index_name = None pk_constraint = self.get_pk_constraint( - connection, table_name, schema, resolve_synonyms=resolve_synonyms, - dblink=dblink, info_cache=kw.get('info_cache')) - pkeys = pk_constraint['constrained_columns'] + connection, + table_name, + schema, + resolve_synonyms=resolve_synonyms, + dblink=dblink, + info_cache=kw.get("info_cache"), + ) + pkeys = pk_constraint["constrained_columns"] uniqueness = dict(NONUNIQUE=False, UNIQUE=True) enabled = dict(DISABLED=False, ENABLED=True) - oracle_sys_col = re.compile(r'SYS_NC\d+\$', re.IGNORECASE) + oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE) index = None for rset in rp: if rset.index_name != last_index_name: - index = dict(name=self.normalize_name(rset.index_name), - column_names=[], dialect_options={}) + index = dict( + name=self.normalize_name(rset.index_name), + column_names=[], + dialect_options={}, + ) indexes.append(index) - index['unique'] = uniqueness.get(rset.uniqueness, False) + index["unique"] = uniqueness.get(rset.uniqueness, False) - if rset.index_type in ('BITMAP', 'FUNCTION-BASED BITMAP'): - index['dialect_options']['oracle_bitmap'] = True + if rset.index_type in ("BITMAP", "FUNCTION-BASED BITMAP"): + index["dialect_options"]["oracle_bitmap"] = True if enabled.get(rset.compression, False): - index['dialect_options']['oracle_compress'] = rset.prefix_length + index["dialect_options"][ + "oracle_compress" + ] = rset.prefix_length # filter out Oracle SYS_NC names. could also do an outer join # to the all_tab_columns table and check for real col names there. if not oracle_sys_col.match(rset.column_name): - index['column_names'].append( - self.normalize_name(rset.column_name)) + index["column_names"].append( + self.normalize_name(rset.column_name) + ) last_index_name = rset.index_name def upper_name_set(names): @@ -1539,18 +1656,21 @@ class OracleDialect(default.DefaultDialect): pk_names = upper_name_set(pkeys) if pk_names: + def is_pk_index(index): # don't include the primary key index - return upper_name_set(index['column_names']) == pk_names + return upper_name_set(index["column_names"]) == pk_names + indexes = [idx for idx in indexes if not is_pk_index(idx)] return indexes @reflection.cache - def _get_constraint_data(self, connection, table_name, schema=None, - dblink='', **kw): + def _get_constraint_data( + self, connection, table_name, schema=None, dblink="", **kw + ): - params = {'table_name': table_name} + params = {"table_name": table_name} text = ( "SELECT" @@ -1572,7 +1692,7 @@ class OracleDialect(default.DefaultDialect): ) if schema is not None: - params['owner'] = schema + params["owner"] = schema text += "\nAND ac.owner = :owner" text += ( @@ -1584,35 +1704,49 @@ class OracleDialect(default.DefaultDialect): "\nORDER BY ac.constraint_name, loc.position" ) - text = text % {'dblink': dblink} + text = text % {"dblink": dblink} rp = connection.execute(sql.text(text), **params) constraint_data = rp.fetchall() return constraint_data @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) pkeys = [] constraint_name = None constraint_data = self._get_constraint_data( - connection, table_name, schema, dblink, - info_cache=kw.get('info_cache')) + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) for row in constraint_data: - (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ - row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) - if cons_type == 'P': + ( + cons_name, + cons_type, + local_column, + remote_table, + remote_column, + remote_owner, + ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + if cons_type == "P": if constraint_name is None: constraint_name = self.normalize_name(cons_name) pkeys.append(local_column) - return {'constrained_columns': pkeys, 'name': constraint_name} + return {"constrained_columns": pkeys, "name": constraint_name} @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -1626,74 +1760,94 @@ class OracleDialect(default.DefaultDialect): """ requested_schema = schema # to check later on - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) constraint_data = self._get_constraint_data( - connection, table_name, schema, dblink, - info_cache=kw.get('info_cache')) + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) def fkey_rec(): return { - 'name': None, - 'constrained_columns': [], - 'referred_schema': None, - 'referred_table': None, - 'referred_columns': [], - 'options': {}, + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], + "options": {}, } fkeys = util.defaultdict(fkey_rec) for row in constraint_data: - (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ - row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + ( + cons_name, + cons_type, + local_column, + remote_table, + remote_column, + remote_owner, + ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) cons_name = self.normalize_name(cons_name) - if cons_type == 'R': + if cons_type == "R": if remote_table is None: # ticket 363 util.warn( - ("Got 'None' querying 'table_name' from " - "all_cons_columns%(dblink)s - does the user have " - "proper rights to the table?") % {'dblink': dblink}) + ( + "Got 'None' querying 'table_name' from " + "all_cons_columns%(dblink)s - does the user have " + "proper rights to the table?" + ) + % {"dblink": dblink} + ) continue rec = fkeys[cons_name] - rec['name'] = cons_name - local_cols, remote_cols = rec[ - 'constrained_columns'], rec['referred_columns'] + rec["name"] = cons_name + local_cols, remote_cols = ( + rec["constrained_columns"], + rec["referred_columns"], + ) - if not rec['referred_table']: + if not rec["referred_table"]: if resolve_synonyms: - ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \ - self._resolve_synonym( - connection, - desired_owner=self.denormalize_name( - remote_owner), - desired_table=self.denormalize_name( - remote_table) - ) + ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = self._resolve_synonym( + connection, + desired_owner=self.denormalize_name(remote_owner), + desired_table=self.denormalize_name(remote_table), + ) if ref_synonym: remote_table = self.normalize_name(ref_synonym) remote_owner = self.normalize_name( - ref_remote_owner) + ref_remote_owner + ) - rec['referred_table'] = remote_table + rec["referred_table"] = remote_table - if requested_schema is not None or \ - self.denormalize_name(remote_owner) != schema: - rec['referred_schema'] = remote_owner + if ( + requested_schema is not None + or self.denormalize_name(remote_owner) != schema + ): + rec["referred_schema"] = remote_owner - if row[9] != 'NO ACTION': - rec['options']['ondelete'] = row[9] + if row[9] != "NO ACTION": + rec["options"]["ondelete"] = row[9] local_cols.append(local_column) remote_cols.append(remote_column) @@ -1701,54 +1855,82 @@ class OracleDialect(default.DefaultDialect): return list(fkeys.values()) @reflection.cache - def get_unique_constraints(self, connection, table_name, schema=None, **kw): - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) constraint_data = self._get_constraint_data( - connection, table_name, schema, dblink, - info_cache=kw.get('info_cache')) + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) - unique_keys = filter(lambda x: x[1] == 'U', constraint_data) + unique_keys = filter(lambda x: x[1] == "U", constraint_data) uniques_group = groupby(unique_keys, lambda x: x[0]) - index_names = set([ix['name'] for ix in self.get_indexes(connection, table_name, schema=schema)]) + index_names = set( + [ + ix["name"] + for ix in self.get_indexes( + connection, table_name, schema=schema + ) + ] + ) return [ { - 'name': name, - 'column_names': cols, - 'duplicates_index': name if name in index_names else None + "name": name, + "column_names": cols, + "duplicates_index": name if name in index_names else None, } - for name, cols in - [ + for name, cols in [ [ self.normalize_name(i[0]), - [self.normalize_name(x[2]) for x in i[1]] - ] for i in uniques_group + [self.normalize_name(x[2]) for x in i[1]], + ] + for i in uniques_group ] ] @reflection.cache - def get_view_definition(self, connection, view_name, schema=None, - resolve_synonyms=False, dblink='', **kw): - info_cache = kw.get('info_cache') - (view_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, view_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) - - params = {'view_name': view_name} + def get_view_definition( + self, + connection, + view_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): + info_cache = kw.get("info_cache") + (view_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + view_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) + + params = {"view_name": view_name} text = "SELECT text FROM all_views WHERE view_name=:view_name" if schema is not None: text += " AND owner = :schema" - params['schema'] = schema + params["schema"] = schema rp = connection.execute(sql.text(text), **params).scalar() if rp: @@ -1759,34 +1941,41 @@ class OracleDialect(default.DefaultDialect): return None @reflection.cache - def get_check_constraints(self, connection, table_name, schema=None, - include_all=False, **kw): - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + def get_check_constraints( + self, connection, table_name, schema=None, include_all=False, **kw + ): + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) constraint_data = self._get_constraint_data( - connection, table_name, schema, dblink, - info_cache=kw.get('info_cache')) + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) - check_constraints = filter(lambda x: x[1] == 'C', constraint_data) + check_constraints = filter(lambda x: x[1] == "C", constraint_data) return [ - { - 'name': self.normalize_name(cons[0]), - 'sqltext': cons[8], - } - for cons in check_constraints if include_all or - not re.match(r'..+?. IS NOT NULL$', cons[8])] + {"name": self.normalize_name(cons[0]), "sqltext": cons[8]} + for cons in check_constraints + if include_all or not re.match(r"..+?. IS NOT NULL$", cons[8]) + ] class _OuterJoinColumn(sql.ClauseElement): - __visit_name__ = 'outer_join_column' + __visit_name__ = "outer_join_column" def __init__(self, column): self.column = column |