diff options
-rw-r--r-- | CHANGES | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/processors.py | 32 | ||||
-rw-r--r-- | lib/sqlalchemy/test/util.py | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/types.py | 12 | ||||
-rw-r--r-- | test/dialect/test_postgresql.py | 14 | ||||
-rw-r--r-- | test/sql/test_types.py | 28 |
6 files changed, 80 insertions, 22 deletions
@@ -171,6 +171,11 @@ CHANGES not new). An error is now raised if a Column() has no type and no foreign keys. [ticket:1705] + - the "scale" argument of the Numeric() type is honored when + coercing a returned floating point value into a string + on its way to Decimal - this allows accuracy to function + on SQLite, MySQL. [ticket:1717] + - engines - Added an optional C extension to speed up the sql layer by reimplementing RowProxy and the most common result processors. diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py index 4cf6831bd..04fa5054a 100644 --- a/lib/sqlalchemy/processors.py +++ b/lib/sqlalchemy/processors.py @@ -38,9 +38,10 @@ try: return UnicodeResultProcessor(encoding, errors).process else: return UnicodeResultProcessor(encoding).process - - def to_decimal_processor_factory(target_class): - return DecimalResultProcessor(target_class).process + + # TODO: add scale argument + #def to_decimal_processor_factory(target_class): + # return DecimalResultProcessor(target_class).process except ImportError: def to_unicode_processor_factory(encoding, errors=None): @@ -57,13 +58,14 @@ except ImportError: return decoder(value, errors)[0] return process - def to_decimal_processor_factory(target_class): - def process(value): - if value is None: - return None - else: - return target_class(str(value)) - return process + # TODO: add scale argument + #def to_decimal_processor_factory(target_class): + # def process(value): + # if value is None: + # return None + # else: + # return target_class(str(value)) + # return process def to_float(value): if value is None: @@ -92,3 +94,13 @@ except ImportError: str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time) str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date) + +def to_decimal_processor_factory(target_class, scale=10): + fstring = "%%.%df" % scale + + def process(value): + if value is None: + return None + else: + return target_class(fstring % value) + return process diff --git a/lib/sqlalchemy/test/util.py b/lib/sqlalchemy/test/util.py index 5be00f906..8a3a0e745 100644 --- a/lib/sqlalchemy/test/util.py +++ b/lib/sqlalchemy/test/util.py @@ -39,4 +39,15 @@ def picklers(): for pickle in picklers: for protocol in -1, 0, 1, 2: yield pickle.loads, lambda d:pickle.dumps(d, protocol) + + +def round_decimal(value, prec): + if isinstance(value, float): + return round(value, prec) + + import decimal + + # can also use shift() here but that is 2.6 only + return (value * decimal.Decimal("1" + "0" * prec)).to_integral(decimal.ROUND_FLOOR) / \ + pow(10, prec)
\ No newline at end of file diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index d5f1d9f14..d7b8f9289 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -838,7 +838,7 @@ class Numeric(_DateAffinity, TypeEngine): # try: # from fastdec import mpd as Decimal # except ImportError: - return processors.to_decimal_processor_factory(_python_Decimal) + return processors.to_decimal_processor_factory(_python_Decimal, self.scale) else: return None @@ -877,6 +877,16 @@ class Float(Numeric): def adapt(self, impltype): return impltype(precision=self.precision, asdecimal=self.asdecimal) + def result_processor(self, dialect, coltype): + if self.asdecimal: + #XXX: use decimal from http://www.bytereef.org/libmpdec.html +# try: +# from fastdec import mpd as Decimal +# except ImportError: + return processors.to_decimal_processor_factory(_python_Decimal) + else: + return None + class DateTime(_DateAffinity, TypeEngine): """A type for ``datetime.datetime()`` objects. diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index 1a21ec11f..b002e7f19 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -9,6 +9,7 @@ from sqlalchemy import exc, schema, types from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.engine.strategies import MockEngineStrategy from sqlalchemy.test import * +from sqlalchemy.test.util import round_decimal from sqlalchemy.sql import table, column from sqlalchemy.test.testing import eq_ from test.engine._base import TablesTest @@ -203,15 +204,6 @@ class FloatCoercionTest(TablesTest, AssertsExecutionResults): {'data':9}, ) - def _round(self, x): - if isinstance(x, float): - return round(x, 9) - elif isinstance(x, decimal.Decimal): - # really ? - # (can also use shift() here but that is 2.6 only) - x = (x * decimal.Decimal("1000000000")).to_integral() / pow(10, 9) - return x - @testing.resolve_artifact_names def test_float_coercion(self): for type_, result in [ @@ -226,14 +218,14 @@ class FloatCoercionTest(TablesTest, AssertsExecutionResults): ]) ).scalar() - eq_(self._round(ret), result) + eq_(round_decimal(ret, 9), result) ret = testing.db.execute( select([ cast(func.stddev_pop(data_table.c.data), type_) ]) ).scalar() - eq_(self._round(ret), result) + eq_(round_decimal(ret, 9), result) diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 4b2370afc..53f4d8d91 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1024,6 +1024,34 @@ class NumericTest(TestBase, AssertsExecutionResults): (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.75")), ]) + @testing.fails_if(_missing_decimal) + def test_precision_decimal(self): + from decimal import Decimal + from sqlalchemy.test.util import round_decimal + + t = Table('t', MetaData(), Column('x', Numeric(precision=18, scale=12))) + t.create(testing.db) + try: + numbers = set( + [ + decimal.Decimal("54.234246451650"), + decimal.Decimal("876734.594069654000"), + decimal.Decimal("0.004354"), + decimal.Decimal("900.0"), + ]) + + testing.db.execute(t.insert(), [{'x':x} for x in numbers]) + + ret = set([row[0] for row in testing.db.execute(t.select()).fetchall()]) + + numbers = set(round_decimal(n, 11) for n in numbers) + ret = set(round_decimal(n, 11) for n in ret) + + eq_(numbers, ret) + finally: + t.drop(testing.db) + + def test_decimal_fallback(self): from decimal import Decimal |