diff options
Diffstat (limited to 'lib/sqlalchemy/databases/oracle.py')
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 255 |
1 files changed, 122 insertions, 133 deletions
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 9d7d6a112..d3aa2e268 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -5,9 +5,9 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sys, StringIO, string, re, warnings +import re, warnings, operator -from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging +from sqlalchemy import util, sql, schema, ansisql, exceptions, logging from sqlalchemy.engine import default, base import sqlalchemy.types as sqltypes @@ -88,8 +88,11 @@ class OracleText(sqltypes.TEXT): def convert_result_value(self, value, dialect): if value is None: return None - else: + elif hasattr(value, 'read'): + # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str return super(OracleText, self).convert_result_value(value.read(), dialect) + else: + return super(OracleText, self).convert_result_value(value, dialect) class OracleRaw(sqltypes.Binary): @@ -178,25 +181,31 @@ class OracleExecutionContext(default.DefaultExecutionContext): super(OracleExecutionContext, self).pre_exec() if self.dialect.auto_setinputsizes: self.set_input_sizes() + if self.compiled_parameters is not None and not isinstance(self.compiled_parameters, list): + for key in self.compiled_parameters: + (bindparam, name, value) = self.compiled_parameters.get_parameter(key) + if bindparam.isoutparam: + dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + if not hasattr(self, 'out_parameters'): + self.out_parameters = {} + self.out_parameters[name] = self.cursor.var(dbtype) + self.parameters[name] = self.out_parameters[name] def get_result_proxy(self): + if hasattr(self, 'out_parameters'): + if self.compiled_parameters is not None: + for k in self.out_parameters: + type = self.compiled_parameters.get_type(k) + self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect) + else: + for k in self.out_parameters: + self.out_parameters[k] = self.out_parameters[k].getvalue() + if self.cursor.description is not None: - if self.dialect.auto_convert_lobs and self.typemap is None: - typemap = {} - binary = False - for column in self.cursor.description: - type_code = column[1] - if type_code in self.dialect.ORACLE_BINARY_TYPES: - binary = True - typemap[column[0].lower()] = OracleBinary() - self.typemap = typemap - if binary: + for column in self.cursor.description: + type_code = column[1] + if type_code in self.dialect.ORACLE_BINARY_TYPES: return base.BufferedColumnResultProxy(self) - else: - for column in self.cursor.description: - type_code = column[1] - if type_code in self.dialect.ORACLE_BINARY_TYPES: - return base.BufferedColumnResultProxy(self) return base.ResultProxy(self) @@ -208,11 +217,26 @@ class OracleDialect(ansisql.ANSIDialect): self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) self.auto_setinputsizes = auto_setinputsizes self.auto_convert_lobs = auto_convert_lobs + if self.dbapi is not None: self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)] else: self.ORACLE_BINARY_TYPES = [] + def dbapi_type_map(self): + if self.dbapi is None or not self.auto_convert_lobs: + return {} + else: + return { + self.dbapi.NUMBER: OracleInteger(), + self.dbapi.CLOB: OracleText(), + self.dbapi.BLOB: OracleBinary(), + self.dbapi.STRING: OracleString(), + self.dbapi.TIMESTAMP: OracleTimestamp(), + self.dbapi.BINARY: OracleRaw(), + datetime.datetime: OracleDate() + } + def dbapi(cls): import cx_Oracle return cx_Oracle @@ -251,7 +275,7 @@ class OracleDialect(ansisql.ANSIDialect): return 30 def oid_column_name(self, column): - if not isinstance(column.table, sql.TableClause) and not isinstance(column.table, sql.Select): + if not isinstance(column.table, (sql.TableClause, sql.Select)): return None else: return "rowid" @@ -341,7 +365,7 @@ class OracleDialect(ansisql.ANSIDialect): return name, owner, dblink raise - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): preparer = self.identifier_preparer if not preparer.should_quote(table): name = table.name.upper() @@ -363,6 +387,13 @@ class OracleDialect(ansisql.ANSIDialect): #print "ROW:" , row (colname, coltype, length, precision, scale, nullable, default) = (row[0], row[1], row[2], row[3], row[4], row[5]=='Y', row[6]) + # if name comes back as all upper, assume its case folded + if (colname.upper() == colname): + colname = colname.lower() + + if include_columns and colname not in include_columns: + continue + # INTEGER if the scale is 0 and precision is null # NUMBER if the scale and precision are both null # NUMBER(9,2) if the precision is 9 and the scale is 2 @@ -382,16 +413,13 @@ class OracleDialect(ansisql.ANSIDialect): try: coltype = ischema_names[coltype] except KeyError: - raise exceptions.AssertionError("Can't get coltype for type '%s' on colname '%s'" % (coltype, colname)) + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, colname))) + coltype = sqltypes.NULLTYPE colargs = [] if default is not None: colargs.append(schema.PassiveDefault(sql.text(default))) - # if name comes back as all upper, assume its case folded - if (colname.upper() == colname): - colname = colname.lower() - table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs)) if not len(table.columns): @@ -458,16 +486,27 @@ class OracleDialect(ansisql.ANSIDialect): OracleDialect.logger = logging.class_logger(OracleDialect) +class _OuterJoinColumn(sql.ClauseElement): + __visit_name__ = 'outer_join_column' + def __init__(self, column): + self.column = column + class OracleCompiler(ansisql.ANSICompiler): """Oracle compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False. """ + operators = ansisql.ANSICompiler.operators.copy() + operators.update( + { + operator.mod : lambda x, y:"mod(%s, %s)" % (x, y) + } + ) + def __init__(self, *args, **kwargs): super(OracleCompiler, self).__init__(*args, **kwargs) - # we have to modify SELECT objects a little bit, so store state here - self._select_state = {} + self.__wheres = {} def default_from(self): """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. @@ -480,49 +519,46 @@ class OracleCompiler(ansisql.ANSICompiler): def apply_function_parens(self, func): return len(func.clauses) > 0 - def visit_join(self, join): + def visit_join(self, join, **kwargs): if self.dialect.use_ansi: - return ansisql.ANSICompiler.visit_join(self, join) - - self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right) - where = self.wheres.get(join.left, None) + return ansisql.ANSICompiler.visit_join(self, join, **kwargs) + + (where, parentjoin) = self.__wheres.get(join, (None, None)) + + class VisitOn(sql.ClauseVisitor): + def visit_binary(s, binary): + if binary.operator == operator.eq: + if binary.left.table is join.right: + binary.left = _OuterJoinColumn(binary.left) + elif binary.right.table is join.right: + binary.right = _OuterJoinColumn(binary.right) + if where is not None: - self.wheres[join] = sql.and_(where, join.onclause) + self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(VisitOn().traverse(join.onclause, clone=True), where), parentjoin) else: - self.wheres[join] = join.onclause -# self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause) - self.strings[join] = self.froms[join] - - if join.isouter: - # if outer join, push on the right side table as the current "outertable" - self._outertable = join.right - - # now re-visit the onclause, which will be used as a where clause - # (the first visit occured via the Join object itself right before it called visit_join()) - self.traverse(join.onclause) - - self._outertable = None - - self.wheres[join].accept_visitor(self) + self.__wheres[join.left] = self.__wheres[join] = (VisitOn().traverse(join.onclause, clone=True), join) - def visit_insert_sequence(self, column, sequence, parameters): - """This is the `sequence` equivalent to ``ANSICompiler``'s - `visit_insert_column_default` which ensures that the column is - present in the generated column list. - """ - - parameters.setdefault(column.key, None) + return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) + + def get_whereclause(self, f): + if f in self.__wheres: + return self.__wheres[f][0] + else: + return None + + def visit_outer_join_column(self, vc): + return self.process(vc.column) + "(+)" + + def uses_sequences_for_inserts(self): + return True - def visit_alias(self, alias): + def visit_alias(self, alias, asfrom=False, **kwargs): """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" - - self.froms[alias] = self.get_from_text(alias.original) + " " + alias.name - self.strings[alias] = self.get_str(alias.original) - - def visit_column(self, column): - ansisql.ANSICompiler.visit_column(self, column) - if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable: - self.strings[column] = self.strings[column] + "(+)" + + if asfrom: + return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + alias.name + else: + return self.process(alias.original, **kwargs) def visit_insert(self, insert): """``INSERT`` s are required to have the primary keys be explicitly present. @@ -539,76 +575,35 @@ class OracleCompiler(ansisql.ANSICompiler): def _TODO_visit_compound_select(self, select): """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" + pass - if getattr(select, '_oracle_visit', False): - # cancel out the compiled order_by on the select - if hasattr(select, "order_by_clause"): - self.strings[select.order_by_clause] = "" - ansisql.ANSICompiler.visit_compound_select(self, select) - return - - if select.limit is not None or select.offset is not None: - select._oracle_visit = True - # to use ROW_NUMBER(), an ORDER BY is required. - orderby = self.strings[select.order_by_clause] - if not orderby: - orderby = select.oid_column - self.traverse(orderby) - orderby = self.strings[orderby] - class SelectVisitor(sql.NoColumnVisitor): - def visit_select(self, select): - select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")) - SelectVisitor().traverse(select) - limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) - if select.offset is not None: - limitselect.append_whereclause("ora_rn>%d" % select.offset) - if select.limit is not None: - limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset)) - else: - limitselect.append_whereclause("ora_rn<=%d" % select.limit) - self.traverse(limitselect) - self.strings[select] = self.strings[limitselect] - self.froms[select] = self.froms[limitselect] - else: - ansisql.ANSICompiler.visit_compound_select(self, select) - - def visit_select(self, select): + def visit_select(self, select, **kwargs): """Look for ``LIMIT`` and OFFSET in a select statement, and if so tries to wrap it in a subquery with ``row_number()`` criterion. """ - # TODO: put a real copy-container on Select and copy, or somehow make this - # not modify the Select statement - if self._select_state.get((select, 'visit'), False): - # cancel out the compiled order_by on the select - if hasattr(select, "order_by_clause"): - self.strings[select.order_by_clause] = "" - ansisql.ANSICompiler.visit_select(self, select) - return - - if select.limit is not None or select.offset is not None: - self._select_state[(select, 'visit')] = True + if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None): # to use ROW_NUMBER(), an ORDER BY is required. - orderby = self.strings[select.order_by_clause] + orderby = self.process(select._order_by_clause) if not orderby: orderby = select.oid_column self.traverse(orderby) - orderby = self.strings[orderby] - if not hasattr(select, '_oracle_visit'): - select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")) - select._oracle_visit = True + orderby = self.process(orderby) + + oldselect = select + select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None) + select._oracle_visit = True + limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) - if select.offset is not None: - limitselect.append_whereclause("ora_rn>%d" % select.offset) - if select.limit is not None: - limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset)) + if select._offset is not None: + limitselect.append_whereclause("ora_rn>%d" % select._offset) + if select._limit is not None: + limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset)) else: - limitselect.append_whereclause("ora_rn<=%d" % select.limit) - self.traverse(limitselect) - self.strings[select] = self.strings[limitselect] - self.froms[select] = self.froms[limitselect] + limitselect.append_whereclause("ora_rn<=%d" % select._limit) + return self.process(limitselect) else: - ansisql.ANSICompiler.visit_select(self, select) + return ansisql.ANSICompiler.visit_select(self, select, **kwargs) def limit_clause(self, select): return "" @@ -619,12 +614,6 @@ class OracleCompiler(ansisql.ANSICompiler): else: return super(OracleCompiler, self).for_update_clause(select) - def visit_binary(self, binary): - if binary.operator == '%': - self.strings[binary] = ("MOD(%s,%s)"%(self.get_str(binary.left), self.get_str(binary.right))) - else: - return ansisql.ANSICompiler.visit_binary(self, binary) - class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): @@ -639,22 +628,22 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec def visit_sequence(self, sequence): - if not self.dialect.has_sequence(self.connection, sequence.name): + if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() class OracleSchemaDropper(ansisql.ANSISchemaDropper): def visit_sequence(self, sequence): - if self.dialect.has_sequence(self.connection, sequence.name): + if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name): self.append("DROP SEQUENCE %s" % sequence.name) self.execute() class OracleDefaultRunner(ansisql.ANSIDefaultRunner): def exec_default_sql(self, default): - c = sql.select([default.arg], from_obj=["DUAL"]).compile(engine=self.connection) - return self.connection.execute_compiled(c).scalar() + c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection) + return self.connection.execute(c).scalar() def visit_sequence(self, seq): - return self.connection.execute_text("SELECT " + seq.name + ".nextval FROM DUAL").scalar() + return self.connection.execute("SELECT " + seq.name + ".nextval FROM DUAL").scalar() dialect = OracleDialect |