summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/databases/mysql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
commited4fc64bb0ac61c27bc4af32962fb129e74a36bf (patch)
treec1cf2fb7b1cafced82a8898e23d2a0bf5ced8526 /lib/sqlalchemy/databases/mysql.py
parent3a8e235af64e36b3b711df1f069d32359fe6c967 (diff)
downloadsqlalchemy-ed4fc64bb0ac61c27bc4af32962fb129e74a36bf.tar.gz
merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3.
Diffstat (limited to 'lib/sqlalchemy/databases/mysql.py')
-rw-r--r--lib/sqlalchemy/databases/mysql.py107
1 files changed, 75 insertions, 32 deletions
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index bac0e5e12..26800e32b 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -4,7 +4,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import re, datetime, inspect, warnings, weakref
+import re, datetime, inspect, warnings, weakref, operator
from sqlalchemy import sql, schema, ansisql
from sqlalchemy.engine import default
@@ -12,13 +12,13 @@ import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions
import sqlalchemy.util as util
from array import array as _array
+from decimal import Decimal
try:
from threading import Lock
except ImportError:
from dummy_threading import Lock
-
RESERVED_WORDS = util.Set(
['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc',
'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both',
@@ -60,7 +60,6 @@ RESERVED_WORDS = util.Set(
'accessible', 'linear', 'master_ssl_verify_server_cert', 'range',
'read_only', 'read_write', # 5.1
])
-
_per_connection_mutex = Lock()
class _NumericType(object):
@@ -137,7 +136,7 @@ class _StringType(object):
class MSNumeric(sqltypes.Numeric, _NumericType):
"""MySQL NUMERIC type"""
- def __init__(self, precision = 10, length = 2, **kw):
+ def __init__(self, precision = 10, length = 2, asdecimal=True, **kw):
"""Construct a NUMERIC.
precision
@@ -157,18 +156,27 @@ class MSNumeric(sqltypes.Numeric, _NumericType):
"""
_NumericType.__init__(self, **kw)
- sqltypes.Numeric.__init__(self, precision, length)
-
+ sqltypes.Numeric.__init__(self, precision, length, asdecimal=asdecimal)
+
def get_col_spec(self):
if self.precision is None:
return self._extend("NUMERIC")
else:
return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
+ def convert_bind_param(self, value, dialect):
+ return value
+
+ def convert_result_value(self, value, dialect):
+ if not self.asdecimal and isinstance(value, Decimal):
+ return float(value)
+ else:
+ return value
+
class MSDecimal(MSNumeric):
"""MySQL DECIMAL type"""
- def __init__(self, precision=10, length=2, **kw):
+ def __init__(self, precision=10, length=2, asdecimal=True, **kw):
"""Construct a DECIMAL.
precision
@@ -187,7 +195,7 @@ class MSDecimal(MSNumeric):
underlying database API, which continue to be numeric.
"""
- super(MSDecimal, self).__init__(precision, length, **kw)
+ super(MSDecimal, self).__init__(precision, length, asdecimal=asdecimal, **kw)
def get_col_spec(self):
if self.precision is None:
@@ -200,7 +208,7 @@ class MSDecimal(MSNumeric):
class MSDouble(MSNumeric):
"""MySQL DOUBLE type"""
- def __init__(self, precision=10, length=2, **kw):
+ def __init__(self, precision=10, length=2, asdecimal=True, **kw):
"""Construct a DOUBLE.
precision
@@ -222,7 +230,7 @@ class MSDouble(MSNumeric):
if ((precision is None and length is not None) or
(precision is not None and length is None)):
raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.")
- super(MSDouble, self).__init__(precision, length, **kw)
+ super(MSDouble, self).__init__(precision, length, asdecimal=asdecimal, **kw)
def get_col_spec(self):
if self.precision is not None and self.length is not None:
@@ -235,7 +243,7 @@ class MSDouble(MSNumeric):
class MSFloat(sqltypes.Float, _NumericType):
"""MySQL FLOAT type"""
- def __init__(self, precision=10, length=None, **kw):
+ def __init__(self, precision=10, length=None, asdecimal=False, **kw):
"""Construct a FLOAT.
precision
@@ -257,7 +265,7 @@ class MSFloat(sqltypes.Float, _NumericType):
if length is not None:
self.length=length
_NumericType.__init__(self, **kw)
- sqltypes.Float.__init__(self, precision)
+ sqltypes.Float.__init__(self, precision, asdecimal=asdecimal)
def get_col_spec(self):
if hasattr(self, 'length') and self.length is not None:
@@ -267,6 +275,10 @@ class MSFloat(sqltypes.Float, _NumericType):
else:
return self._extend("FLOAT")
+ def convert_bind_param(self, value, dialect):
+ return value
+
+
class MSInteger(sqltypes.Integer, _NumericType):
"""MySQL INTEGER type"""
@@ -955,7 +967,10 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
if self.compiled.isinsert:
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
-
+
+ def is_select(self):
+ return re.match(r'SELECT|SHOW|DESCRIBE|XA RECOVER', self.statement.lstrip(), re.I) is not None
+
class MySQLDialect(ansisql.ANSIDialect):
def __init__(self, **kwargs):
ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
@@ -1044,6 +1059,27 @@ class MySQLDialect(ansisql.ANSIDialect):
except:
pass
+ def do_begin_twophase(self, connection, xid):
+ connection.execute(sql.text("XA BEGIN :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
+ connection.execute(sql.text("XA PREPARE :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if not is_prepared:
+ connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
+ connection.execute(sql.text("XA ROLLBACK :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if not is_prepared:
+ self.do_prepare_twophase(connection, xid)
+ connection.execute(sql.text("XA COMMIT :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+ def do_recover_twophase(self, connection):
+ resultset = connection.execute(sql.text("XA RECOVER"))
+ return [row['data'][0:row['gtrid_length']] for row in resultset]
+
def is_disconnect(self, e):
return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055)
@@ -1088,7 +1124,7 @@ class MySQLDialect(ansisql.ANSIDialect):
version.append(n)
return tuple(version)
- def reflecttable(self, connection, table):
+ def reflecttable(self, connection, table, include_columns):
"""Load column definitions from the server."""
decode_from = self._detect_charset(connection)
@@ -1111,6 +1147,9 @@ class MySQLDialect(ansisql.ANSIDialect):
# leave column names as unicode
name = name.decode(decode_from)
+
+ if include_columns and name not in include_columns:
+ continue
match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type)
col_type = match.group(1)
@@ -1118,7 +1157,11 @@ class MySQLDialect(ansisql.ANSIDialect):
extra_1 = match.group(3)
extra_2 = match.group(4)
- coltype = ischema_names.get(col_type, MSString)
+ try:
+ coltype = ischema_names[col_type]
+ except KeyError:
+ warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (col_type, name)))
+ coltype = sqltypes.NULLTYPE
kw = {}
if extra_1 is not None:
@@ -1156,7 +1199,6 @@ class MySQLDialect(ansisql.ANSIDialect):
if not row:
raise exceptions.NoSuchTableError(table.fullname)
desc = row[1].strip()
- row.close()
tabletype = ''
lastparen = re.search(r'\)[^\)]*\Z', desc)
@@ -1223,7 +1265,6 @@ class MySQLDialect(ansisql.ANSIDialect):
cs = True
else:
cs = row[1] in ('0', 'OFF' 'off')
- row.close()
cache['lower_case_table_names'] = cs
self.per_connection[raw_connection] = cache
return cache.get('lower_case_table_names')
@@ -1266,14 +1307,21 @@ class _MySQLPythonRowProxy(object):
class MySQLCompiler(ansisql.ANSICompiler):
- def visit_cast(self, cast):
-
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y),
+ operator.mod : '%%'
+ }
+ )
+
+ def visit_cast(self, cast, **kwargs):
if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)):
- return super(MySQLCompiler, self).visit_cast(cast)
+ return super(MySQLCompiler, self).visit_cast(cast, **kwargs)
else:
# so just skip the CAST altogether for now.
# TODO: put whatever MySQL does for CAST here.
- self.strings[cast] = self.strings[cast.clause]
+ return self.process(cast.clause)
def for_update_clause(self, select):
if select.for_update == 'read':
@@ -1283,20 +1331,15 @@ class MySQLCompiler(ansisql.ANSICompiler):
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
- # striaght from the MySQL docs, I kid you not
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
+ # straight from the MySQL docs, I kid you not
text += " \n LIMIT 18446744073709551615"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
return text
- def binary_operator_string(self, binary):
- if binary.operator == '%':
- return '%%'
- else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):