summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES5
-rw-r--r--lib/sqlalchemy/processors.py32
-rw-r--r--lib/sqlalchemy/test/util.py11
-rw-r--r--lib/sqlalchemy/types.py12
-rw-r--r--test/dialect/test_postgresql.py14
-rw-r--r--test/sql/test_types.py28
6 files changed, 80 insertions, 22 deletions
diff --git a/CHANGES b/CHANGES
index 5e0d3662d..05d53ffa5 100644
--- a/CHANGES
+++ b/CHANGES
@@ -171,6 +171,11 @@ CHANGES
not new). An error is now raised if a Column() has no type
and no foreign keys. [ticket:1705]
+ - the "scale" argument of the Numeric() type is honored when
+ coercing a returned floating point value into a string
+ on its way to Decimal - this allows accuracy to function
+ on SQLite, MySQL. [ticket:1717]
+
- engines
- Added an optional C extension to speed up the sql layer by
reimplementing RowProxy and the most common result processors.
diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py
index 4cf6831bd..04fa5054a 100644
--- a/lib/sqlalchemy/processors.py
+++ b/lib/sqlalchemy/processors.py
@@ -38,9 +38,10 @@ try:
return UnicodeResultProcessor(encoding, errors).process
else:
return UnicodeResultProcessor(encoding).process
-
- def to_decimal_processor_factory(target_class):
- return DecimalResultProcessor(target_class).process
+
+ # TODO: add scale argument
+ #def to_decimal_processor_factory(target_class):
+ # return DecimalResultProcessor(target_class).process
except ImportError:
def to_unicode_processor_factory(encoding, errors=None):
@@ -57,13 +58,14 @@ except ImportError:
return decoder(value, errors)[0]
return process
- def to_decimal_processor_factory(target_class):
- def process(value):
- if value is None:
- return None
- else:
- return target_class(str(value))
- return process
+ # TODO: add scale argument
+ #def to_decimal_processor_factory(target_class):
+ # def process(value):
+ # if value is None:
+ # return None
+ # else:
+ # return target_class(str(value))
+ # return process
def to_float(value):
if value is None:
@@ -92,3 +94,13 @@ except ImportError:
str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time)
str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date)
+
+def to_decimal_processor_factory(target_class, scale=10):
+ fstring = "%%.%df" % scale
+
+ def process(value):
+ if value is None:
+ return None
+ else:
+ return target_class(fstring % value)
+ return process
diff --git a/lib/sqlalchemy/test/util.py b/lib/sqlalchemy/test/util.py
index 5be00f906..8a3a0e745 100644
--- a/lib/sqlalchemy/test/util.py
+++ b/lib/sqlalchemy/test/util.py
@@ -39,4 +39,15 @@ def picklers():
for pickle in picklers:
for protocol in -1, 0, 1, 2:
yield pickle.loads, lambda d:pickle.dumps(d, protocol)
+
+
+def round_decimal(value, prec):
+ if isinstance(value, float):
+ return round(value, prec)
+
+ import decimal
+
+ # can also use shift() here but that is 2.6 only
+ return (value * decimal.Decimal("1" + "0" * prec)).to_integral(decimal.ROUND_FLOOR) / \
+ pow(10, prec)
\ No newline at end of file
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index d5f1d9f14..d7b8f9289 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -838,7 +838,7 @@ class Numeric(_DateAffinity, TypeEngine):
# try:
# from fastdec import mpd as Decimal
# except ImportError:
- return processors.to_decimal_processor_factory(_python_Decimal)
+ return processors.to_decimal_processor_factory(_python_Decimal, self.scale)
else:
return None
@@ -877,6 +877,16 @@ class Float(Numeric):
def adapt(self, impltype):
return impltype(precision=self.precision, asdecimal=self.asdecimal)
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ #XXX: use decimal from http://www.bytereef.org/libmpdec.html
+# try:
+# from fastdec import mpd as Decimal
+# except ImportError:
+ return processors.to_decimal_processor_factory(_python_Decimal)
+ else:
+ return None
+
class DateTime(_DateAffinity, TypeEngine):
"""A type for ``datetime.datetime()`` objects.
diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py
index 1a21ec11f..b002e7f19 100644
--- a/test/dialect/test_postgresql.py
+++ b/test/dialect/test_postgresql.py
@@ -9,6 +9,7 @@ from sqlalchemy import exc, schema, types
from sqlalchemy.dialects.postgresql import base as postgresql
from sqlalchemy.engine.strategies import MockEngineStrategy
from sqlalchemy.test import *
+from sqlalchemy.test.util import round_decimal
from sqlalchemy.sql import table, column
from sqlalchemy.test.testing import eq_
from test.engine._base import TablesTest
@@ -203,15 +204,6 @@ class FloatCoercionTest(TablesTest, AssertsExecutionResults):
{'data':9},
)
- def _round(self, x):
- if isinstance(x, float):
- return round(x, 9)
- elif isinstance(x, decimal.Decimal):
- # really ?
- # (can also use shift() here but that is 2.6 only)
- x = (x * decimal.Decimal("1000000000")).to_integral() / pow(10, 9)
- return x
-
@testing.resolve_artifact_names
def test_float_coercion(self):
for type_, result in [
@@ -226,14 +218,14 @@ class FloatCoercionTest(TablesTest, AssertsExecutionResults):
])
).scalar()
- eq_(self._round(ret), result)
+ eq_(round_decimal(ret, 9), result)
ret = testing.db.execute(
select([
cast(func.stddev_pop(data_table.c.data), type_)
])
).scalar()
- eq_(self._round(ret), result)
+ eq_(round_decimal(ret, 9), result)
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 4b2370afc..53f4d8d91 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -1024,6 +1024,34 @@ class NumericTest(TestBase, AssertsExecutionResults):
(2, 3.5, 5.6, Decimal("12.4"), Decimal("15.75")),
])
+ @testing.fails_if(_missing_decimal)
+ def test_precision_decimal(self):
+ from decimal import Decimal
+ from sqlalchemy.test.util import round_decimal
+
+ t = Table('t', MetaData(), Column('x', Numeric(precision=18, scale=12)))
+ t.create(testing.db)
+ try:
+ numbers = set(
+ [
+ decimal.Decimal("54.234246451650"),
+ decimal.Decimal("876734.594069654000"),
+ decimal.Decimal("0.004354"),
+ decimal.Decimal("900.0"),
+ ])
+
+ testing.db.execute(t.insert(), [{'x':x} for x in numbers])
+
+ ret = set([row[0] for row in testing.db.execute(t.select()).fetchall()])
+
+ numbers = set(round_decimal(n, 11) for n in numbers)
+ ret = set(round_decimal(n, 11) for n in ret)
+
+ eq_(numbers, ret)
+ finally:
+ t.drop(testing.db)
+
+
def test_decimal_fallback(self):
from decimal import Decimal