diff options
-rw-r--r-- | CHANGES | 21 | ||||
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 20 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/access.py | 72 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/informix.py | 85 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 180 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 217 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 130 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 68 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 74 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 24 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 91 | ||||
-rw-r--r-- | lib/sqlalchemy/sql.py | 63 | ||||
-rw-r--r-- | lib/sqlalchemy/types.py | 177 | ||||
-rw-r--r-- | test/orm/assorted_eager.py | 2 | ||||
-rw-r--r-- | test/orm/unitofwork.py | 7 | ||||
-rw-r--r-- | test/sql/defaults.py | 39 | ||||
-rw-r--r-- | test/sql/select.py | 6 | ||||
-rw-r--r-- | test/sql/testtypes.py | 43 | ||||
-rw-r--r-- | test/sql/unicode.py | 2 | ||||
-rw-r--r-- | test/testlib/testing.py | 2 |
20 files changed, 789 insertions, 534 deletions
@@ -1,5 +1,9 @@ 0.4.0 - orm + - speed ! along with recent speedups to ResultProxy, total number of + function calls significantly reduced for large loads. + test/perf/masseagerload.py reports 0.4 as having the fewest number + of function calls across all SA versions (0.1, 0.2, and 0.3) - new collection_class api and implementation [ticket:213] collections are now instrumented via decorations rather than @@ -135,11 +139,6 @@ - improved support for custom column_property() attributes which feature correlated subqueries...work better with eager loading now. - - along with recent speedups to ResultProxy, total number of - function calls significantly reduced for large loads. - test/perf/masseagerload.py reports 0.4 as having the fewest number - of function calls across all SA versions (0.1, 0.2, and 0.3) - - primary key "collapse" behavior; the mapper will analyze all columns in its given selectable for primary key "equivalence", that is, columns which are equivalent via foreign key relationship or via an @@ -182,6 +181,11 @@ style of Hibernate - sql + - speed ! clause compilation as well as the mechanics of SQL constructs + have been streamlined and simplified to a signficant degree, for a + 20-30% improvement of the statement construction/compilation overhead of + 0.3 + - all "type" keyword arguments, such as those to bindparam(), column(), Column(), and func.<something>(), renamed to "type_". those objects still name their "type" attribute as "type". @@ -276,8 +280,15 @@ semantics for "__contains__" [ticket:606] - engines + - speed ! the mechanics of result processing and bind parameter processing + have been overhauled, streamlined and optimized to issue as little method + calls as possible. bench tests for mass INSERT and mass rowset iteration + both show 0.4 to be over twice as fast as 0.3, using 68% fewer function + calls. + - You can now hook into the pool lifecycle and run SQL statements or other logic at new each DBAPI connection, pool check-out and check-in. + - Connections gain a .properties collection, with contents scoped to the lifetime of the underlying DBAPI connection - removed auto_close_cursors and disallow_open_cursors arguments from Pool; diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 4d50b6a25..dd4065f39 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -252,24 +252,14 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): for a single statement execution, or one element of an executemany execution. """ - if self.parameters is not None: - bindparams = self.parameters.copy() - else: - bindparams = {} - bindparams.update(params) d = sql.ClauseParameters(self.dialect, self.positiontup) - for b in self.binds.values(): - name = self.bind_names[b] - d.set_parameter(b, b.value, name) - for key, value in bindparams.iteritems(): - try: - b = self.binds[key] - except KeyError: - continue - name = self.bind_names[b] - d.set_parameter(b, value, name) + pd = self.parameters or {} + pd.update(params) + for key, bind in self.binds.iteritems(): + d.set_parameter(bind, pd.get(key, bind.value), self.bind_names[bind]) + return d params = property(lambda self:self.construct_params({}), doc="""Return the `ClauseParameters` corresponding to this compiled object. diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py index 3c06822ae..6bf8b96e9 100644 --- a/lib/sqlalchemy/databases/access.py +++ b/lib/sqlalchemy/databases/access.py @@ -11,16 +11,18 @@ import sqlalchemy.engine.default as default class AcNumeric(types.Numeric): - def convert_result_value(self, value, dialect): - return value - - def convert_bind_param(self, value, dialect): - if value is None: - # Not sure that this exception is needed - return value - else: - return str(value) + def result_processor(self, dialect): + return None + def bind_processor(self, dialect): + def process(value): + if value is None: + # Not sure that this exception is needed + return value + else: + return str(value) + return process + def get_col_spec(self): return "NUMERIC" @@ -28,12 +30,14 @@ class AcFloat(types.Float): def get_col_spec(self): return "FLOAT" - def convert_bind_param(self, value, dialect): + def bind_processor(self, dialect): """By converting to string, we can use Decimal types round-trip.""" - if not value is None: - return str(value) - return None - + def process(value): + if not value is None: + return str(value) + return None + return process + class AcInteger(types.Integer): def get_col_spec(self): return "INTEGER" @@ -72,11 +76,11 @@ class AcUnicode(types.Unicode): def get_col_spec(self): return "TEXT" + (self.length and ("(%d)" % self.length) or "") - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - return value + def result_processor(self, dialect): + return None class AcChar(types.CHAR): def get_col_spec(self): @@ -90,21 +94,25 @@ class AcBoolean(types.Boolean): def get_col_spec(self): return "YESNO" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + class AcTimeStamp(types.TIMESTAMP): def get_col_spec(self): return "TIMESTAMP" diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index a3ef99916..21ecf1538 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -61,27 +61,33 @@ class InfoDateTime(sqltypes.DateTime ): def get_col_spec(self): return "DATETIME YEAR TO SECOND" - def convert_bind_param(self, value, dialect): - if value is not None: - if value.microsecond: - value = value.replace( microsecond = 0 ) - return value - + def bind_processor(self, dialect): + def process(value): + if value is not None: + if value.microsecond: + value = value.replace( microsecond = 0 ) + return value + return process + class InfoTime(sqltypes.Time ): def get_col_spec(self): return "DATETIME HOUR TO SECOND" - def convert_bind_param(self, value, dialect): - if value is not None: - if value.microsecond: - value = value.replace( microsecond = 0 ) - return value + def bind_processor(self, dialect): + def process(value): + if value is not None: + if value.microsecond: + value = value.replace( microsecond = 0 ) + return value + return process - def convert_result_value(self, value, dialect): - if isinstance( value , datetime.datetime ): - return value.time() - else: - return value + def result_processor(self, dialect): + def process(value): + if isinstance( value , datetime.datetime ): + return value.time() + else: + return value + return process class InfoText(sqltypes.String): def get_col_spec(self): @@ -91,36 +97,45 @@ class InfoString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} - def convert_bind_param( self , value , dialect ): - if value == '': - return None - else: - return value - + def bind_processor(self, dialect): + def process(value): + if value == '': + return None + else: + return value + return process + class InfoChar(sqltypes.CHAR): def get_col_spec(self): return "CHAR(%(length)s)" % {'length' : self.length} + class InfoBinary(sqltypes.Binary): def get_col_spec(self): return "BYTE" + class InfoBoolean(sqltypes.Boolean): default_type = 'NUM' def get_col_spec(self): return "SMALLINT" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process colspecs = { sqltypes.Integer : InfoInteger, diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 9ec0fbbc3..308a38a76 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -47,16 +47,18 @@ from sqlalchemy.engine import default import operator class MSNumeric(sqltypes.Numeric): - def convert_result_value(self, value, dialect): - return value - - def convert_bind_param(self, value, dialect): - if value is None: - # Not sure that this exception is needed - return value - else: - return str(value) + def result_processor(self, dialect): + return None + def bind_processor(self, dialect): + def process(value): + if value is None: + # Not sure that this exception is needed + return value + else: + return str(value) + return process + def get_col_spec(self): if self.precision is None: return "NUMERIC" @@ -67,12 +69,14 @@ class MSFloat(sqltypes.Float): def get_col_spec(self): return "FLOAT(%(precision)s)" % {'precision': self.precision} - def convert_bind_param(self, value, dialect): - """By converting to string, we can use Decimal types round-trip.""" - if not value is None: - return str(value) - return None - + def bind_processor(self, dialect): + def process(value): + """By converting to string, we can use Decimal types round-trip.""" + if not value is None: + return str(value) + return None + return process + class MSInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" @@ -108,57 +112,71 @@ class MSTime(sqltypes.Time): def get_col_spec(self): return "DATETIME" - def convert_bind_param(self, value, dialect): - if isinstance(value, datetime.datetime): - value = datetime.datetime.combine(self.__zero_date, value.time()) - elif isinstance(value, datetime.time): - value = datetime.datetime.combine(self.__zero_date, value) - return value - - def convert_result_value(self, value, dialect): - if isinstance(value, datetime.datetime): - return value.time() - elif isinstance(value, datetime.date): - return datetime.time(0, 0, 0) - return value - -class MSDateTime_adodbapi(MSDateTime): - def convert_result_value(self, value, dialect): - # adodbapi will return datetimes with empty time values as datetime.date() objects. - # Promote them back to full datetime.datetime() - if value and not hasattr(value, 'second'): - return datetime.datetime(value.year, value.month, value.day) - return value - -class MSDateTime_pyodbc(MSDateTime): - def convert_bind_param(self, value, dialect): - if value and not hasattr(value, 'second'): - return datetime.datetime(value.year, value.month, value.day) - else: + def bind_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + value = datetime.datetime.combine(self.__zero_date, value.time()) + elif isinstance(value, datetime.time): + value = datetime.datetime.combine(self.__zero_date, value) return value - -class MSDate_pyodbc(MSDate): - def convert_bind_param(self, value, dialect): - if value and not hasattr(value, 'second'): - return datetime.datetime(value.year, value.month, value.day) - else: + return process + + def result_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + elif isinstance(value, datetime.date): + return datetime.time(0, 0, 0) return value - - def convert_result_value(self, value, dialect): - # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date() - if value and hasattr(value, 'second'): - return value.date() - else: + return process + +class MSDateTime_adodbapi(MSDateTime): + def result_processor(self, dialect): + def process(value): + # adodbapi will return datetimes with empty time values as datetime.date() objects. + # Promote them back to full datetime.datetime() + if value and not hasattr(value, 'second'): + return datetime.datetime(value.year, value.month, value.day) return value - + return process + +class MSDateTime_pyodbc(MSDateTime): + def bind_processor(self, dialect): + def process(value): + if value and not hasattr(value, 'second'): + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + +class MSDate_pyodbc(MSDate): + def bind_processor(self, dialect): + def process(value): + if value and not hasattr(value, 'second'): + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + + def result_processor(self, dialect): + def process(value): + # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date() + if value and hasattr(value, 'second'): + return value.date() + else: + return value + return process + class MSDate_pymssql(MSDate): - def convert_result_value(self, value, dialect): - # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date() - if value and hasattr(value, 'second'): - return value.date() - else: - return value - + def result_processor(self, dialect): + def process(value): + # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date() + if value and hasattr(value, 'second'): + return value.date() + else: + return value + return process + class MSText(sqltypes.TEXT): def get_col_spec(self): if self.dialect.text_as_varchar: @@ -181,11 +199,11 @@ class MSNVarchar(sqltypes.Unicode): class AdoMSNVarchar(MSNVarchar): """overrides bindparam/result processing to not convert any unicode strings""" - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - return value + def result_processor(self, dialect): + return None class MSChar(sqltypes.CHAR): def get_col_spec(self): @@ -203,20 +221,24 @@ class MSBoolean(sqltypes.Boolean): def get_col_spec(self): return "BIT" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process class MSTimeStamp(sqltypes.TIMESTAMP): def get_col_spec(self): diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 01d3fa6bc..6d6f32ead 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -294,14 +294,20 @@ class MSNumeric(sqltypes.Numeric, _NumericType): else: return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}) - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - if not self.asdecimal and isinstance(value, util.decimal_type): - return float(value) + def result_processor(self, dialect): + if not self.asdecimal: + def process(value): + if isinstance(value, util.decimal_type): + return float(value) + else: + return value + return process else: - return value + return None + class MSDecimal(MSNumeric): @@ -408,8 +414,8 @@ class MSFloat(sqltypes.Float, _NumericType): else: return self._extend("FLOAT") - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None class MSInteger(sqltypes.Integer, _NumericType): @@ -539,16 +545,17 @@ class MSBit(sqltypes.TypeEngine): def __init__(self, length=None): self.length = length - def convert_result_value(self, value, dialect): + def result_processor(self, dialect): """Convert a MySQL's 64 bit, variable length binary string to a long.""" - - if value is not None: - v = 0L - for i in map(ord, value): - v = v << 8 | i - value = v - return value - + def process(value): + if value is not None: + v = 0L + for i in map(ord, value): + v = v << 8 | i + value = v + return value + return process + def get_col_spec(self): if self.length is not None: return "BIT(%s)" % self.length @@ -576,13 +583,14 @@ class MSTime(sqltypes.Time): def get_col_spec(self): return "TIME" - def convert_result_value(self, value, dialect): - # convert from a timedelta value - if value is not None: - return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) - else: - return None - + def result_processor(self, dialect): + def process(value): + # convert from a timedelta value + if value is not None: + return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) + else: + return None + return process class MSTimeStamp(sqltypes.TIMESTAMP): """MySQL TIMESTAMP type. @@ -930,12 +938,13 @@ class _BinaryType(sqltypes.Binary): else: return "BLOB" - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return buffer(value) - + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return buffer(value) + return process class MSVarBinary(_BinaryType): """MySQL VARBINARY type, for variable length binary data.""" @@ -976,12 +985,13 @@ class MSBinary(_BinaryType): else: return "BLOB" - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return buffer(value) - + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return buffer(value) + return process class MSBlob(_BinaryType): """MySQL BLOB type, for binary data up to 2^16 bytes""" @@ -1002,13 +1012,15 @@ class MSBlob(_BinaryType): return "BLOB(%d)" % self.length else: return "BLOB" - - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return buffer(value) - + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return buffer(value) + return process + def __repr__(self): return "%s()" % self.__class__.__name__ @@ -1094,12 +1106,18 @@ class MSEnum(MSString): length = max([len(v) for v in strip_enums]) super(MSEnum, self).__init__(length, **kw) - def convert_bind_param(self, value, engine): - if self.strict and value is not None and value not in self.enums: - raise exceptions.InvalidRequestError('"%s" not a valid value for ' - 'this enum' % value) - return super(MSEnum, self).convert_bind_param(value, engine) - + def bind_processor(self, dialect): + super_convert = super(MSEnum, self).bind_processor(dialect) + def process(value): + if self.strict and value is not None and value not in self.enums: + raise exceptions.InvalidRequestError('"%s" not a valid value for ' + 'this enum' % value) + if super_convert: + return super_convert(value) + else: + return value + return process + def get_col_spec(self): return self._extend("ENUM(%s)" % ",".join(self.__ddl_values)) @@ -1155,36 +1173,44 @@ class MSSet(MSString): length = max([len(v) for v in strip_values] + [0]) super(MSSet, self).__init__(length, **kw) - def convert_result_value(self, value, dialect): - # The good news: - # No ',' quoting issues- commas aren't allowed in SET values - # The bad news: - # Plenty of driver inconsistencies here. - if isinstance(value, util.set_types): - # ..some versions convert '' to an empty set - if not value: - value.add('') - # ..some return sets.Set, even for pythons that have __builtin__.set - if not isinstance(value, util.Set): - value = util.Set(value) - return value - # ...and some versions return strings - if value is not None: - return util.Set(value.split(',')) - else: - return value - - def convert_bind_param(self, value, engine): - if value is None or isinstance(value, (int, long, basestring)): - pass - else: - if None in value: - value = util.Set(value) - value.remove(None) - value.add('') - value = ','.join(value) - return super(MSSet, self).convert_bind_param(value, engine) - + def result_processor(self, dialect): + def process(value): + # The good news: + # No ',' quoting issues- commas aren't allowed in SET values + # The bad news: + # Plenty of driver inconsistencies here. + if isinstance(value, util.set_types): + # ..some versions convert '' to an empty set + if not value: + value.add('') + # ..some return sets.Set, even for pythons that have __builtin__.set + if not isinstance(value, util.Set): + value = util.Set(value) + return value + # ...and some versions return strings + if value is not None: + return util.Set(value.split(',')) + else: + return value + return process + + def bind_processor(self, dialect): + super_convert = super(MSSet, self).bind_processor(dialect) + def process(value): + if value is None or isinstance(value, (int, long, basestring)): + pass + else: + if None in value: + value = util.Set(value) + value.remove(None) + value.add('') + value = ','.join(value) + if super_convert: + return super_convert(value) + else: + return value + return process + def get_col_spec(self): return self._extend("SET(%s)" % ",".join(self.__ddl_values)) @@ -1195,21 +1221,24 @@ class MSBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOL" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process colspecs = { sqltypes.Integer: MSInteger, @@ -1284,7 +1313,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext): re.I | re.UNICODE) def post_exec(self): - if self.compiled.isinsert: + if self.compiled.isinsert and not self.executemany: if (not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None): self._last_inserted_ids = ([self.cursor.lastrowid] + diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 2c45c94e8..520332d45 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -32,26 +32,31 @@ class OracleSmallInteger(sqltypes.Smallinteger): class OracleDate(sqltypes.Date): def get_col_spec(self): return "DATE" - def convert_bind_param(self, value, dialect): - return value - def convert_result_value(self, value, dialect): - if not isinstance(value, datetime.datetime): - return value - else: - return value.date() + def bind_processor(self, dialect): + return None + def result_processor(self, dialect): + def process(value): + if not isinstance(value, datetime.datetime): + return value + else: + return value.date() + return process + class OracleDateTime(sqltypes.DateTime): def get_col_spec(self): return "DATE" - def convert_result_value(self, value, dialect): - if value is None or isinstance(value,datetime.datetime): - return value - else: - # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year,value.month, - value.day,value.hour, value.minute, value.second) - + def result_processor(self, dialect): + def process(value): + if value is None or isinstance(value,datetime.datetime): + return value + else: + # convert cx_oracle datetime object returned pre-python 2.4 + return datetime.datetime(value.year,value.month, + value.day,value.hour, value.minute, value.second) + return process + # Note: # Oracle DATE == DATETIME # Oracle does not allow milliseconds in DATE @@ -65,14 +70,15 @@ class OracleTimestamp(sqltypes.TIMESTAMP): def get_dbapi_type(self, dialect): return dialect.TIMESTAMP - def convert_result_value(self, value, dialect): - if value is None or isinstance(value,datetime.datetime): - return value - else: - # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year,value.month, - value.day,value.hour, value.minute, value.second) - + def result_processor(self, dialect): + def process(value): + if value is None or isinstance(value,datetime.datetime): + return value + else: + # convert cx_oracle datetime object returned pre-python 2.4 + return datetime.datetime(value.year,value.month, + value.day,value.hour, value.minute, value.second) + return process class OracleString(sqltypes.String): def get_col_spec(self): @@ -85,15 +91,23 @@ class OracleText(sqltypes.TEXT): def get_col_spec(self): return "CLOB" - def convert_result_value(self, value, dialect): - if value is None: - return None - elif hasattr(value, 'read'): - # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str - return super(OracleText, self).convert_result_value(value.read(), dialect) - else: - return super(OracleText, self).convert_result_value(value, dialect) - + def result_processor(self, dialect): + super_process = super(OracleText, self).result_processor(dialect) + def process(value): + if value is None: + return None + elif hasattr(value, 'read'): + # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str + if super_process: + return super_process(value.read()) + else: + return value.read() + else: + if super_process: + return super_process(value) + else: + return value + return process class OracleRaw(sqltypes.Binary): def get_col_spec(self): @@ -110,34 +124,40 @@ class OracleBinary(sqltypes.Binary): def get_col_spec(self): return "BLOB" - def convert_bind_param(self, value, dialect): - return value - - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return value.read() + def bind_processor(self, dialect): + return None + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return value.read() + return process + class OracleBoolean(sqltypes.Boolean): def get_col_spec(self): return "SMALLINT" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + colspecs = { sqltypes.Integer : OracleInteger, sqltypes.Smallinteger : OracleSmallInteger, @@ -196,7 +216,7 @@ class OracleExecutionContext(default.DefaultExecutionContext): if self.compiled_parameters is not None: for k in self.out_parameters: type = self.compiled_parameters.get_type(k) - self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect) + self.out_parameters[k] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[k].getvalue()) else: for k in self.out_parameters: self.out_parameters[k] = self.out_parameters[k].getvalue() diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index a30832b43..e4897bba6 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -22,14 +22,19 @@ class PGNumeric(sqltypes.Numeric): else: return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - if not self.asdecimal and isinstance(value, util.decimal_type): - return float(value) + def result_processor(self, dialect): + if self.asdecimal: + return None else: - return value + def process(value): + if isinstance(value, util.decimal_type): + return float(value) + else: + return value + return process class PGFloat(sqltypes.Float): def get_col_spec(self): @@ -98,25 +103,38 @@ class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable): impl.__dict__.update(self.__dict__) impl.item_type = self.item_type.dialect_impl(dialect) return impl - def convert_bind_param(self, value, dialect): - if value is None: - return value - def convert_item(item): - if isinstance(item, (list,tuple)): - return [convert_item(child) for child in item] - else: - return self.item_type.convert_bind_param(item, dialect) - return [convert_item(item) for item in value] - def convert_result_value(self, value, dialect): - if value is None: - return value - def convert_item(item): - if isinstance(item, list): - return [convert_item(child) for child in item] - else: - return self.item_type.convert_result_value(item, dialect) - # Could specialcase when item_type.convert_result_value is the default identity func - return [convert_item(item) for item in value] + + def bind_processor(self, dialect): + item_proc = self.item_type.bind_processor(dialect) + def process(value): + if value is None: + return value + def convert_item(item): + if isinstance(item, (list,tuple)): + return [convert_item(child) for child in item] + else: + if item_proc: + return item_proc(item) + else: + return item + return [convert_item(item) for item in value] + return process + + def result_processor(self, dialect): + item_proc = self.item_type.bind_processor(dialect) + def process(value): + if value is None: + return value + def convert_item(item): + if isinstance(item, list): + return [convert_item(child) for child in item] + else: + if item_proc: + return item_proc(item) + else: + return item + return [convert_item(item) for item in value] + return process def get_col_spec(self): return self.item_type.get_col_spec() + '[]' diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 7999cc403..3cc821a36 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -32,15 +32,17 @@ class SLSmallInteger(sqltypes.Smallinteger): return "SMALLINT" class DateTimeMixin(object): - def convert_bind_param(self, value, dialect): - if value is not None: - if getattr(value, 'microsecond', None) is not None: - return value.strftime(self.__format__ + "." + str(value.microsecond)) + def bind_processor(self, dialect): + def process(value): + if value is not None: + if getattr(value, 'microsecond', None) is not None: + return value.strftime(self.__format__ + "." + str(value.microsecond)) + else: + return value.strftime(self.__format__) else: - return value.strftime(self.__format__) - else: - return None - + return None + return process + def _cvt(self, value, dialect): if value is None: return None @@ -57,30 +59,36 @@ class SLDateTime(DateTimeMixin,sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP" - def convert_result_value(self, value, dialect): - tup = self._cvt(value, dialect) - return tup and datetime.datetime(*tup) - + def result_processor(self, dialect): + def process(value): + tup = self._cvt(value, dialect) + return tup and datetime.datetime(*tup) + return process + class SLDate(DateTimeMixin, sqltypes.Date): __format__ = "%Y-%m-%d" def get_col_spec(self): return "DATE" - def convert_result_value(self, value, dialect): - tup = self._cvt(value, dialect) - return tup and datetime.date(*tup[0:3]) - + def result_processor(self, dialect): + def process(value): + tup = self._cvt(value, dialect) + return tup and datetime.date(*tup[0:3]) + return process + class SLTime(DateTimeMixin, sqltypes.Time): __format__ = "%H:%M:%S" def get_col_spec(self): return "TIME" - def convert_result_value(self, value, dialect): - tup = self._cvt(value, dialect) - return tup and datetime.time(*tup[3:7]) - + def result_processor(self, dialect): + def process(value): + tup = self._cvt(value, dialect) + return tup and datetime.time(*tup[3:7]) + return process + class SLText(sqltypes.TEXT): def get_col_spec(self): return "TEXT" @@ -101,16 +109,20 @@ class SLBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" - def convert_bind_param(self, value, dialect): - if value is None: - return None - return value and 1 or 0 - - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + return value and 1 or 0 + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + colspecs = { sqltypes.Integer : SLInteger, sqltypes.Smallinteger : SLSmallInteger, @@ -150,7 +162,7 @@ def descriptor(): class SQLiteExecutionContext(default.DefaultExecutionContext): def post_exec(self): - if self.compiled.isinsert: + if self.compiled.isinsert and not self.executemany: if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 284be6dfe..8fe34bf3f 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1127,20 +1127,17 @@ class ResultProxy(object): col3 = row[mytable.c.mycol] # access via Column object. ResultProxy also contains a map of TypeEngine objects and will - invoke the appropriate ``convert_result_value()`` method before + invoke the appropriate ``result_processor()`` method before returning columns, as well as the ExecutionContext corresponding to the statement execution. It provides several methods for which to obtain information from the underlying ExecutionContext. """ - class AmbiguousColumn(object): - def __init__(self, key): - self.key = key - def dialect_impl(self, dialect): - return self - def convert_result_value(self, arg, engine): - raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key)) - + def __ambiguous_processor(self, colname): + def process(value): + raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % colname) + return process + def __init__(self, context): """ResultProxy objects are constructed via the execute() method on SQLEngine.""" self.context = context @@ -1185,13 +1182,13 @@ class ResultProxy(object): else: type = typemap.get(item[1], types.NULLTYPE) - rec = (type, type.dialect_impl(self.dialect), i) + rec = (type, type.dialect_impl(self.dialect).result_processor(self.dialect), i) if rec[0] is None: raise exceptions.InvalidRequestError( "None for metadata " + colname) if self.__props.setdefault(colname.lower(), rec) is not rec: - self.__props[colname.lower()] = (type, ResultProxy.AmbiguousColumn(colname), 0) + self.__props[colname.lower()] = (type, self.__ambiguous_processor(colname), 0) self.__keys.append(colname) self.__props[i] = rec @@ -1298,7 +1295,10 @@ class ResultProxy(object): def _get_col(self, row, key): rec = self._key_cache[key] - return rec[1].convert_result_value(row[rec[2]], self.dialect) + if rec[1]: + return rec[1](row[rec[2]]) + else: + return row[rec[2]] def _fetchone_impl(self): return self.cursor.fetchone() diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 02802452a..ccaf080e7 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -163,12 +163,17 @@ class DefaultExecutionContext(base.ExecutionContext): self.statement = unicode(compiled) if parameters is None: self.compiled_parameters = compiled.construct_params({}) + self.executemany = False elif not isinstance(parameters, (list, tuple)): self.compiled_parameters = compiled.construct_params(parameters) + self.executemany = False else: self.compiled_parameters = [compiled.construct_params(m or {}) for m in parameters] if len(self.compiled_parameters) == 1: self.compiled_parameters = self.compiled_parameters[0] + self.executemany = False + else: + self.executemany = True elif statement is not None: self.typemap = self.column_labels = None self.parameters = self.__encode_param_keys(parameters) @@ -206,22 +211,26 @@ class DefaultExecutionContext(base.ExecutionContext): return proc(params) def __convert_compiled_params(self, parameters): - executemany = parameters is not None and isinstance(parameters, list) encode = not self.dialect.supports_unicode_statements() # the bind params are a CompiledParams object. but all the DBAPI's hate # that object (or similar). so convert it to a clean # dictionary/list/tuple of dictionary/tuple of list if parameters is not None: - if self.dialect.positional: - if executemany: - parameters = [p.get_raw_list() for p in parameters] + if self.executemany: + processors = parameters[0].get_processors() + else: + processors = parameters.get_processors() + + if self.dialect.positional: + if self.executemany: + parameters = [p.get_raw_list(processors) for p in parameters] else: - parameters = parameters.get_raw_list() - else: - if executemany: - parameters = [p.get_raw_dict(encode_keys=encode) for p in parameters] + parameters = parameters.get_raw_list(processors) + else: + if self.executemany: + parameters = [p.get_raw_dict(processors, encode_keys=encode) for p in parameters] else: - parameters = parameters.get_raw_dict(encode_keys=encode) + parameters = parameters.get_raw_dict(processors, encode_keys=encode) return parameters def is_select(self): @@ -311,28 +320,31 @@ class DefaultExecutionContext(base.ExecutionContext): """generate default values for compiled insert/update statements, and generate last_inserted_ids() collection.""" - # TODO: cleanup if self.isinsert: - if isinstance(self.compiled_parameters, list): - plist = self.compiled_parameters - else: - plist = [self.compiled_parameters] drunner = self.dialect.defaultrunner(self) - for param in plist: + if self.executemany: + # executemany doesn't populate last_inserted_ids() + firstparam = self.compiled_parameters[0] + processors = firstparam.get_processors() + for c in self.compiled.statement.table.c: + if c.default is not None: + params = self.compiled_parameters + for param in params: + if not c.key in param or param.get_original(c.key) is None: + self.compiled_parameters = param + newid = drunner.get_column_default(c) + if newid is not None: + param.set_value(c.key, newid) + self.compiled_parameters = params + else: + param = self.compiled_parameters + processors = param.get_processors() last_inserted_ids = [] - # check the "default" status of each column in the table for c in self.compiled.statement.table.c: - # check if it will be populated by a SQL clause - we'll need that - # after execution. if c in self.compiled.inline_params: self._postfetch_cols.add(c) if c.primary_key: last_inserted_ids.append(None) - # check if its not present at all. see if theres a default - # and fire it off, and add to bind parameters. if - # its a pk, add the value to our last_inserted_ids list, - # or, if its a SQL-side default, let it fire off on the DB side, but we'll need - # the SQL-generated value after execution. elif not c.key in param or param.get_original(c.key) is None: if isinstance(c.default, schema.PassiveDefault): self._postfetch_cols.add(c) @@ -340,32 +352,33 @@ class DefaultExecutionContext(base.ExecutionContext): if newid is not None: param.set_value(c.key, newid) if c.primary_key: - last_inserted_ids.append(param.get_processed(c.key)) + last_inserted_ids.append(param.get_processed(c.key, processors)) elif c.primary_key: last_inserted_ids.append(None) - # its an explicitly passed pk value - add it to - # our last_inserted_ids list. elif c.primary_key: - last_inserted_ids.append(param.get_processed(c.key)) - # TODO: we arent accounting for executemany() situations - # here (hard to do since lastrowid doesnt support it either) + last_inserted_ids.append(param.get_processed(c.key, processors)) self._last_inserted_ids = last_inserted_ids self._last_inserted_params = param + + elif self.isupdate: - if isinstance(self.compiled_parameters, list): - plist = self.compiled_parameters - else: - plist = [self.compiled_parameters] drunner = self.dialect.defaultrunner(self) - for param in plist: - # check the "onupdate" status of each column in the table + if self.executemany: + for c in self.compiled.statement.table.c: + if c.onupdate is not None: + params = self.compiled_parameters + for param in params: + if not c.key in param or param.get_original(c.key) is None: + self.compiled_parameters = param + value = drunner.get_column_onupdate(c) + if value is not None: + param.set_value(c.key, value) + self.compiled_parameters = params + else: + param = self.compiled_parameters for c in self.compiled.statement.table.c: - # it will be populated by a SQL clause - we'll need that - # after execution. if c in self.compiled.inline_params: self._postfetch_cols.add(c) - # its not in the bind parameters, and theres an "onupdate" defined for the column; - # execute it and add to bind params elif c.onupdate is not None and (not c.key in param or param.get_original(c.key) is None): value = drunner.get_column_onupdate(c) if value is not None: diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 3fc13a50d..994a877bd 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -812,7 +812,6 @@ class ClauseParameters(object): """ def __init__(self, dialect, positional=None): - super(ClauseParameters, self).__init__() self.dialect = dialect self.__binds = {} self.positional = positional or [] @@ -829,19 +828,31 @@ class ClauseParameters(object): def get_type(self, key): return self.__binds[key][0].type - def get_processed(self, key): - (bind, name, value) = self.__binds[key] - return bind.typeprocess(value, self.dialect) - + def get_processors(self): + """return a dictionary of bind 'processing' functions""" + return dict([ + (key, value) for key, value in + [( + key, + self.__binds[key][0].bind_processor(self.dialect) + ) for key in self.__binds] + if value is not None + ]) + + def get_processed(self, key, processors): + return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2] + def keys(self): return self.__binds.keys() def __iter__(self): return iter(self.keys()) - - def __getitem__(self, key): - return self.get_processed(key) + def __getitem__(self, key): + (bind, name, value) = self.__binds[key] + processor = bind.bind_processor(self.dialect) + return processor is not None and processor(value) or value + def __contains__(self, key): return key in self.__binds @@ -851,14 +862,36 @@ class ClauseParameters(object): def get_original_dict(self): return dict([(name, value) for (b, name, value) in self.__binds.values()]) - def get_raw_list(self): - return [self.get_processed(key) for key in self.positional] + def get_raw_list(self, processors): +# (bind, name, value) = self.__binds[key] + return [ + (key in processors) and + processors[key](self.__binds[key][2]) or + self.__binds[key][2] + for key in self.positional + ] - def get_raw_dict(self, encode_keys=False): + def get_raw_dict(self, processors, encode_keys=False): if encode_keys: - return dict([(key.encode(self.dialect.encoding), self.get_processed(key)) for key in self.keys()]) + return dict([ + ( + key.encode(self.dialect.encoding), + (key in processors) and + processors[key](self.__binds[key][2]) or + self.__binds[key][2] + ) + for key in self.keys() + ]) else: - return dict([(key, self.get_processed(key)) for key in self.keys()]) + return dict([ + ( + key, + (key in processors) and + processors[key](self.__binds[key][2]) or + self.__binds[key][2] + ) + for key in self.keys() + ]) def __repr__(self): return self.__class__.__name__ + ":" + repr(self.get_original_dict()) @@ -1995,8 +2028,8 @@ class _BindParamClause(ClauseElement, _CompareMixin): def _get_from_objects(self, **modifiers): return [] - def typeprocess(self, value, dialect): - return self.type.dialect_impl(dialect).convert_bind_param(value, dialect) + def bind_processor(self, dialect): + return self.type.dialect_impl(dialect).bind_processor(dialect) def _compare_type(self, obj): if not isinstance(self.type, sqltypes.NullType): diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index fe05910df..f3854e3e1 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -59,12 +59,13 @@ class TypeEngine(AbstractType): def get_col_spec(self): raise NotImplementedError() - def convert_bind_param(self, value, dialect): - return value - - def convert_result_value(self, value, dialect): - return value + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + return None + def adapt(self, cls): return cls() @@ -115,11 +116,11 @@ class TypeDecorator(AbstractType): def get_col_spec(self): return self.impl.get_col_spec() - def convert_bind_param(self, value, dialect): - return self.impl.convert_bind_param(value, dialect) + def bind_processor(self, dialect): + return self.impl.bind_processor(dialect) - def convert_result_value(self, value, dialect): - return self.impl.convert_result_value(value, dialect) + def result_processor(self, dialect): + return self.impl.result_processor(dialect) def copy(self): instance = self.__class__.__new__(self.__class__) @@ -183,11 +184,6 @@ class NullType(TypeEngine): def get_col_spec(self): raise NotImplementedError() - def convert_bind_param(self, value, dialect): - return value - - def convert_result_value(self, value, dialect): - return value NullTypeEngine = NullType class Concatenable(object): @@ -202,11 +198,27 @@ class String(TypeEngine, Concatenable): def adapt(self, impltype): return impltype(length=self.length, convert_unicode=self.convert_unicode) - def convert_bind_param(self, value, dialect): - if not (self.convert_unicode or dialect.convert_unicode) or value is None or not isinstance(value, unicode): - return value + def bind_processor(self, dialect): + if self.convert_unicode or dialect.convert_unicode: + def process(value): + if isinstance(value, unicode): + return value.encode(dialect.encoding) + else: + return value + return process else: - return value.encode(dialect.encoding) + return None + + def result_processor(self, dialect): + if self.convert_unicode or dialect.convert_unicode: + def process(value): + if value is not None and not isinstance(value, unicode): + return value.decode(dialect.encoding) + else: + return value + return process + else: + return None def get_search_list(self): l = super(String, self).get_search_list() @@ -215,11 +227,6 @@ class String(TypeEngine, Concatenable): else: return l - def convert_result_value(self, value, dialect): - if not (self.convert_unicode or dialect.convert_unicode) or value is None or isinstance(value, unicode): - return value - else: - return value.decode(dialect.encoding) def get_dbapi_type(self, dbapi): return dbapi.STRING @@ -254,17 +261,24 @@ class Numeric(TypeEngine): def get_dbapi_type(self, dbapi): return dbapi.NUMBER - def convert_bind_param(self, value, dialect): - if value is not None: - return float(value) - else: - return value - - def convert_result_value(self, value, dialect): - if value is not None and self.asdecimal: - return Decimal(str(value)) + def bind_processor(self, dialect): + def process(value): + if value is not None: + return float(value) + else: + return value + return process + + def result_processor(self, dialect): + if self.asdecimal: + def process(value): + if value is not None: + return Decimal(str(value)) + else: + return value + return process else: - return value + return None class Float(Numeric): def __init__(self, precision = 10, asdecimal=False, **kwargs): @@ -308,15 +322,14 @@ class Binary(TypeEngine): def __init__(self, length=None): self.length = length - def convert_bind_param(self, value, dialect): - if value is not None: - return dialect.dbapi.Binary(value) - else: - return None - - def convert_result_value(self, value, dialect): - return value - + def bind_processor(self, dialect): + def process(value): + if value is not None: + return dialect.dbapi.Binary(value) + else: + return None + return process + def adapt(self, impltype): return impltype(length=self.length) @@ -332,17 +345,27 @@ class PickleType(MutableType, TypeDecorator): self.mutable = mutable super(PickleType, self).__init__() - def convert_result_value(self, value, dialect): - if value is None: - return None - buf = self.impl.convert_result_value(value, dialect) - return self.pickler.loads(str(buf)) - - def convert_bind_param(self, value, dialect): - if value is None: - return None - return self.impl.convert_bind_param(self.pickler.dumps(value, self.protocol), dialect) - + def bind_processor(self, dialect): + impl_process = self.impl.bind_processor(dialect) + def process(value): + if value is None: + return None + if impl_process is None: + return self.pickler.dumps(value, self.protocol) + else: + return impl_process(self.pickler.dumps(value, self.protocol)) + return process + + def result_processor(self, dialect): + impl_process = self.impl.result_processor(dialect) + def process(value): + if value is None: + return None + if impl_process is not None: + value = impl_process(value) + return self.pickler.loads(str(value)) + return process + def copy_value(self, value): if self.mutable: return self.pickler.loads(self.pickler.dumps(value, self.protocol)) @@ -370,8 +393,8 @@ class Interval(TypeDecorator): Converting is very simple - just use epoch(zero timestamp, 01.01.1970) as base, so if we need to store timedelta = 1 day (24 hours) in database it - will be stored as DateTime = '2nd Jan 1970 00:00', see convert_bind_param - and convert_result_value to actual conversion code + will be stored as DateTime = '2nd Jan 1970 00:00', see bind_processor + and result_processor to actual conversion code """ #Empty useless type, because at the moment of creation of instance we don't #know what type will be decorated - it depends on used dialect. @@ -396,25 +419,35 @@ class Interval(TypeDecorator): def __hasNativeImpl(self,dialect): return dialect.__class__ in self.__supported - - def convert_bind_param(self, value, dialect): - if value is None: - return None - if not self.__hasNativeImpl(dialect): - tmpval = dt.datetime.utcfromtimestamp(0) + value - return self.impl.convert_bind_param(tmpval,dialect) + + def bind_processor(self, dialect): + impl_processor = self.impl.bind_processor(dialect) + if self.__hasNativeImpl(dialect): + return impl_processor else: - return self.impl.convert_bind_param(value,dialect) - - def convert_result_value(self, value, dialect): - if value is None: - return None - retval = self.impl.convert_result_value(value,dialect) - if not self.__hasNativeImpl(dialect): - return retval - dt.datetime.utcfromtimestamp(0) + def process(value): + if value is None: + return None + tmpval = dt.datetime.utcfromtimestamp(0) + value + if impl_processor is not None: + return impl_processor(tmpval) + else: + return tmpval + return process + + def result_processor(self, dialect): + impl_processor = self.impl.result_processor(dialect) + if self.__hasNativeImpl(dialect): + return impl_processor else: - return retval - + def process(value): + if value is None: + return None + if impl_processor is not None: + value = impl_processor(value) + return value - dt.datetime.utcfromtimestamp(0) + return process + class FLOAT(Float):pass class TEXT(String):pass class DECIMAL(Numeric):pass diff --git a/test/orm/assorted_eager.py b/test/orm/assorted_eager.py index 652186b8e..ce17e8dfd 100644 --- a/test/orm/assorted_eager.py +++ b/test/orm/assorted_eager.py @@ -13,7 +13,7 @@ class EagerTest(AssertMixin): dbmeta = MetaData(testbase.db) # determine a literal value for "false" based on the dialect - false = Boolean().dialect_impl(testbase.db.dialect).convert_bind_param(False, testbase.db.dialect) + false = Boolean().dialect_impl(testbase.db.dialect).bind_processor(testbase.db.dialect)(False) owners = Table ( 'owners', dbmeta , Column ( 'id', Integer, primary_key=True, nullable=False ), diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 0ef64746f..c7a5c055a 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -460,7 +460,7 @@ class ClauseAttributesTest(UnitOfWorkTest): global metadata, users_table metadata = MetaData(testbase.db) users_table = Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, Sequence('users_id_seq', optional=True), primary_key=True), Column('name', String(30)), Column('counter', Integer, default=1)) metadata.create_all() @@ -995,7 +995,10 @@ class SaveTest(UnitOfWorkTest): u = User() u.user_id=42 Session.commit() - + + # why no support on oracle ? because oracle doesn't save + # "blank" strings; it saves a single space character. + @testing.unsupported('oracle') def test_dont_update_blanks(self): mapper(User, users) u = User() diff --git a/test/sql/defaults.py b/test/sql/defaults.py index 0df49ea39..76bd2c41f 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -9,14 +9,14 @@ import datetime class DefaultTest(PersistTest): def setUpAll(self): - global t, f, f2, ts, currenttime, metadata + global t, f, f2, ts, currenttime, metadata, default_generator db = testbase.db metadata = MetaData(db) - x = {'x':50} + default_generator = {'x':50} def mydefault(): - x['x'] += 1 - return x['x'] + default_generator['x'] += 1 + return default_generator['x'] def myupdate_with_ctx(ctx): return len(ctx.compiled_parameters['col2']) @@ -96,6 +96,7 @@ class DefaultTest(PersistTest): t.drop() def tearDown(self): + default_generator['x'] = 50 t.delete().execute() def testargsignature(self): @@ -125,7 +126,14 @@ class DefaultTest(PersistTest): t.insert().execute() ctexec = currenttime.scalar() - print "Currenttime "+ repr(ctexec) + l = t.select().execute() + today = datetime.date.today() + self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today)]) + + def testinsertmany(self): + r = t.insert().execute({}, {}, {}) + + ctexec = currenttime.scalar() l = t.select().execute() today = datetime.date.today() self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today)]) @@ -135,6 +143,25 @@ class DefaultTest(PersistTest): l = t.select().execute() self.assert_(l.fetchone()['col3'] == 50) + def testupdatemany(self): + t.insert().execute({}, {}, {}) + + t.update(t.c.col1==bindparam('pkval')).execute( + {'pkval':51,'col7':None, 'col8':None, 'boolcol1':False}, + ) + + + t.update(t.c.col1==bindparam('pkval')).execute( + {'pkval':51,}, + {'pkval':52,}, + {'pkval':53,}, + ) + + l = t.select().execute() + ctexec = currenttime.scalar() + today = datetime.date.today() + self.assert_(l.fetchall() == [(51, 'im the update', f2, ts, ts, ctexec, False, False, 13, today), (52, 'im the update', f2, ts, ts, ctexec, True, False, 13, today), (53, 'im the update', f2, ts, ts, ctexec, True, False, 13, today)]) + def testupdate(self): r = t.insert().execute() @@ -147,7 +174,7 @@ class DefaultTest(PersistTest): self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False, 13, datetime.date.today())) # mysql/other db's return 0 or 1 for count(1) self.assert_(14 <= f2 <= 15) - + def testupdatevalues(self): r = t.insert().execute() pk = r.last_inserted_ids()[0] diff --git a/test/sql/select.py b/test/sql/select.py index f5932d515..865f1ec48 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -902,9 +902,9 @@ EXISTS (select yay from foo where boo = lar)", self.assert_compile(stmt, expected_positional_stmt, dialect=sqlite.dialect()) nonpositional = stmt.compile() positional = stmt.compile(dialect=sqlite.dialect()) - assert positional.get_params().get_raw_list() == expected_default_params_list - assert nonpositional.get_params(**test_param_dict).get_raw_dict() == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict).get_raw_dict())) - assert positional.get_params(**test_param_dict).get_raw_list() == expected_test_params_list + assert positional.get_params().get_raw_list({}) == expected_default_params_list + assert nonpositional.get_params(**test_param_dict).get_raw_dict({}) == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict).get_raw_dict())) + assert positional.get_params(**test_param_dict).get_raw_list({}) == expected_test_params_list # check that params() doesnt modify original statement s = select([table1], or_(table1.c.myid==bindparam('myid'), table2.c.otherid==bindparam('myotherid'))) diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 659033016..d917b4a18 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -10,28 +10,47 @@ from testlib import * class MyType(types.TypeEngine): def get_col_spec(self): return "VARCHAR(100)" - def convert_bind_param(self, value, engine): - return "BIND_IN"+ value - def convert_result_value(self, value, engine): - return value + "BIND_OUT" + def bind_processor(self, dialect): + def process(value): + return "BIND_IN"+ value + return process + def result_processor(self, dialect): + def process(value): + return value + "BIND_OUT" + return process def adapt(self, typeobj): return typeobj() class MyDecoratedType(types.TypeDecorator): impl = String - def convert_bind_param(self, value, dialect): - return "BIND_IN"+ super(MyDecoratedType, self).convert_bind_param(value, dialect) - def convert_result_value(self, value, dialect): - return super(MyDecoratedType, self).convert_result_value(value, dialect) + "BIND_OUT" + def bind_processor(self, dialect): + impl_processor = super(MyDecoratedType, self).bind_processor(dialect) or (lambda value:value) + def process(value): + return "BIND_IN"+ impl_processor(value) + return process + def result_processor(self, dialect): + impl_processor = super(MyDecoratedType, self).result_processor(dialect) or (lambda value:value) + def process(value): + return impl_processor(value) + "BIND_OUT" + return process def copy(self): return MyDecoratedType() class MyUnicodeType(types.TypeDecorator): impl = Unicode - def convert_bind_param(self, value, dialect): - return "UNI_BIND_IN"+ super(MyUnicodeType, self).convert_bind_param(value, dialect) - def convert_result_value(self, value, dialect): - return super(MyUnicodeType, self).convert_result_value(value, dialect) + "UNI_BIND_OUT" + + def bind_processor(self, dialect): + impl_processor = super(MyUnicodeType, self).bind_processor(dialect) + def process(value): + return "UNI_BIND_IN"+ impl_processor(value) + return process + + def result_processor(self, dialect): + impl_processor = super(MyUnicodeType, self).result_processor(dialect) + def process(value): + return impl_processor(value) + "UNI_BIND_OUT" + return process + def copy(self): return MyUnicodeType(self.impl.length) diff --git a/test/sql/unicode.py b/test/sql/unicode.py index b66d001be..19e78ed59 100644 --- a/test/sql/unicode.py +++ b/test/sql/unicode.py @@ -32,6 +32,8 @@ class UnicodeSchemaTest(PersistTest): Column(u'\u6e2c\u8a66_id', Integer, primary_key=True, autoincrement=False), Column(u'unitable1_\u6e2c\u8a66', Integer, + # lets leave these out for now so that PG tests pass, until + # the test can be broken out into a pg-passing version (or we figure it out) #ForeignKey(u'unitable1.\u6e2c\u8a66') ), Column(u'Unitéble2_b', Integer, diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 9ee201202..ba3670f4d 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -221,7 +221,7 @@ class SQLCompileTest(PersistTest): if checkparams is not None: if isinstance(checkparams, list): - self.assert_(c.get_params().get_raw_list() == checkparams, "params dont match ") + self.assert_(c.get_params().get_raw_list({}) == checkparams, "params dont match ") else: self.assert_(c.get_params().get_original_dict() == checkparams, "params dont match" + repr(c.get_params())) |