diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 20 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/access.py | 72 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/informix.py | 85 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 180 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 217 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 130 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 68 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 74 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 24 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 91 | ||||
-rw-r--r-- | lib/sqlalchemy/sql.py | 63 | ||||
-rw-r--r-- | lib/sqlalchemy/types.py | 177 |
12 files changed, 697 insertions, 504 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 4d50b6a25..dd4065f39 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -252,24 +252,14 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): for a single statement execution, or one element of an executemany execution. """ - if self.parameters is not None: - bindparams = self.parameters.copy() - else: - bindparams = {} - bindparams.update(params) d = sql.ClauseParameters(self.dialect, self.positiontup) - for b in self.binds.values(): - name = self.bind_names[b] - d.set_parameter(b, b.value, name) - for key, value in bindparams.iteritems(): - try: - b = self.binds[key] - except KeyError: - continue - name = self.bind_names[b] - d.set_parameter(b, value, name) + pd = self.parameters or {} + pd.update(params) + for key, bind in self.binds.iteritems(): + d.set_parameter(bind, pd.get(key, bind.value), self.bind_names[bind]) + return d params = property(lambda self:self.construct_params({}), doc="""Return the `ClauseParameters` corresponding to this compiled object. diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py index 3c06822ae..6bf8b96e9 100644 --- a/lib/sqlalchemy/databases/access.py +++ b/lib/sqlalchemy/databases/access.py @@ -11,16 +11,18 @@ import sqlalchemy.engine.default as default class AcNumeric(types.Numeric): - def convert_result_value(self, value, dialect): - return value - - def convert_bind_param(self, value, dialect): - if value is None: - # Not sure that this exception is needed - return value - else: - return str(value) + def result_processor(self, dialect): + return None + def bind_processor(self, dialect): + def process(value): + if value is None: + # Not sure that this exception is needed + return value + else: + return str(value) + return process + def get_col_spec(self): return "NUMERIC" @@ -28,12 +30,14 @@ class AcFloat(types.Float): def get_col_spec(self): return "FLOAT" - def convert_bind_param(self, value, dialect): + def bind_processor(self, dialect): """By converting to string, we can use Decimal types round-trip.""" - if not value is None: - return str(value) - return None - + def process(value): + if not value is None: + return str(value) + return None + return process + class AcInteger(types.Integer): def get_col_spec(self): return "INTEGER" @@ -72,11 +76,11 @@ class AcUnicode(types.Unicode): def get_col_spec(self): return "TEXT" + (self.length and ("(%d)" % self.length) or "") - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - return value + def result_processor(self, dialect): + return None class AcChar(types.CHAR): def get_col_spec(self): @@ -90,21 +94,25 @@ class AcBoolean(types.Boolean): def get_col_spec(self): return "YESNO" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + class AcTimeStamp(types.TIMESTAMP): def get_col_spec(self): return "TIMESTAMP" diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index a3ef99916..21ecf1538 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -61,27 +61,33 @@ class InfoDateTime(sqltypes.DateTime ): def get_col_spec(self): return "DATETIME YEAR TO SECOND" - def convert_bind_param(self, value, dialect): - if value is not None: - if value.microsecond: - value = value.replace( microsecond = 0 ) - return value - + def bind_processor(self, dialect): + def process(value): + if value is not None: + if value.microsecond: + value = value.replace( microsecond = 0 ) + return value + return process + class InfoTime(sqltypes.Time ): def get_col_spec(self): return "DATETIME HOUR TO SECOND" - def convert_bind_param(self, value, dialect): - if value is not None: - if value.microsecond: - value = value.replace( microsecond = 0 ) - return value + def bind_processor(self, dialect): + def process(value): + if value is not None: + if value.microsecond: + value = value.replace( microsecond = 0 ) + return value + return process - def convert_result_value(self, value, dialect): - if isinstance( value , datetime.datetime ): - return value.time() - else: - return value + def result_processor(self, dialect): + def process(value): + if isinstance( value , datetime.datetime ): + return value.time() + else: + return value + return process class InfoText(sqltypes.String): def get_col_spec(self): @@ -91,36 +97,45 @@ class InfoString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} - def convert_bind_param( self , value , dialect ): - if value == '': - return None - else: - return value - + def bind_processor(self, dialect): + def process(value): + if value == '': + return None + else: + return value + return process + class InfoChar(sqltypes.CHAR): def get_col_spec(self): return "CHAR(%(length)s)" % {'length' : self.length} + class InfoBinary(sqltypes.Binary): def get_col_spec(self): return "BYTE" + class InfoBoolean(sqltypes.Boolean): default_type = 'NUM' def get_col_spec(self): return "SMALLINT" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process colspecs = { sqltypes.Integer : InfoInteger, diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 9ec0fbbc3..308a38a76 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -47,16 +47,18 @@ from sqlalchemy.engine import default import operator class MSNumeric(sqltypes.Numeric): - def convert_result_value(self, value, dialect): - return value - - def convert_bind_param(self, value, dialect): - if value is None: - # Not sure that this exception is needed - return value - else: - return str(value) + def result_processor(self, dialect): + return None + def bind_processor(self, dialect): + def process(value): + if value is None: + # Not sure that this exception is needed + return value + else: + return str(value) + return process + def get_col_spec(self): if self.precision is None: return "NUMERIC" @@ -67,12 +69,14 @@ class MSFloat(sqltypes.Float): def get_col_spec(self): return "FLOAT(%(precision)s)" % {'precision': self.precision} - def convert_bind_param(self, value, dialect): - """By converting to string, we can use Decimal types round-trip.""" - if not value is None: - return str(value) - return None - + def bind_processor(self, dialect): + def process(value): + """By converting to string, we can use Decimal types round-trip.""" + if not value is None: + return str(value) + return None + return process + class MSInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" @@ -108,57 +112,71 @@ class MSTime(sqltypes.Time): def get_col_spec(self): return "DATETIME" - def convert_bind_param(self, value, dialect): - if isinstance(value, datetime.datetime): - value = datetime.datetime.combine(self.__zero_date, value.time()) - elif isinstance(value, datetime.time): - value = datetime.datetime.combine(self.__zero_date, value) - return value - - def convert_result_value(self, value, dialect): - if isinstance(value, datetime.datetime): - return value.time() - elif isinstance(value, datetime.date): - return datetime.time(0, 0, 0) - return value - -class MSDateTime_adodbapi(MSDateTime): - def convert_result_value(self, value, dialect): - # adodbapi will return datetimes with empty time values as datetime.date() objects. - # Promote them back to full datetime.datetime() - if value and not hasattr(value, 'second'): - return datetime.datetime(value.year, value.month, value.day) - return value - -class MSDateTime_pyodbc(MSDateTime): - def convert_bind_param(self, value, dialect): - if value and not hasattr(value, 'second'): - return datetime.datetime(value.year, value.month, value.day) - else: + def bind_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + value = datetime.datetime.combine(self.__zero_date, value.time()) + elif isinstance(value, datetime.time): + value = datetime.datetime.combine(self.__zero_date, value) return value - -class MSDate_pyodbc(MSDate): - def convert_bind_param(self, value, dialect): - if value and not hasattr(value, 'second'): - return datetime.datetime(value.year, value.month, value.day) - else: + return process + + def result_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + elif isinstance(value, datetime.date): + return datetime.time(0, 0, 0) return value - - def convert_result_value(self, value, dialect): - # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date() - if value and hasattr(value, 'second'): - return value.date() - else: + return process + +class MSDateTime_adodbapi(MSDateTime): + def result_processor(self, dialect): + def process(value): + # adodbapi will return datetimes with empty time values as datetime.date() objects. + # Promote them back to full datetime.datetime() + if value and not hasattr(value, 'second'): + return datetime.datetime(value.year, value.month, value.day) return value - + return process + +class MSDateTime_pyodbc(MSDateTime): + def bind_processor(self, dialect): + def process(value): + if value and not hasattr(value, 'second'): + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + +class MSDate_pyodbc(MSDate): + def bind_processor(self, dialect): + def process(value): + if value and not hasattr(value, 'second'): + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + + def result_processor(self, dialect): + def process(value): + # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date() + if value and hasattr(value, 'second'): + return value.date() + else: + return value + return process + class MSDate_pymssql(MSDate): - def convert_result_value(self, value, dialect): - # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date() - if value and hasattr(value, 'second'): - return value.date() - else: - return value - + def result_processor(self, dialect): + def process(value): + # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date() + if value and hasattr(value, 'second'): + return value.date() + else: + return value + return process + class MSText(sqltypes.TEXT): def get_col_spec(self): if self.dialect.text_as_varchar: @@ -181,11 +199,11 @@ class MSNVarchar(sqltypes.Unicode): class AdoMSNVarchar(MSNVarchar): """overrides bindparam/result processing to not convert any unicode strings""" - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - return value + def result_processor(self, dialect): + return None class MSChar(sqltypes.CHAR): def get_col_spec(self): @@ -203,20 +221,24 @@ class MSBoolean(sqltypes.Boolean): def get_col_spec(self): return "BIT" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process class MSTimeStamp(sqltypes.TIMESTAMP): def get_col_spec(self): diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 01d3fa6bc..6d6f32ead 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -294,14 +294,20 @@ class MSNumeric(sqltypes.Numeric, _NumericType): else: return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}) - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - if not self.asdecimal and isinstance(value, util.decimal_type): - return float(value) + def result_processor(self, dialect): + if not self.asdecimal: + def process(value): + if isinstance(value, util.decimal_type): + return float(value) + else: + return value + return process else: - return value + return None + class MSDecimal(MSNumeric): @@ -408,8 +414,8 @@ class MSFloat(sqltypes.Float, _NumericType): else: return self._extend("FLOAT") - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None class MSInteger(sqltypes.Integer, _NumericType): @@ -539,16 +545,17 @@ class MSBit(sqltypes.TypeEngine): def __init__(self, length=None): self.length = length - def convert_result_value(self, value, dialect): + def result_processor(self, dialect): """Convert a MySQL's 64 bit, variable length binary string to a long.""" - - if value is not None: - v = 0L - for i in map(ord, value): - v = v << 8 | i - value = v - return value - + def process(value): + if value is not None: + v = 0L + for i in map(ord, value): + v = v << 8 | i + value = v + return value + return process + def get_col_spec(self): if self.length is not None: return "BIT(%s)" % self.length @@ -576,13 +583,14 @@ class MSTime(sqltypes.Time): def get_col_spec(self): return "TIME" - def convert_result_value(self, value, dialect): - # convert from a timedelta value - if value is not None: - return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) - else: - return None - + def result_processor(self, dialect): + def process(value): + # convert from a timedelta value + if value is not None: + return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) + else: + return None + return process class MSTimeStamp(sqltypes.TIMESTAMP): """MySQL TIMESTAMP type. @@ -930,12 +938,13 @@ class _BinaryType(sqltypes.Binary): else: return "BLOB" - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return buffer(value) - + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return buffer(value) + return process class MSVarBinary(_BinaryType): """MySQL VARBINARY type, for variable length binary data.""" @@ -976,12 +985,13 @@ class MSBinary(_BinaryType): else: return "BLOB" - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return buffer(value) - + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return buffer(value) + return process class MSBlob(_BinaryType): """MySQL BLOB type, for binary data up to 2^16 bytes""" @@ -1002,13 +1012,15 @@ class MSBlob(_BinaryType): return "BLOB(%d)" % self.length else: return "BLOB" - - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return buffer(value) - + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return buffer(value) + return process + def __repr__(self): return "%s()" % self.__class__.__name__ @@ -1094,12 +1106,18 @@ class MSEnum(MSString): length = max([len(v) for v in strip_enums]) super(MSEnum, self).__init__(length, **kw) - def convert_bind_param(self, value, engine): - if self.strict and value is not None and value not in self.enums: - raise exceptions.InvalidRequestError('"%s" not a valid value for ' - 'this enum' % value) - return super(MSEnum, self).convert_bind_param(value, engine) - + def bind_processor(self, dialect): + super_convert = super(MSEnum, self).bind_processor(dialect) + def process(value): + if self.strict and value is not None and value not in self.enums: + raise exceptions.InvalidRequestError('"%s" not a valid value for ' + 'this enum' % value) + if super_convert: + return super_convert(value) + else: + return value + return process + def get_col_spec(self): return self._extend("ENUM(%s)" % ",".join(self.__ddl_values)) @@ -1155,36 +1173,44 @@ class MSSet(MSString): length = max([len(v) for v in strip_values] + [0]) super(MSSet, self).__init__(length, **kw) - def convert_result_value(self, value, dialect): - # The good news: - # No ',' quoting issues- commas aren't allowed in SET values - # The bad news: - # Plenty of driver inconsistencies here. - if isinstance(value, util.set_types): - # ..some versions convert '' to an empty set - if not value: - value.add('') - # ..some return sets.Set, even for pythons that have __builtin__.set - if not isinstance(value, util.Set): - value = util.Set(value) - return value - # ...and some versions return strings - if value is not None: - return util.Set(value.split(',')) - else: - return value - - def convert_bind_param(self, value, engine): - if value is None or isinstance(value, (int, long, basestring)): - pass - else: - if None in value: - value = util.Set(value) - value.remove(None) - value.add('') - value = ','.join(value) - return super(MSSet, self).convert_bind_param(value, engine) - + def result_processor(self, dialect): + def process(value): + # The good news: + # No ',' quoting issues- commas aren't allowed in SET values + # The bad news: + # Plenty of driver inconsistencies here. + if isinstance(value, util.set_types): + # ..some versions convert '' to an empty set + if not value: + value.add('') + # ..some return sets.Set, even for pythons that have __builtin__.set + if not isinstance(value, util.Set): + value = util.Set(value) + return value + # ...and some versions return strings + if value is not None: + return util.Set(value.split(',')) + else: + return value + return process + + def bind_processor(self, dialect): + super_convert = super(MSSet, self).bind_processor(dialect) + def process(value): + if value is None or isinstance(value, (int, long, basestring)): + pass + else: + if None in value: + value = util.Set(value) + value.remove(None) + value.add('') + value = ','.join(value) + if super_convert: + return super_convert(value) + else: + return value + return process + def get_col_spec(self): return self._extend("SET(%s)" % ",".join(self.__ddl_values)) @@ -1195,21 +1221,24 @@ class MSBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOL" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process colspecs = { sqltypes.Integer: MSInteger, @@ -1284,7 +1313,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext): re.I | re.UNICODE) def post_exec(self): - if self.compiled.isinsert: + if self.compiled.isinsert and not self.executemany: if (not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None): self._last_inserted_ids = ([self.cursor.lastrowid] + diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 2c45c94e8..520332d45 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -32,26 +32,31 @@ class OracleSmallInteger(sqltypes.Smallinteger): class OracleDate(sqltypes.Date): def get_col_spec(self): return "DATE" - def convert_bind_param(self, value, dialect): - return value - def convert_result_value(self, value, dialect): - if not isinstance(value, datetime.datetime): - return value - else: - return value.date() + def bind_processor(self, dialect): + return None + def result_processor(self, dialect): + def process(value): + if not isinstance(value, datetime.datetime): + return value + else: + return value.date() + return process + class OracleDateTime(sqltypes.DateTime): def get_col_spec(self): return "DATE" - def convert_result_value(self, value, dialect): - if value is None or isinstance(value,datetime.datetime): - return value - else: - # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year,value.month, - value.day,value.hour, value.minute, value.second) - + def result_processor(self, dialect): + def process(value): + if value is None or isinstance(value,datetime.datetime): + return value + else: + # convert cx_oracle datetime object returned pre-python 2.4 + return datetime.datetime(value.year,value.month, + value.day,value.hour, value.minute, value.second) + return process + # Note: # Oracle DATE == DATETIME # Oracle does not allow milliseconds in DATE @@ -65,14 +70,15 @@ class OracleTimestamp(sqltypes.TIMESTAMP): def get_dbapi_type(self, dialect): return dialect.TIMESTAMP - def convert_result_value(self, value, dialect): - if value is None or isinstance(value,datetime.datetime): - return value - else: - # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year,value.month, - value.day,value.hour, value.minute, value.second) - + def result_processor(self, dialect): + def process(value): + if value is None or isinstance(value,datetime.datetime): + return value + else: + # convert cx_oracle datetime object returned pre-python 2.4 + return datetime.datetime(value.year,value.month, + value.day,value.hour, value.minute, value.second) + return process class OracleString(sqltypes.String): def get_col_spec(self): @@ -85,15 +91,23 @@ class OracleText(sqltypes.TEXT): def get_col_spec(self): return "CLOB" - def convert_result_value(self, value, dialect): - if value is None: - return None - elif hasattr(value, 'read'): - # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str - return super(OracleText, self).convert_result_value(value.read(), dialect) - else: - return super(OracleText, self).convert_result_value(value, dialect) - + def result_processor(self, dialect): + super_process = super(OracleText, self).result_processor(dialect) + def process(value): + if value is None: + return None + elif hasattr(value, 'read'): + # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str + if super_process: + return super_process(value.read()) + else: + return value.read() + else: + if super_process: + return super_process(value) + else: + return value + return process class OracleRaw(sqltypes.Binary): def get_col_spec(self): @@ -110,34 +124,40 @@ class OracleBinary(sqltypes.Binary): def get_col_spec(self): return "BLOB" - def convert_bind_param(self, value, dialect): - return value - - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return value.read() + def bind_processor(self, dialect): + return None + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return value.read() + return process + class OracleBoolean(sqltypes.Boolean): def get_col_spec(self): return "SMALLINT" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + colspecs = { sqltypes.Integer : OracleInteger, sqltypes.Smallinteger : OracleSmallInteger, @@ -196,7 +216,7 @@ class OracleExecutionContext(default.DefaultExecutionContext): if self.compiled_parameters is not None: for k in self.out_parameters: type = self.compiled_parameters.get_type(k) - self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect) + self.out_parameters[k] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[k].getvalue()) else: for k in self.out_parameters: self.out_parameters[k] = self.out_parameters[k].getvalue() diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index a30832b43..e4897bba6 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -22,14 +22,19 @@ class PGNumeric(sqltypes.Numeric): else: return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - if not self.asdecimal and isinstance(value, util.decimal_type): - return float(value) + def result_processor(self, dialect): + if self.asdecimal: + return None else: - return value + def process(value): + if isinstance(value, util.decimal_type): + return float(value) + else: + return value + return process class PGFloat(sqltypes.Float): def get_col_spec(self): @@ -98,25 +103,38 @@ class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable): impl.__dict__.update(self.__dict__) impl.item_type = self.item_type.dialect_impl(dialect) return impl - def convert_bind_param(self, value, dialect): - if value is None: - return value - def convert_item(item): - if isinstance(item, (list,tuple)): - return [convert_item(child) for child in item] - else: - return self.item_type.convert_bind_param(item, dialect) - return [convert_item(item) for item in value] - def convert_result_value(self, value, dialect): - if value is None: - return value - def convert_item(item): - if isinstance(item, list): - return [convert_item(child) for child in item] - else: - return self.item_type.convert_result_value(item, dialect) - # Could specialcase when item_type.convert_result_value is the default identity func - return [convert_item(item) for item in value] + + def bind_processor(self, dialect): + item_proc = self.item_type.bind_processor(dialect) + def process(value): + if value is None: + return value + def convert_item(item): + if isinstance(item, (list,tuple)): + return [convert_item(child) for child in item] + else: + if item_proc: + return item_proc(item) + else: + return item + return [convert_item(item) for item in value] + return process + + def result_processor(self, dialect): + item_proc = self.item_type.bind_processor(dialect) + def process(value): + if value is None: + return value + def convert_item(item): + if isinstance(item, list): + return [convert_item(child) for child in item] + else: + if item_proc: + return item_proc(item) + else: + return item + return [convert_item(item) for item in value] + return process def get_col_spec(self): return self.item_type.get_col_spec() + '[]' diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 7999cc403..3cc821a36 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -32,15 +32,17 @@ class SLSmallInteger(sqltypes.Smallinteger): return "SMALLINT" class DateTimeMixin(object): - def convert_bind_param(self, value, dialect): - if value is not None: - if getattr(value, 'microsecond', None) is not None: - return value.strftime(self.__format__ + "." + str(value.microsecond)) + def bind_processor(self, dialect): + def process(value): + if value is not None: + if getattr(value, 'microsecond', None) is not None: + return value.strftime(self.__format__ + "." + str(value.microsecond)) + else: + return value.strftime(self.__format__) else: - return value.strftime(self.__format__) - else: - return None - + return None + return process + def _cvt(self, value, dialect): if value is None: return None @@ -57,30 +59,36 @@ class SLDateTime(DateTimeMixin,sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP" - def convert_result_value(self, value, dialect): - tup = self._cvt(value, dialect) - return tup and datetime.datetime(*tup) - + def result_processor(self, dialect): + def process(value): + tup = self._cvt(value, dialect) + return tup and datetime.datetime(*tup) + return process + class SLDate(DateTimeMixin, sqltypes.Date): __format__ = "%Y-%m-%d" def get_col_spec(self): return "DATE" - def convert_result_value(self, value, dialect): - tup = self._cvt(value, dialect) - return tup and datetime.date(*tup[0:3]) - + def result_processor(self, dialect): + def process(value): + tup = self._cvt(value, dialect) + return tup and datetime.date(*tup[0:3]) + return process + class SLTime(DateTimeMixin, sqltypes.Time): __format__ = "%H:%M:%S" def get_col_spec(self): return "TIME" - def convert_result_value(self, value, dialect): - tup = self._cvt(value, dialect) - return tup and datetime.time(*tup[3:7]) - + def result_processor(self, dialect): + def process(value): + tup = self._cvt(value, dialect) + return tup and datetime.time(*tup[3:7]) + return process + class SLText(sqltypes.TEXT): def get_col_spec(self): return "TEXT" @@ -101,16 +109,20 @@ class SLBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" - def convert_bind_param(self, value, dialect): - if value is None: - return None - return value and 1 or 0 - - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + return value and 1 or 0 + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + colspecs = { sqltypes.Integer : SLInteger, sqltypes.Smallinteger : SLSmallInteger, @@ -150,7 +162,7 @@ def descriptor(): class SQLiteExecutionContext(default.DefaultExecutionContext): def post_exec(self): - if self.compiled.isinsert: + if self.compiled.isinsert and not self.executemany: if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 284be6dfe..8fe34bf3f 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1127,20 +1127,17 @@ class ResultProxy(object): col3 = row[mytable.c.mycol] # access via Column object. ResultProxy also contains a map of TypeEngine objects and will - invoke the appropriate ``convert_result_value()`` method before + invoke the appropriate ``result_processor()`` method before returning columns, as well as the ExecutionContext corresponding to the statement execution. It provides several methods for which to obtain information from the underlying ExecutionContext. """ - class AmbiguousColumn(object): - def __init__(self, key): - self.key = key - def dialect_impl(self, dialect): - return self - def convert_result_value(self, arg, engine): - raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key)) - + def __ambiguous_processor(self, colname): + def process(value): + raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % colname) + return process + def __init__(self, context): """ResultProxy objects are constructed via the execute() method on SQLEngine.""" self.context = context @@ -1185,13 +1182,13 @@ class ResultProxy(object): else: type = typemap.get(item[1], types.NULLTYPE) - rec = (type, type.dialect_impl(self.dialect), i) + rec = (type, type.dialect_impl(self.dialect).result_processor(self.dialect), i) if rec[0] is None: raise exceptions.InvalidRequestError( "None for metadata " + colname) if self.__props.setdefault(colname.lower(), rec) is not rec: - self.__props[colname.lower()] = (type, ResultProxy.AmbiguousColumn(colname), 0) + self.__props[colname.lower()] = (type, self.__ambiguous_processor(colname), 0) self.__keys.append(colname) self.__props[i] = rec @@ -1298,7 +1295,10 @@ class ResultProxy(object): def _get_col(self, row, key): rec = self._key_cache[key] - return rec[1].convert_result_value(row[rec[2]], self.dialect) + if rec[1]: + return rec[1](row[rec[2]]) + else: + return row[rec[2]] def _fetchone_impl(self): return self.cursor.fetchone() diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 02802452a..ccaf080e7 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -163,12 +163,17 @@ class DefaultExecutionContext(base.ExecutionContext): self.statement = unicode(compiled) if parameters is None: self.compiled_parameters = compiled.construct_params({}) + self.executemany = False elif not isinstance(parameters, (list, tuple)): self.compiled_parameters = compiled.construct_params(parameters) + self.executemany = False else: self.compiled_parameters = [compiled.construct_params(m or {}) for m in parameters] if len(self.compiled_parameters) == 1: self.compiled_parameters = self.compiled_parameters[0] + self.executemany = False + else: + self.executemany = True elif statement is not None: self.typemap = self.column_labels = None self.parameters = self.__encode_param_keys(parameters) @@ -206,22 +211,26 @@ class DefaultExecutionContext(base.ExecutionContext): return proc(params) def __convert_compiled_params(self, parameters): - executemany = parameters is not None and isinstance(parameters, list) encode = not self.dialect.supports_unicode_statements() # the bind params are a CompiledParams object. but all the DBAPI's hate # that object (or similar). so convert it to a clean # dictionary/list/tuple of dictionary/tuple of list if parameters is not None: - if self.dialect.positional: - if executemany: - parameters = [p.get_raw_list() for p in parameters] + if self.executemany: + processors = parameters[0].get_processors() + else: + processors = parameters.get_processors() + + if self.dialect.positional: + if self.executemany: + parameters = [p.get_raw_list(processors) for p in parameters] else: - parameters = parameters.get_raw_list() - else: - if executemany: - parameters = [p.get_raw_dict(encode_keys=encode) for p in parameters] + parameters = parameters.get_raw_list(processors) + else: + if self.executemany: + parameters = [p.get_raw_dict(processors, encode_keys=encode) for p in parameters] else: - parameters = parameters.get_raw_dict(encode_keys=encode) + parameters = parameters.get_raw_dict(processors, encode_keys=encode) return parameters def is_select(self): @@ -311,28 +320,31 @@ class DefaultExecutionContext(base.ExecutionContext): """generate default values for compiled insert/update statements, and generate last_inserted_ids() collection.""" - # TODO: cleanup if self.isinsert: - if isinstance(self.compiled_parameters, list): - plist = self.compiled_parameters - else: - plist = [self.compiled_parameters] drunner = self.dialect.defaultrunner(self) - for param in plist: + if self.executemany: + # executemany doesn't populate last_inserted_ids() + firstparam = self.compiled_parameters[0] + processors = firstparam.get_processors() + for c in self.compiled.statement.table.c: + if c.default is not None: + params = self.compiled_parameters + for param in params: + if not c.key in param or param.get_original(c.key) is None: + self.compiled_parameters = param + newid = drunner.get_column_default(c) + if newid is not None: + param.set_value(c.key, newid) + self.compiled_parameters = params + else: + param = self.compiled_parameters + processors = param.get_processors() last_inserted_ids = [] - # check the "default" status of each column in the table for c in self.compiled.statement.table.c: - # check if it will be populated by a SQL clause - we'll need that - # after execution. if c in self.compiled.inline_params: self._postfetch_cols.add(c) if c.primary_key: last_inserted_ids.append(None) - # check if its not present at all. see if theres a default - # and fire it off, and add to bind parameters. if - # its a pk, add the value to our last_inserted_ids list, - # or, if its a SQL-side default, let it fire off on the DB side, but we'll need - # the SQL-generated value after execution. elif not c.key in param or param.get_original(c.key) is None: if isinstance(c.default, schema.PassiveDefault): self._postfetch_cols.add(c) @@ -340,32 +352,33 @@ class DefaultExecutionContext(base.ExecutionContext): if newid is not None: param.set_value(c.key, newid) if c.primary_key: - last_inserted_ids.append(param.get_processed(c.key)) + last_inserted_ids.append(param.get_processed(c.key, processors)) elif c.primary_key: last_inserted_ids.append(None) - # its an explicitly passed pk value - add it to - # our last_inserted_ids list. elif c.primary_key: - last_inserted_ids.append(param.get_processed(c.key)) - # TODO: we arent accounting for executemany() situations - # here (hard to do since lastrowid doesnt support it either) + last_inserted_ids.append(param.get_processed(c.key, processors)) self._last_inserted_ids = last_inserted_ids self._last_inserted_params = param + + elif self.isupdate: - if isinstance(self.compiled_parameters, list): - plist = self.compiled_parameters - else: - plist = [self.compiled_parameters] drunner = self.dialect.defaultrunner(self) - for param in plist: - # check the "onupdate" status of each column in the table + if self.executemany: + for c in self.compiled.statement.table.c: + if c.onupdate is not None: + params = self.compiled_parameters + for param in params: + if not c.key in param or param.get_original(c.key) is None: + self.compiled_parameters = param + value = drunner.get_column_onupdate(c) + if value is not None: + param.set_value(c.key, value) + self.compiled_parameters = params + else: + param = self.compiled_parameters for c in self.compiled.statement.table.c: - # it will be populated by a SQL clause - we'll need that - # after execution. if c in self.compiled.inline_params: self._postfetch_cols.add(c) - # its not in the bind parameters, and theres an "onupdate" defined for the column; - # execute it and add to bind params elif c.onupdate is not None and (not c.key in param or param.get_original(c.key) is None): value = drunner.get_column_onupdate(c) if value is not None: diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 3fc13a50d..994a877bd 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -812,7 +812,6 @@ class ClauseParameters(object): """ def __init__(self, dialect, positional=None): - super(ClauseParameters, self).__init__() self.dialect = dialect self.__binds = {} self.positional = positional or [] @@ -829,19 +828,31 @@ class ClauseParameters(object): def get_type(self, key): return self.__binds[key][0].type - def get_processed(self, key): - (bind, name, value) = self.__binds[key] - return bind.typeprocess(value, self.dialect) - + def get_processors(self): + """return a dictionary of bind 'processing' functions""" + return dict([ + (key, value) for key, value in + [( + key, + self.__binds[key][0].bind_processor(self.dialect) + ) for key in self.__binds] + if value is not None + ]) + + def get_processed(self, key, processors): + return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2] + def keys(self): return self.__binds.keys() def __iter__(self): return iter(self.keys()) - - def __getitem__(self, key): - return self.get_processed(key) + def __getitem__(self, key): + (bind, name, value) = self.__binds[key] + processor = bind.bind_processor(self.dialect) + return processor is not None and processor(value) or value + def __contains__(self, key): return key in self.__binds @@ -851,14 +862,36 @@ class ClauseParameters(object): def get_original_dict(self): return dict([(name, value) for (b, name, value) in self.__binds.values()]) - def get_raw_list(self): - return [self.get_processed(key) for key in self.positional] + def get_raw_list(self, processors): +# (bind, name, value) = self.__binds[key] + return [ + (key in processors) and + processors[key](self.__binds[key][2]) or + self.__binds[key][2] + for key in self.positional + ] - def get_raw_dict(self, encode_keys=False): + def get_raw_dict(self, processors, encode_keys=False): if encode_keys: - return dict([(key.encode(self.dialect.encoding), self.get_processed(key)) for key in self.keys()]) + return dict([ + ( + key.encode(self.dialect.encoding), + (key in processors) and + processors[key](self.__binds[key][2]) or + self.__binds[key][2] + ) + for key in self.keys() + ]) else: - return dict([(key, self.get_processed(key)) for key in self.keys()]) + return dict([ + ( + key, + (key in processors) and + processors[key](self.__binds[key][2]) or + self.__binds[key][2] + ) + for key in self.keys() + ]) def __repr__(self): return self.__class__.__name__ + ":" + repr(self.get_original_dict()) @@ -1995,8 +2028,8 @@ class _BindParamClause(ClauseElement, _CompareMixin): def _get_from_objects(self, **modifiers): return [] - def typeprocess(self, value, dialect): - return self.type.dialect_impl(dialect).convert_bind_param(value, dialect) + def bind_processor(self, dialect): + return self.type.dialect_impl(dialect).bind_processor(dialect) def _compare_type(self, obj): if not isinstance(self.type, sqltypes.NullType): diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index fe05910df..f3854e3e1 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -59,12 +59,13 @@ class TypeEngine(AbstractType): def get_col_spec(self): raise NotImplementedError() - def convert_bind_param(self, value, dialect): - return value - - def convert_result_value(self, value, dialect): - return value + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + return None + def adapt(self, cls): return cls() @@ -115,11 +116,11 @@ class TypeDecorator(AbstractType): def get_col_spec(self): return self.impl.get_col_spec() - def convert_bind_param(self, value, dialect): - return self.impl.convert_bind_param(value, dialect) + def bind_processor(self, dialect): + return self.impl.bind_processor(dialect) - def convert_result_value(self, value, dialect): - return self.impl.convert_result_value(value, dialect) + def result_processor(self, dialect): + return self.impl.result_processor(dialect) def copy(self): instance = self.__class__.__new__(self.__class__) @@ -183,11 +184,6 @@ class NullType(TypeEngine): def get_col_spec(self): raise NotImplementedError() - def convert_bind_param(self, value, dialect): - return value - - def convert_result_value(self, value, dialect): - return value NullTypeEngine = NullType class Concatenable(object): @@ -202,11 +198,27 @@ class String(TypeEngine, Concatenable): def adapt(self, impltype): return impltype(length=self.length, convert_unicode=self.convert_unicode) - def convert_bind_param(self, value, dialect): - if not (self.convert_unicode or dialect.convert_unicode) or value is None or not isinstance(value, unicode): - return value + def bind_processor(self, dialect): + if self.convert_unicode or dialect.convert_unicode: + def process(value): + if isinstance(value, unicode): + return value.encode(dialect.encoding) + else: + return value + return process else: - return value.encode(dialect.encoding) + return None + + def result_processor(self, dialect): + if self.convert_unicode or dialect.convert_unicode: + def process(value): + if value is not None and not isinstance(value, unicode): + return value.decode(dialect.encoding) + else: + return value + return process + else: + return None def get_search_list(self): l = super(String, self).get_search_list() @@ -215,11 +227,6 @@ class String(TypeEngine, Concatenable): else: return l - def convert_result_value(self, value, dialect): - if not (self.convert_unicode or dialect.convert_unicode) or value is None or isinstance(value, unicode): - return value - else: - return value.decode(dialect.encoding) def get_dbapi_type(self, dbapi): return dbapi.STRING @@ -254,17 +261,24 @@ class Numeric(TypeEngine): def get_dbapi_type(self, dbapi): return dbapi.NUMBER - def convert_bind_param(self, value, dialect): - if value is not None: - return float(value) - else: - return value - - def convert_result_value(self, value, dialect): - if value is not None and self.asdecimal: - return Decimal(str(value)) + def bind_processor(self, dialect): + def process(value): + if value is not None: + return float(value) + else: + return value + return process + + def result_processor(self, dialect): + if self.asdecimal: + def process(value): + if value is not None: + return Decimal(str(value)) + else: + return value + return process else: - return value + return None class Float(Numeric): def __init__(self, precision = 10, asdecimal=False, **kwargs): @@ -308,15 +322,14 @@ class Binary(TypeEngine): def __init__(self, length=None): self.length = length - def convert_bind_param(self, value, dialect): - if value is not None: - return dialect.dbapi.Binary(value) - else: - return None - - def convert_result_value(self, value, dialect): - return value - + def bind_processor(self, dialect): + def process(value): + if value is not None: + return dialect.dbapi.Binary(value) + else: + return None + return process + def adapt(self, impltype): return impltype(length=self.length) @@ -332,17 +345,27 @@ class PickleType(MutableType, TypeDecorator): self.mutable = mutable super(PickleType, self).__init__() - def convert_result_value(self, value, dialect): - if value is None: - return None - buf = self.impl.convert_result_value(value, dialect) - return self.pickler.loads(str(buf)) - - def convert_bind_param(self, value, dialect): - if value is None: - return None - return self.impl.convert_bind_param(self.pickler.dumps(value, self.protocol), dialect) - + def bind_processor(self, dialect): + impl_process = self.impl.bind_processor(dialect) + def process(value): + if value is None: + return None + if impl_process is None: + return self.pickler.dumps(value, self.protocol) + else: + return impl_process(self.pickler.dumps(value, self.protocol)) + return process + + def result_processor(self, dialect): + impl_process = self.impl.result_processor(dialect) + def process(value): + if value is None: + return None + if impl_process is not None: + value = impl_process(value) + return self.pickler.loads(str(value)) + return process + def copy_value(self, value): if self.mutable: return self.pickler.loads(self.pickler.dumps(value, self.protocol)) @@ -370,8 +393,8 @@ class Interval(TypeDecorator): Converting is very simple - just use epoch(zero timestamp, 01.01.1970) as base, so if we need to store timedelta = 1 day (24 hours) in database it - will be stored as DateTime = '2nd Jan 1970 00:00', see convert_bind_param - and convert_result_value to actual conversion code + will be stored as DateTime = '2nd Jan 1970 00:00', see bind_processor + and result_processor to actual conversion code """ #Empty useless type, because at the moment of creation of instance we don't #know what type will be decorated - it depends on used dialect. @@ -396,25 +419,35 @@ class Interval(TypeDecorator): def __hasNativeImpl(self,dialect): return dialect.__class__ in self.__supported - - def convert_bind_param(self, value, dialect): - if value is None: - return None - if not self.__hasNativeImpl(dialect): - tmpval = dt.datetime.utcfromtimestamp(0) + value - return self.impl.convert_bind_param(tmpval,dialect) + + def bind_processor(self, dialect): + impl_processor = self.impl.bind_processor(dialect) + if self.__hasNativeImpl(dialect): + return impl_processor else: - return self.impl.convert_bind_param(value,dialect) - - def convert_result_value(self, value, dialect): - if value is None: - return None - retval = self.impl.convert_result_value(value,dialect) - if not self.__hasNativeImpl(dialect): - return retval - dt.datetime.utcfromtimestamp(0) + def process(value): + if value is None: + return None + tmpval = dt.datetime.utcfromtimestamp(0) + value + if impl_processor is not None: + return impl_processor(tmpval) + else: + return tmpval + return process + + def result_processor(self, dialect): + impl_processor = self.impl.result_processor(dialect) + if self.__hasNativeImpl(dialect): + return impl_processor else: - return retval - + def process(value): + if value is None: + return None + if impl_processor is not None: + value = impl_processor(value) + return value - dt.datetime.utcfromtimestamp(0) + return process + class FLOAT(Float):pass class TEXT(String):pass class DECIMAL(Numeric):pass |