summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES6
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py51
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py1
-rw-r--r--lib/sqlalchemy/sql/compiler.py2
-rw-r--r--test/dialect/test_oracle.py53
-rw-r--r--test/sql/test_types.py2
6 files changed, 94 insertions, 21 deletions
diff --git a/CHANGES b/CHANGES
index cc005a271..27a38145e 100644
--- a/CHANGES
+++ b/CHANGES
@@ -253,6 +253,12 @@ CHANGES
- an NCLOB type is added to the base types.
+ - the Oracle dialect now features NUMBER which intends
+ to act justlike Oracle's NUMBER type. It is the primary
+ numeric type returned by table reflection and attempts
+ to return Decimal()/float/int based on the precision/scale
+ parameters. [ticket:885]
+
- func.char_length is a generic function for LENGTH
- ForeignKey() which includes onupdate=<value> will emit a
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 17b09e79c..a5ced0738 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -132,14 +132,26 @@ class NCLOB(sqltypes.Text):
VARCHAR2 = VARCHAR
NVARCHAR2 = NVARCHAR
-class NUMBER(sqltypes.Numeric):
+
+class NUMBER(sqltypes.Numeric, sqltypes.Integer):
__visit_name__ = 'NUMBER'
-class BFILE(sqltypes.Binary):
- __visit_name__ = 'BFILE'
-
+ def __init__(self, precision=None, scale=None, asdecimal=None):
+ if asdecimal is None:
+ asdecimal = bool(scale and scale > 0)
+
+ super(NUMBER, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal)
+
class DOUBLE_PRECISION(sqltypes.Numeric):
__visit_name__ = 'DOUBLE_PRECISION'
+ def __init__(self, precision=None, scale=None, asdecimal=None):
+ if asdecimal is None:
+ asdecimal = False
+
+ super(DOUBLE_PRECISION, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal)
+
+class BFILE(sqltypes.Binary):
+ __visit_name__ = 'BFILE'
class LONG(sqltypes.Text):
__visit_name__ = 'LONG'
@@ -200,13 +212,24 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
return self.visit_DATE(type_)
def visit_float(self, type_):
- if type_.precision is None:
- return "NUMERIC"
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : 2}
+ return self.visit_FLOAT(type_)
def visit_unicode(self, type_):
return self.visit_NVARCHAR(type_)
+
+ def visit_DOUBLE_PRECISION(self, type_):
+ return self._generate_numeric(type_, "DOUBLE PRECISION")
+
+ def visit_NUMBER(self, type_):
+ return self._generate_numeric(type_, "NUMBER")
+
+ def _generate_numeric(self, type_, name):
+ if type_.precision is None:
+ return name
+ elif type_.scale is None:
+ return "%(name)s(%(precision)s)" % {'name':name,'precision': type_.precision}
+ else:
+ return "%(name)s(%(precision)s, %(scale)s)" % {'name':name,'precision': type_.precision, 'scale' : type_.scale}
def visit_VARCHAR(self, type_):
return "VARCHAR(%(length)s)" % {'length' : type_.length}
@@ -658,18 +681,8 @@ class OracleDialect(default.DefaultDialect):
(colname, coltype, length, precision, scale, nullable, default) = \
(self.normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
- # INTEGER if the scale is 0 and precision is null
- # NUMBER if the scale and precision are both null
- # NUMBER(9,2) if the precision is 9 and the scale is 2
- # NUMBER(3) if the precision is 3 and scale is 0
- #length is ignored except for CHAR and VARCHAR2
if coltype == 'NUMBER' :
- if precision is None and scale is None:
- coltype = sqltypes.NUMERIC
- elif precision is None and scale == 0:
- coltype = sqltypes.INTEGER
- else :
- coltype = sqltypes.NUMERIC(precision, scale)
+ coltype = NUMBER(precision, scale)
elif coltype=='CHAR' or coltype=='VARCHAR2':
coltype = self.ischema_names.get(coltype)(length)
else:
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index 475d6559a..f40923591 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -178,6 +178,7 @@ colspecs = {
sqltypes.TIMESTAMP : _OracleTimestamp,
sqltypes.Integer : _OracleInteger, # this is only needed for OUT parameters.
# it would be nice if we could not use it otherwise.
+ oracle.NUMBER : oracle.NUMBER, # don't let this get converted
oracle.RAW: _OracleRaw,
}
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index d6187bcde..403ec968b 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1149,6 +1149,8 @@ class GenericTypeCompiler(engine.TypeCompiler):
def visit_NUMERIC(self, type_):
if type_.precision is None:
return "NUMERIC"
+ elif type_.scale is None:
+ return "NUMERIC(%(precision)s)" % {'precision': type_.precision}
else:
return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}
diff --git a/test/dialect/test_oracle.py b/test/dialect/test_oracle.py
index 85c3097be..f8cfdf1fc 100644
--- a/test/dialect/test_oracle.py
+++ b/test/dialect/test_oracle.py
@@ -10,6 +10,7 @@ from sqlalchemy.test.engines import testing_engine
from sqlalchemy.dialects.oracle import cx_oracle, base as oracle
from sqlalchemy.engine import default
from sqlalchemy.util import jython
+from decimal import Decimal
import os
@@ -380,7 +381,57 @@ class TypesTest(TestBase, AssertsCompiledSQL):
assert isinstance(x, int)
finally:
t1.drop()
-
+
+ def test_numerics(self):
+ m = MetaData(testing.db)
+ t1 = Table('t1', m,
+ Column('intcol', Integer),
+ Column('numericcol', Numeric(precision=9, scale=2)),
+ Column('floatcol1', Float()),
+ Column('floatcol2', FLOAT()),
+ Column('doubleprec', oracle.DOUBLE_PRECISION),
+ Column('numbercol1', oracle.NUMBER(9)),
+ Column('numbercol2', oracle.NUMBER(9, 3)),
+ Column('numbercol3', oracle.NUMBER),
+
+ )
+ t1.create()
+ try:
+ t1.insert().execute(
+ intcol=1,
+ numericcol=5.2,
+ floatcol1=6.5,
+ floatcol2 = 8.5,
+ doubleprec = 9.5,
+ numbercol1=12,
+ numbercol2=14.85,
+ numbercol3=15.76
+ )
+
+ m2 = MetaData(testing.db)
+ t2 = Table('t1', m2, autoload=True)
+
+ for row in (
+ t1.select().execute().first(),
+ t2.select().execute().first()
+ ):
+ for i, (val, type_) in enumerate((
+ (1, int),
+ (Decimal("5.2"), Decimal),
+ (6.5, float),
+ (8.5, float),
+ (9.5, float),
+ (12, int),
+ (Decimal("14.85"), Decimal),
+ (15.76, float),
+ )):
+ eq_(row[i], val)
+ assert isinstance(row[i], type_)
+
+ finally:
+ t1.drop()
+
+
def test_reflect_raw(self):
types_table = Table(
'all_types', MetaData(testing.db),
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index ccd6e5038..cede11cc5 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -205,7 +205,7 @@ class ColumnsTest(TestBase, AssertsExecutionResults):
db = testing.db
if testing.against('oracle'):
- expectedResults['float_column'] = 'float_column NUMERIC(25, 2)'
+ expectedResults['float_column'] = 'float_column FLOAT'
if testing.against('sqlite'):
expectedResults['float_column'] = 'float_column FLOAT'