From ed4fc64bb0ac61c27bc4af32962fb129e74a36bf Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 27 Jul 2007 04:08:53 +0000 Subject: merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3. --- lib/sqlalchemy/databases/mysql.py | 107 ++++++++++++++++++++++++++------------ 1 file changed, 75 insertions(+), 32 deletions(-) (limited to 'lib/sqlalchemy/databases/mysql.py') diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index bac0e5e12..26800e32b 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re, datetime, inspect, warnings, weakref +import re, datetime, inspect, warnings, weakref, operator from sqlalchemy import sql, schema, ansisql from sqlalchemy.engine import default @@ -12,13 +12,13 @@ import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions import sqlalchemy.util as util from array import array as _array +from decimal import Decimal try: from threading import Lock except ImportError: from dummy_threading import Lock - RESERVED_WORDS = util.Set( ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc', 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both', @@ -60,7 +60,6 @@ RESERVED_WORDS = util.Set( 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', 'read_only', 'read_write', # 5.1 ]) - _per_connection_mutex = Lock() class _NumericType(object): @@ -137,7 +136,7 @@ class _StringType(object): class MSNumeric(sqltypes.Numeric, _NumericType): """MySQL NUMERIC type""" - def __init__(self, precision = 10, length = 2, **kw): + def __init__(self, precision = 10, length = 2, asdecimal=True, **kw): """Construct a NUMERIC. precision @@ -157,18 +156,27 @@ class MSNumeric(sqltypes.Numeric, _NumericType): """ _NumericType.__init__(self, **kw) - sqltypes.Numeric.__init__(self, precision, length) - + sqltypes.Numeric.__init__(self, precision, length, asdecimal=asdecimal) + def get_col_spec(self): if self.precision is None: return self._extend("NUMERIC") else: return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}) + def convert_bind_param(self, value, dialect): + return value + + def convert_result_value(self, value, dialect): + if not self.asdecimal and isinstance(value, Decimal): + return float(value) + else: + return value + class MSDecimal(MSNumeric): """MySQL DECIMAL type""" - def __init__(self, precision=10, length=2, **kw): + def __init__(self, precision=10, length=2, asdecimal=True, **kw): """Construct a DECIMAL. precision @@ -187,7 +195,7 @@ class MSDecimal(MSNumeric): underlying database API, which continue to be numeric. """ - super(MSDecimal, self).__init__(precision, length, **kw) + super(MSDecimal, self).__init__(precision, length, asdecimal=asdecimal, **kw) def get_col_spec(self): if self.precision is None: @@ -200,7 +208,7 @@ class MSDecimal(MSNumeric): class MSDouble(MSNumeric): """MySQL DOUBLE type""" - def __init__(self, precision=10, length=2, **kw): + def __init__(self, precision=10, length=2, asdecimal=True, **kw): """Construct a DOUBLE. precision @@ -222,7 +230,7 @@ class MSDouble(MSNumeric): if ((precision is None and length is not None) or (precision is not None and length is None)): raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.") - super(MSDouble, self).__init__(precision, length, **kw) + super(MSDouble, self).__init__(precision, length, asdecimal=asdecimal, **kw) def get_col_spec(self): if self.precision is not None and self.length is not None: @@ -235,7 +243,7 @@ class MSDouble(MSNumeric): class MSFloat(sqltypes.Float, _NumericType): """MySQL FLOAT type""" - def __init__(self, precision=10, length=None, **kw): + def __init__(self, precision=10, length=None, asdecimal=False, **kw): """Construct a FLOAT. precision @@ -257,7 +265,7 @@ class MSFloat(sqltypes.Float, _NumericType): if length is not None: self.length=length _NumericType.__init__(self, **kw) - sqltypes.Float.__init__(self, precision) + sqltypes.Float.__init__(self, precision, asdecimal=asdecimal) def get_col_spec(self): if hasattr(self, 'length') and self.length is not None: @@ -267,6 +275,10 @@ class MSFloat(sqltypes.Float, _NumericType): else: return self._extend("FLOAT") + def convert_bind_param(self, value, dialect): + return value + + class MSInteger(sqltypes.Integer, _NumericType): """MySQL INTEGER type""" @@ -955,7 +967,10 @@ class MySQLExecutionContext(default.DefaultExecutionContext): if self.compiled.isinsert: if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] - + + def is_select(self): + return re.match(r'SELECT|SHOW|DESCRIBE|XA RECOVER', self.statement.lstrip(), re.I) is not None + class MySQLDialect(ansisql.ANSIDialect): def __init__(self, **kwargs): ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs) @@ -1044,6 +1059,27 @@ class MySQLDialect(ansisql.ANSIDialect): except: pass + def do_begin_twophase(self, connection, xid): + connection.execute(sql.text("XA BEGIN :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute(sql.text("XA PREPARE :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + if not is_prepared: + connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute(sql.text("XA ROLLBACK :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + connection.execute(sql.text("XA COMMIT :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_recover_twophase(self, connection): + resultset = connection.execute(sql.text("XA RECOVER")) + return [row['data'][0:row['gtrid_length']] for row in resultset] + def is_disconnect(self, e): return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055) @@ -1088,7 +1124,7 @@ class MySQLDialect(ansisql.ANSIDialect): version.append(n) return tuple(version) - def reflecttable(self, connection, table): + def reflecttable(self, connection, table, include_columns): """Load column definitions from the server.""" decode_from = self._detect_charset(connection) @@ -1111,6 +1147,9 @@ class MySQLDialect(ansisql.ANSIDialect): # leave column names as unicode name = name.decode(decode_from) + + if include_columns and name not in include_columns: + continue match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type) col_type = match.group(1) @@ -1118,7 +1157,11 @@ class MySQLDialect(ansisql.ANSIDialect): extra_1 = match.group(3) extra_2 = match.group(4) - coltype = ischema_names.get(col_type, MSString) + try: + coltype = ischema_names[col_type] + except KeyError: + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (col_type, name))) + coltype = sqltypes.NULLTYPE kw = {} if extra_1 is not None: @@ -1156,7 +1199,6 @@ class MySQLDialect(ansisql.ANSIDialect): if not row: raise exceptions.NoSuchTableError(table.fullname) desc = row[1].strip() - row.close() tabletype = '' lastparen = re.search(r'\)[^\)]*\Z', desc) @@ -1223,7 +1265,6 @@ class MySQLDialect(ansisql.ANSIDialect): cs = True else: cs = row[1] in ('0', 'OFF' 'off') - row.close() cache['lower_case_table_names'] = cs self.per_connection[raw_connection] = cache return cache.get('lower_case_table_names') @@ -1266,14 +1307,21 @@ class _MySQLPythonRowProxy(object): class MySQLCompiler(ansisql.ANSICompiler): - def visit_cast(self, cast): - + operators = ansisql.ANSICompiler.operators.copy() + operators.update( + { + sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y), + operator.mod : '%%' + } + ) + + def visit_cast(self, cast, **kwargs): if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)): - return super(MySQLCompiler, self).visit_cast(cast) + return super(MySQLCompiler, self).visit_cast(cast, **kwargs) else: # so just skip the CAST altogether for now. # TODO: put whatever MySQL does for CAST here. - self.strings[cast] = self.strings[cast.clause] + return self.process(cast.clause) def for_update_clause(self, select): if select.for_update == 'read': @@ -1283,20 +1331,15 @@ class MySQLCompiler(ansisql.ANSICompiler): def limit_clause(self, select): text = "" - if select.limit is not None: - text += " \n LIMIT " + str(select.limit) - if select.offset is not None: - if select.limit is None: - # striaght from the MySQL docs, I kid you not + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + # straight from the MySQL docs, I kid you not text += " \n LIMIT 18446744073709551615" - text += " OFFSET " + str(select.offset) + text += " OFFSET " + str(select._offset) return text - def binary_operator_string(self, binary): - if binary.operator == '%': - return '%%' - else: - return ansisql.ANSICompiler.binary_operator_string(self, binary) class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): -- cgit v1.2.1