diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-08-14 21:53:32 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-08-14 21:53:32 +0000 |
commit | 087f235c33c1be4e0778231e8344a50dc4005c59 (patch) | |
tree | d47c35d1e520e43c05ec869304870c0b6c87f736 /lib/sqlalchemy/databases | |
parent | e58063aa91d893d76e9f34fbc3ea21818185844d (diff) | |
download | sqlalchemy-087f235c33c1be4e0778231e8344a50dc4005c59.tar.gz |
- merged "fasttypes" branch. this branch changes the signature
of convert_bind_param() and convert_result_value() to callable-returning
bind_processor() and result_processor() methods. if no callable is
returned, no pre/post processing function is called.
- hooks added throughout base/sql/defaults to optimize the calling
of bind param/result processors so that method call overhead is minimized.
special cases added for executemany() scenarios such that unneeded "last row id"
logic doesn't kick in, parameters aren't excessively traversed.
- new performance tests show a combined mass-insert/mass-select test as having 68%
fewer function calls than the same test run against 0.3.
- general performance improvement of result set iteration is around 10-20%.
Diffstat (limited to 'lib/sqlalchemy/databases')
-rw-r--r-- | lib/sqlalchemy/databases/access.py | 72 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/informix.py | 85 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 180 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 217 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 130 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 68 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 74 |
7 files changed, 475 insertions, 351 deletions
diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py index 3c06822ae..6bf8b96e9 100644 --- a/lib/sqlalchemy/databases/access.py +++ b/lib/sqlalchemy/databases/access.py @@ -11,16 +11,18 @@ import sqlalchemy.engine.default as default class AcNumeric(types.Numeric): - def convert_result_value(self, value, dialect): - return value - - def convert_bind_param(self, value, dialect): - if value is None: - # Not sure that this exception is needed - return value - else: - return str(value) + def result_processor(self, dialect): + return None + def bind_processor(self, dialect): + def process(value): + if value is None: + # Not sure that this exception is needed + return value + else: + return str(value) + return process + def get_col_spec(self): return "NUMERIC" @@ -28,12 +30,14 @@ class AcFloat(types.Float): def get_col_spec(self): return "FLOAT" - def convert_bind_param(self, value, dialect): + def bind_processor(self, dialect): """By converting to string, we can use Decimal types round-trip.""" - if not value is None: - return str(value) - return None - + def process(value): + if not value is None: + return str(value) + return None + return process + class AcInteger(types.Integer): def get_col_spec(self): return "INTEGER" @@ -72,11 +76,11 @@ class AcUnicode(types.Unicode): def get_col_spec(self): return "TEXT" + (self.length and ("(%d)" % self.length) or "") - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - return value + def result_processor(self, dialect): + return None class AcChar(types.CHAR): def get_col_spec(self): @@ -90,21 +94,25 @@ class AcBoolean(types.Boolean): def get_col_spec(self): return "YESNO" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + class AcTimeStamp(types.TIMESTAMP): def get_col_spec(self): return "TIMESTAMP" diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index a3ef99916..21ecf1538 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -61,27 +61,33 @@ class InfoDateTime(sqltypes.DateTime ): def get_col_spec(self): return "DATETIME YEAR TO SECOND" - def convert_bind_param(self, value, dialect): - if value is not None: - if value.microsecond: - value = value.replace( microsecond = 0 ) - return value - + def bind_processor(self, dialect): + def process(value): + if value is not None: + if value.microsecond: + value = value.replace( microsecond = 0 ) + return value + return process + class InfoTime(sqltypes.Time ): def get_col_spec(self): return "DATETIME HOUR TO SECOND" - def convert_bind_param(self, value, dialect): - if value is not None: - if value.microsecond: - value = value.replace( microsecond = 0 ) - return value + def bind_processor(self, dialect): + def process(value): + if value is not None: + if value.microsecond: + value = value.replace( microsecond = 0 ) + return value + return process - def convert_result_value(self, value, dialect): - if isinstance( value , datetime.datetime ): - return value.time() - else: - return value + def result_processor(self, dialect): + def process(value): + if isinstance( value , datetime.datetime ): + return value.time() + else: + return value + return process class InfoText(sqltypes.String): def get_col_spec(self): @@ -91,36 +97,45 @@ class InfoString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} - def convert_bind_param( self , value , dialect ): - if value == '': - return None - else: - return value - + def bind_processor(self, dialect): + def process(value): + if value == '': + return None + else: + return value + return process + class InfoChar(sqltypes.CHAR): def get_col_spec(self): return "CHAR(%(length)s)" % {'length' : self.length} + class InfoBinary(sqltypes.Binary): def get_col_spec(self): return "BYTE" + class InfoBoolean(sqltypes.Boolean): default_type = 'NUM' def get_col_spec(self): return "SMALLINT" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process colspecs = { sqltypes.Integer : InfoInteger, diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 9ec0fbbc3..308a38a76 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -47,16 +47,18 @@ from sqlalchemy.engine import default import operator class MSNumeric(sqltypes.Numeric): - def convert_result_value(self, value, dialect): - return value - - def convert_bind_param(self, value, dialect): - if value is None: - # Not sure that this exception is needed - return value - else: - return str(value) + def result_processor(self, dialect): + return None + def bind_processor(self, dialect): + def process(value): + if value is None: + # Not sure that this exception is needed + return value + else: + return str(value) + return process + def get_col_spec(self): if self.precision is None: return "NUMERIC" @@ -67,12 +69,14 @@ class MSFloat(sqltypes.Float): def get_col_spec(self): return "FLOAT(%(precision)s)" % {'precision': self.precision} - def convert_bind_param(self, value, dialect): - """By converting to string, we can use Decimal types round-trip.""" - if not value is None: - return str(value) - return None - + def bind_processor(self, dialect): + def process(value): + """By converting to string, we can use Decimal types round-trip.""" + if not value is None: + return str(value) + return None + return process + class MSInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" @@ -108,57 +112,71 @@ class MSTime(sqltypes.Time): def get_col_spec(self): return "DATETIME" - def convert_bind_param(self, value, dialect): - if isinstance(value, datetime.datetime): - value = datetime.datetime.combine(self.__zero_date, value.time()) - elif isinstance(value, datetime.time): - value = datetime.datetime.combine(self.__zero_date, value) - return value - - def convert_result_value(self, value, dialect): - if isinstance(value, datetime.datetime): - return value.time() - elif isinstance(value, datetime.date): - return datetime.time(0, 0, 0) - return value - -class MSDateTime_adodbapi(MSDateTime): - def convert_result_value(self, value, dialect): - # adodbapi will return datetimes with empty time values as datetime.date() objects. - # Promote them back to full datetime.datetime() - if value and not hasattr(value, 'second'): - return datetime.datetime(value.year, value.month, value.day) - return value - -class MSDateTime_pyodbc(MSDateTime): - def convert_bind_param(self, value, dialect): - if value and not hasattr(value, 'second'): - return datetime.datetime(value.year, value.month, value.day) - else: + def bind_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + value = datetime.datetime.combine(self.__zero_date, value.time()) + elif isinstance(value, datetime.time): + value = datetime.datetime.combine(self.__zero_date, value) return value - -class MSDate_pyodbc(MSDate): - def convert_bind_param(self, value, dialect): - if value and not hasattr(value, 'second'): - return datetime.datetime(value.year, value.month, value.day) - else: + return process + + def result_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + elif isinstance(value, datetime.date): + return datetime.time(0, 0, 0) return value - - def convert_result_value(self, value, dialect): - # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date() - if value and hasattr(value, 'second'): - return value.date() - else: + return process + +class MSDateTime_adodbapi(MSDateTime): + def result_processor(self, dialect): + def process(value): + # adodbapi will return datetimes with empty time values as datetime.date() objects. + # Promote them back to full datetime.datetime() + if value and not hasattr(value, 'second'): + return datetime.datetime(value.year, value.month, value.day) return value - + return process + +class MSDateTime_pyodbc(MSDateTime): + def bind_processor(self, dialect): + def process(value): + if value and not hasattr(value, 'second'): + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + +class MSDate_pyodbc(MSDate): + def bind_processor(self, dialect): + def process(value): + if value and not hasattr(value, 'second'): + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + + def result_processor(self, dialect): + def process(value): + # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date() + if value and hasattr(value, 'second'): + return value.date() + else: + return value + return process + class MSDate_pymssql(MSDate): - def convert_result_value(self, value, dialect): - # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date() - if value and hasattr(value, 'second'): - return value.date() - else: - return value - + def result_processor(self, dialect): + def process(value): + # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date() + if value and hasattr(value, 'second'): + return value.date() + else: + return value + return process + class MSText(sqltypes.TEXT): def get_col_spec(self): if self.dialect.text_as_varchar: @@ -181,11 +199,11 @@ class MSNVarchar(sqltypes.Unicode): class AdoMSNVarchar(MSNVarchar): """overrides bindparam/result processing to not convert any unicode strings""" - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - return value + def result_processor(self, dialect): + return None class MSChar(sqltypes.CHAR): def get_col_spec(self): @@ -203,20 +221,24 @@ class MSBoolean(sqltypes.Boolean): def get_col_spec(self): return "BIT" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process class MSTimeStamp(sqltypes.TIMESTAMP): def get_col_spec(self): diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 01d3fa6bc..6d6f32ead 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -294,14 +294,20 @@ class MSNumeric(sqltypes.Numeric, _NumericType): else: return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}) - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - if not self.asdecimal and isinstance(value, util.decimal_type): - return float(value) + def result_processor(self, dialect): + if not self.asdecimal: + def process(value): + if isinstance(value, util.decimal_type): + return float(value) + else: + return value + return process else: - return value + return None + class MSDecimal(MSNumeric): @@ -408,8 +414,8 @@ class MSFloat(sqltypes.Float, _NumericType): else: return self._extend("FLOAT") - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None class MSInteger(sqltypes.Integer, _NumericType): @@ -539,16 +545,17 @@ class MSBit(sqltypes.TypeEngine): def __init__(self, length=None): self.length = length - def convert_result_value(self, value, dialect): + def result_processor(self, dialect): """Convert a MySQL's 64 bit, variable length binary string to a long.""" - - if value is not None: - v = 0L - for i in map(ord, value): - v = v << 8 | i - value = v - return value - + def process(value): + if value is not None: + v = 0L + for i in map(ord, value): + v = v << 8 | i + value = v + return value + return process + def get_col_spec(self): if self.length is not None: return "BIT(%s)" % self.length @@ -576,13 +583,14 @@ class MSTime(sqltypes.Time): def get_col_spec(self): return "TIME" - def convert_result_value(self, value, dialect): - # convert from a timedelta value - if value is not None: - return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) - else: - return None - + def result_processor(self, dialect): + def process(value): + # convert from a timedelta value + if value is not None: + return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) + else: + return None + return process class MSTimeStamp(sqltypes.TIMESTAMP): """MySQL TIMESTAMP type. @@ -930,12 +938,13 @@ class _BinaryType(sqltypes.Binary): else: return "BLOB" - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return buffer(value) - + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return buffer(value) + return process class MSVarBinary(_BinaryType): """MySQL VARBINARY type, for variable length binary data.""" @@ -976,12 +985,13 @@ class MSBinary(_BinaryType): else: return "BLOB" - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return buffer(value) - + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return buffer(value) + return process class MSBlob(_BinaryType): """MySQL BLOB type, for binary data up to 2^16 bytes""" @@ -1002,13 +1012,15 @@ class MSBlob(_BinaryType): return "BLOB(%d)" % self.length else: return "BLOB" - - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return buffer(value) - + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return buffer(value) + return process + def __repr__(self): return "%s()" % self.__class__.__name__ @@ -1094,12 +1106,18 @@ class MSEnum(MSString): length = max([len(v) for v in strip_enums]) super(MSEnum, self).__init__(length, **kw) - def convert_bind_param(self, value, engine): - if self.strict and value is not None and value not in self.enums: - raise exceptions.InvalidRequestError('"%s" not a valid value for ' - 'this enum' % value) - return super(MSEnum, self).convert_bind_param(value, engine) - + def bind_processor(self, dialect): + super_convert = super(MSEnum, self).bind_processor(dialect) + def process(value): + if self.strict and value is not None and value not in self.enums: + raise exceptions.InvalidRequestError('"%s" not a valid value for ' + 'this enum' % value) + if super_convert: + return super_convert(value) + else: + return value + return process + def get_col_spec(self): return self._extend("ENUM(%s)" % ",".join(self.__ddl_values)) @@ -1155,36 +1173,44 @@ class MSSet(MSString): length = max([len(v) for v in strip_values] + [0]) super(MSSet, self).__init__(length, **kw) - def convert_result_value(self, value, dialect): - # The good news: - # No ',' quoting issues- commas aren't allowed in SET values - # The bad news: - # Plenty of driver inconsistencies here. - if isinstance(value, util.set_types): - # ..some versions convert '' to an empty set - if not value: - value.add('') - # ..some return sets.Set, even for pythons that have __builtin__.set - if not isinstance(value, util.Set): - value = util.Set(value) - return value - # ...and some versions return strings - if value is not None: - return util.Set(value.split(',')) - else: - return value - - def convert_bind_param(self, value, engine): - if value is None or isinstance(value, (int, long, basestring)): - pass - else: - if None in value: - value = util.Set(value) - value.remove(None) - value.add('') - value = ','.join(value) - return super(MSSet, self).convert_bind_param(value, engine) - + def result_processor(self, dialect): + def process(value): + # The good news: + # No ',' quoting issues- commas aren't allowed in SET values + # The bad news: + # Plenty of driver inconsistencies here. + if isinstance(value, util.set_types): + # ..some versions convert '' to an empty set + if not value: + value.add('') + # ..some return sets.Set, even for pythons that have __builtin__.set + if not isinstance(value, util.Set): + value = util.Set(value) + return value + # ...and some versions return strings + if value is not None: + return util.Set(value.split(',')) + else: + return value + return process + + def bind_processor(self, dialect): + super_convert = super(MSSet, self).bind_processor(dialect) + def process(value): + if value is None or isinstance(value, (int, long, basestring)): + pass + else: + if None in value: + value = util.Set(value) + value.remove(None) + value.add('') + value = ','.join(value) + if super_convert: + return super_convert(value) + else: + return value + return process + def get_col_spec(self): return self._extend("SET(%s)" % ",".join(self.__ddl_values)) @@ -1195,21 +1221,24 @@ class MSBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOL" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process colspecs = { sqltypes.Integer: MSInteger, @@ -1284,7 +1313,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext): re.I | re.UNICODE) def post_exec(self): - if self.compiled.isinsert: + if self.compiled.isinsert and not self.executemany: if (not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None): self._last_inserted_ids = ([self.cursor.lastrowid] + diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 2c45c94e8..520332d45 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -32,26 +32,31 @@ class OracleSmallInteger(sqltypes.Smallinteger): class OracleDate(sqltypes.Date): def get_col_spec(self): return "DATE" - def convert_bind_param(self, value, dialect): - return value - def convert_result_value(self, value, dialect): - if not isinstance(value, datetime.datetime): - return value - else: - return value.date() + def bind_processor(self, dialect): + return None + def result_processor(self, dialect): + def process(value): + if not isinstance(value, datetime.datetime): + return value + else: + return value.date() + return process + class OracleDateTime(sqltypes.DateTime): def get_col_spec(self): return "DATE" - def convert_result_value(self, value, dialect): - if value is None or isinstance(value,datetime.datetime): - return value - else: - # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year,value.month, - value.day,value.hour, value.minute, value.second) - + def result_processor(self, dialect): + def process(value): + if value is None or isinstance(value,datetime.datetime): + return value + else: + # convert cx_oracle datetime object returned pre-python 2.4 + return datetime.datetime(value.year,value.month, + value.day,value.hour, value.minute, value.second) + return process + # Note: # Oracle DATE == DATETIME # Oracle does not allow milliseconds in DATE @@ -65,14 +70,15 @@ class OracleTimestamp(sqltypes.TIMESTAMP): def get_dbapi_type(self, dialect): return dialect.TIMESTAMP - def convert_result_value(self, value, dialect): - if value is None or isinstance(value,datetime.datetime): - return value - else: - # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year,value.month, - value.day,value.hour, value.minute, value.second) - + def result_processor(self, dialect): + def process(value): + if value is None or isinstance(value,datetime.datetime): + return value + else: + # convert cx_oracle datetime object returned pre-python 2.4 + return datetime.datetime(value.year,value.month, + value.day,value.hour, value.minute, value.second) + return process class OracleString(sqltypes.String): def get_col_spec(self): @@ -85,15 +91,23 @@ class OracleText(sqltypes.TEXT): def get_col_spec(self): return "CLOB" - def convert_result_value(self, value, dialect): - if value is None: - return None - elif hasattr(value, 'read'): - # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str - return super(OracleText, self).convert_result_value(value.read(), dialect) - else: - return super(OracleText, self).convert_result_value(value, dialect) - + def result_processor(self, dialect): + super_process = super(OracleText, self).result_processor(dialect) + def process(value): + if value is None: + return None + elif hasattr(value, 'read'): + # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str + if super_process: + return super_process(value.read()) + else: + return value.read() + else: + if super_process: + return super_process(value) + else: + return value + return process class OracleRaw(sqltypes.Binary): def get_col_spec(self): @@ -110,34 +124,40 @@ class OracleBinary(sqltypes.Binary): def get_col_spec(self): return "BLOB" - def convert_bind_param(self, value, dialect): - return value - - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return value.read() + def bind_processor(self, dialect): + return None + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return value.read() + return process + class OracleBoolean(sqltypes.Boolean): def get_col_spec(self): return "SMALLINT" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + colspecs = { sqltypes.Integer : OracleInteger, sqltypes.Smallinteger : OracleSmallInteger, @@ -196,7 +216,7 @@ class OracleExecutionContext(default.DefaultExecutionContext): if self.compiled_parameters is not None: for k in self.out_parameters: type = self.compiled_parameters.get_type(k) - self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect) + self.out_parameters[k] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[k].getvalue()) else: for k in self.out_parameters: self.out_parameters[k] = self.out_parameters[k].getvalue() diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index a30832b43..e4897bba6 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -22,14 +22,19 @@ class PGNumeric(sqltypes.Numeric): else: return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - if not self.asdecimal and isinstance(value, util.decimal_type): - return float(value) + def result_processor(self, dialect): + if self.asdecimal: + return None else: - return value + def process(value): + if isinstance(value, util.decimal_type): + return float(value) + else: + return value + return process class PGFloat(sqltypes.Float): def get_col_spec(self): @@ -98,25 +103,38 @@ class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable): impl.__dict__.update(self.__dict__) impl.item_type = self.item_type.dialect_impl(dialect) return impl - def convert_bind_param(self, value, dialect): - if value is None: - return value - def convert_item(item): - if isinstance(item, (list,tuple)): - return [convert_item(child) for child in item] - else: - return self.item_type.convert_bind_param(item, dialect) - return [convert_item(item) for item in value] - def convert_result_value(self, value, dialect): - if value is None: - return value - def convert_item(item): - if isinstance(item, list): - return [convert_item(child) for child in item] - else: - return self.item_type.convert_result_value(item, dialect) - # Could specialcase when item_type.convert_result_value is the default identity func - return [convert_item(item) for item in value] + + def bind_processor(self, dialect): + item_proc = self.item_type.bind_processor(dialect) + def process(value): + if value is None: + return value + def convert_item(item): + if isinstance(item, (list,tuple)): + return [convert_item(child) for child in item] + else: + if item_proc: + return item_proc(item) + else: + return item + return [convert_item(item) for item in value] + return process + + def result_processor(self, dialect): + item_proc = self.item_type.bind_processor(dialect) + def process(value): + if value is None: + return value + def convert_item(item): + if isinstance(item, list): + return [convert_item(child) for child in item] + else: + if item_proc: + return item_proc(item) + else: + return item + return [convert_item(item) for item in value] + return process def get_col_spec(self): return self.item_type.get_col_spec() + '[]' diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 7999cc403..3cc821a36 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -32,15 +32,17 @@ class SLSmallInteger(sqltypes.Smallinteger): return "SMALLINT" class DateTimeMixin(object): - def convert_bind_param(self, value, dialect): - if value is not None: - if getattr(value, 'microsecond', None) is not None: - return value.strftime(self.__format__ + "." + str(value.microsecond)) + def bind_processor(self, dialect): + def process(value): + if value is not None: + if getattr(value, 'microsecond', None) is not None: + return value.strftime(self.__format__ + "." + str(value.microsecond)) + else: + return value.strftime(self.__format__) else: - return value.strftime(self.__format__) - else: - return None - + return None + return process + def _cvt(self, value, dialect): if value is None: return None @@ -57,30 +59,36 @@ class SLDateTime(DateTimeMixin,sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP" - def convert_result_value(self, value, dialect): - tup = self._cvt(value, dialect) - return tup and datetime.datetime(*tup) - + def result_processor(self, dialect): + def process(value): + tup = self._cvt(value, dialect) + return tup and datetime.datetime(*tup) + return process + class SLDate(DateTimeMixin, sqltypes.Date): __format__ = "%Y-%m-%d" def get_col_spec(self): return "DATE" - def convert_result_value(self, value, dialect): - tup = self._cvt(value, dialect) - return tup and datetime.date(*tup[0:3]) - + def result_processor(self, dialect): + def process(value): + tup = self._cvt(value, dialect) + return tup and datetime.date(*tup[0:3]) + return process + class SLTime(DateTimeMixin, sqltypes.Time): __format__ = "%H:%M:%S" def get_col_spec(self): return "TIME" - def convert_result_value(self, value, dialect): - tup = self._cvt(value, dialect) - return tup and datetime.time(*tup[3:7]) - + def result_processor(self, dialect): + def process(value): + tup = self._cvt(value, dialect) + return tup and datetime.time(*tup[3:7]) + return process + class SLText(sqltypes.TEXT): def get_col_spec(self): return "TEXT" @@ -101,16 +109,20 @@ class SLBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" - def convert_bind_param(self, value, dialect): - if value is None: - return None - return value and 1 or 0 - - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + return value and 1 or 0 + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + colspecs = { sqltypes.Integer : SLInteger, sqltypes.Smallinteger : SLSmallInteger, @@ -150,7 +162,7 @@ def descriptor(): class SQLiteExecutionContext(default.DefaultExecutionContext): def post_exec(self): - if self.compiled.isinsert: + if self.compiled.isinsert and not self.executemany: if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] |