summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/databases/oracle.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/databases/oracle.py')
-rw-r--r--lib/sqlalchemy/databases/oracle.py255
1 files changed, 122 insertions, 133 deletions
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index 9d7d6a112..d3aa2e268 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -5,9 +5,9 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, StringIO, string, re, warnings
+import re, warnings, operator
-from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging
+from sqlalchemy import util, sql, schema, ansisql, exceptions, logging
from sqlalchemy.engine import default, base
import sqlalchemy.types as sqltypes
@@ -88,8 +88,11 @@ class OracleText(sqltypes.TEXT):
def convert_result_value(self, value, dialect):
if value is None:
return None
- else:
+ elif hasattr(value, 'read'):
+ # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str
return super(OracleText, self).convert_result_value(value.read(), dialect)
+ else:
+ return super(OracleText, self).convert_result_value(value, dialect)
class OracleRaw(sqltypes.Binary):
@@ -178,25 +181,31 @@ class OracleExecutionContext(default.DefaultExecutionContext):
super(OracleExecutionContext, self).pre_exec()
if self.dialect.auto_setinputsizes:
self.set_input_sizes()
+ if self.compiled_parameters is not None and not isinstance(self.compiled_parameters, list):
+ for key in self.compiled_parameters:
+ (bindparam, name, value) = self.compiled_parameters.get_parameter(key)
+ if bindparam.isoutparam:
+ dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
+ if not hasattr(self, 'out_parameters'):
+ self.out_parameters = {}
+ self.out_parameters[name] = self.cursor.var(dbtype)
+ self.parameters[name] = self.out_parameters[name]
def get_result_proxy(self):
+ if hasattr(self, 'out_parameters'):
+ if self.compiled_parameters is not None:
+ for k in self.out_parameters:
+ type = self.compiled_parameters.get_type(k)
+ self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect)
+ else:
+ for k in self.out_parameters:
+ self.out_parameters[k] = self.out_parameters[k].getvalue()
+
if self.cursor.description is not None:
- if self.dialect.auto_convert_lobs and self.typemap is None:
- typemap = {}
- binary = False
- for column in self.cursor.description:
- type_code = column[1]
- if type_code in self.dialect.ORACLE_BINARY_TYPES:
- binary = True
- typemap[column[0].lower()] = OracleBinary()
- self.typemap = typemap
- if binary:
+ for column in self.cursor.description:
+ type_code = column[1]
+ if type_code in self.dialect.ORACLE_BINARY_TYPES:
return base.BufferedColumnResultProxy(self)
- else:
- for column in self.cursor.description:
- type_code = column[1]
- if type_code in self.dialect.ORACLE_BINARY_TYPES:
- return base.BufferedColumnResultProxy(self)
return base.ResultProxy(self)
@@ -208,11 +217,26 @@ class OracleDialect(ansisql.ANSIDialect):
self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
self.auto_setinputsizes = auto_setinputsizes
self.auto_convert_lobs = auto_convert_lobs
+
if self.dbapi is not None:
self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)]
else:
self.ORACLE_BINARY_TYPES = []
+ def dbapi_type_map(self):
+ if self.dbapi is None or not self.auto_convert_lobs:
+ return {}
+ else:
+ return {
+ self.dbapi.NUMBER: OracleInteger(),
+ self.dbapi.CLOB: OracleText(),
+ self.dbapi.BLOB: OracleBinary(),
+ self.dbapi.STRING: OracleString(),
+ self.dbapi.TIMESTAMP: OracleTimestamp(),
+ self.dbapi.BINARY: OracleRaw(),
+ datetime.datetime: OracleDate()
+ }
+
def dbapi(cls):
import cx_Oracle
return cx_Oracle
@@ -251,7 +275,7 @@ class OracleDialect(ansisql.ANSIDialect):
return 30
def oid_column_name(self, column):
- if not isinstance(column.table, sql.TableClause) and not isinstance(column.table, sql.Select):
+ if not isinstance(column.table, (sql.TableClause, sql.Select)):
return None
else:
return "rowid"
@@ -341,7 +365,7 @@ class OracleDialect(ansisql.ANSIDialect):
return name, owner, dblink
raise
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
preparer = self.identifier_preparer
if not preparer.should_quote(table):
name = table.name.upper()
@@ -363,6 +387,13 @@ class OracleDialect(ansisql.ANSIDialect):
#print "ROW:" , row
(colname, coltype, length, precision, scale, nullable, default) = (row[0], row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
+ # if name comes back as all upper, assume its case folded
+ if (colname.upper() == colname):
+ colname = colname.lower()
+
+ if include_columns and colname not in include_columns:
+ continue
+
# INTEGER if the scale is 0 and precision is null
# NUMBER if the scale and precision are both null
# NUMBER(9,2) if the precision is 9 and the scale is 2
@@ -382,16 +413,13 @@ class OracleDialect(ansisql.ANSIDialect):
try:
coltype = ischema_names[coltype]
except KeyError:
- raise exceptions.AssertionError("Can't get coltype for type '%s' on colname '%s'" % (coltype, colname))
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, colname)))
+ coltype = sqltypes.NULLTYPE
colargs = []
if default is not None:
colargs.append(schema.PassiveDefault(sql.text(default)))
- # if name comes back as all upper, assume its case folded
- if (colname.upper() == colname):
- colname = colname.lower()
-
table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs))
if not len(table.columns):
@@ -458,16 +486,27 @@ class OracleDialect(ansisql.ANSIDialect):
OracleDialect.logger = logging.class_logger(OracleDialect)
+class _OuterJoinColumn(sql.ClauseElement):
+ __visit_name__ = 'outer_join_column'
+ def __init__(self, column):
+ self.column = column
+
class OracleCompiler(ansisql.ANSICompiler):
"""Oracle compiler modifies the lexical structure of Select
statements to work under non-ANSI configured Oracle databases, if
the use_ansi flag is False.
"""
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ operator.mod : lambda x, y:"mod(%s, %s)" % (x, y)
+ }
+ )
+
def __init__(self, *args, **kwargs):
super(OracleCompiler, self).__init__(*args, **kwargs)
- # we have to modify SELECT objects a little bit, so store state here
- self._select_state = {}
+ self.__wheres = {}
def default_from(self):
"""Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
@@ -480,49 +519,46 @@ class OracleCompiler(ansisql.ANSICompiler):
def apply_function_parens(self, func):
return len(func.clauses) > 0
- def visit_join(self, join):
+ def visit_join(self, join, **kwargs):
if self.dialect.use_ansi:
- return ansisql.ANSICompiler.visit_join(self, join)
-
- self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right)
- where = self.wheres.get(join.left, None)
+ return ansisql.ANSICompiler.visit_join(self, join, **kwargs)
+
+ (where, parentjoin) = self.__wheres.get(join, (None, None))
+
+ class VisitOn(sql.ClauseVisitor):
+ def visit_binary(s, binary):
+ if binary.operator == operator.eq:
+ if binary.left.table is join.right:
+ binary.left = _OuterJoinColumn(binary.left)
+ elif binary.right.table is join.right:
+ binary.right = _OuterJoinColumn(binary.right)
+
if where is not None:
- self.wheres[join] = sql.and_(where, join.onclause)
+ self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(VisitOn().traverse(join.onclause, clone=True), where), parentjoin)
else:
- self.wheres[join] = join.onclause
-# self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause)
- self.strings[join] = self.froms[join]
-
- if join.isouter:
- # if outer join, push on the right side table as the current "outertable"
- self._outertable = join.right
-
- # now re-visit the onclause, which will be used as a where clause
- # (the first visit occured via the Join object itself right before it called visit_join())
- self.traverse(join.onclause)
-
- self._outertable = None
-
- self.wheres[join].accept_visitor(self)
+ self.__wheres[join.left] = self.__wheres[join] = (VisitOn().traverse(join.onclause, clone=True), join)
- def visit_insert_sequence(self, column, sequence, parameters):
- """This is the `sequence` equivalent to ``ANSICompiler``'s
- `visit_insert_column_default` which ensures that the column is
- present in the generated column list.
- """
-
- parameters.setdefault(column.key, None)
+ return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
+
+ def get_whereclause(self, f):
+ if f in self.__wheres:
+ return self.__wheres[f][0]
+ else:
+ return None
+
+ def visit_outer_join_column(self, vc):
+ return self.process(vc.column) + "(+)"
+
+ def uses_sequences_for_inserts(self):
+ return True
- def visit_alias(self, alias):
+ def visit_alias(self, alias, asfrom=False, **kwargs):
"""Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??"""
-
- self.froms[alias] = self.get_from_text(alias.original) + " " + alias.name
- self.strings[alias] = self.get_str(alias.original)
-
- def visit_column(self, column):
- ansisql.ANSICompiler.visit_column(self, column)
- if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable:
- self.strings[column] = self.strings[column] + "(+)"
+
+ if asfrom:
+ return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + alias.name
+ else:
+ return self.process(alias.original, **kwargs)
def visit_insert(self, insert):
"""``INSERT`` s are required to have the primary keys be explicitly present.
@@ -539,76 +575,35 @@ class OracleCompiler(ansisql.ANSICompiler):
def _TODO_visit_compound_select(self, select):
"""Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
+ pass
- if getattr(select, '_oracle_visit', False):
- # cancel out the compiled order_by on the select
- if hasattr(select, "order_by_clause"):
- self.strings[select.order_by_clause] = ""
- ansisql.ANSICompiler.visit_compound_select(self, select)
- return
-
- if select.limit is not None or select.offset is not None:
- select._oracle_visit = True
- # to use ROW_NUMBER(), an ORDER BY is required.
- orderby = self.strings[select.order_by_clause]
- if not orderby:
- orderby = select.oid_column
- self.traverse(orderby)
- orderby = self.strings[orderby]
- class SelectVisitor(sql.NoColumnVisitor):
- def visit_select(self, select):
- select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
- SelectVisitor().traverse(select)
- limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
- if select.offset is not None:
- limitselect.append_whereclause("ora_rn>%d" % select.offset)
- if select.limit is not None:
- limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
- else:
- limitselect.append_whereclause("ora_rn<=%d" % select.limit)
- self.traverse(limitselect)
- self.strings[select] = self.strings[limitselect]
- self.froms[select] = self.froms[limitselect]
- else:
- ansisql.ANSICompiler.visit_compound_select(self, select)
-
- def visit_select(self, select):
+ def visit_select(self, select, **kwargs):
"""Look for ``LIMIT`` and OFFSET in a select statement, and if
so tries to wrap it in a subquery with ``row_number()`` criterion.
"""
- # TODO: put a real copy-container on Select and copy, or somehow make this
- # not modify the Select statement
- if self._select_state.get((select, 'visit'), False):
- # cancel out the compiled order_by on the select
- if hasattr(select, "order_by_clause"):
- self.strings[select.order_by_clause] = ""
- ansisql.ANSICompiler.visit_select(self, select)
- return
-
- if select.limit is not None or select.offset is not None:
- self._select_state[(select, 'visit')] = True
+ if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None):
# to use ROW_NUMBER(), an ORDER BY is required.
- orderby = self.strings[select.order_by_clause]
+ orderby = self.process(select._order_by_clause)
if not orderby:
orderby = select.oid_column
self.traverse(orderby)
- orderby = self.strings[orderby]
- if not hasattr(select, '_oracle_visit'):
- select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
- select._oracle_visit = True
+ orderby = self.process(orderby)
+
+ oldselect = select
+ select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None)
+ select._oracle_visit = True
+
limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
- if select.offset is not None:
- limitselect.append_whereclause("ora_rn>%d" % select.offset)
- if select.limit is not None:
- limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
+ if select._offset is not None:
+ limitselect.append_whereclause("ora_rn>%d" % select._offset)
+ if select._limit is not None:
+ limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset))
else:
- limitselect.append_whereclause("ora_rn<=%d" % select.limit)
- self.traverse(limitselect)
- self.strings[select] = self.strings[limitselect]
- self.froms[select] = self.froms[limitselect]
+ limitselect.append_whereclause("ora_rn<=%d" % select._limit)
+ return self.process(limitselect)
else:
- ansisql.ANSICompiler.visit_select(self, select)
+ return ansisql.ANSICompiler.visit_select(self, select, **kwargs)
def limit_clause(self, select):
return ""
@@ -619,12 +614,6 @@ class OracleCompiler(ansisql.ANSICompiler):
else:
return super(OracleCompiler, self).for_update_clause(select)
- def visit_binary(self, binary):
- if binary.operator == '%':
- self.strings[binary] = ("MOD(%s,%s)"%(self.get_str(binary.left), self.get_str(binary.right)))
- else:
- return ansisql.ANSICompiler.visit_binary(self, binary)
-
class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
@@ -639,22 +628,22 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
def visit_sequence(self, sequence):
- if not self.dialect.has_sequence(self.connection, sequence.name):
+ if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
class OracleSchemaDropper(ansisql.ANSISchemaDropper):
def visit_sequence(self, sequence):
- if self.dialect.has_sequence(self.connection, sequence.name):
+ if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name):
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
def exec_default_sql(self, default):
- c = sql.select([default.arg], from_obj=["DUAL"]).compile(engine=self.connection)
- return self.connection.execute_compiled(c).scalar()
+ c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection)
+ return self.connection.execute(c).scalar()
def visit_sequence(self, seq):
- return self.connection.execute_text("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
+ return self.connection.execute("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
dialect = OracleDialect