summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/databases/mysql.py
diff options
context:
space:
mode:
authorJason Kirtland <jek@discorporate.us>2007-07-29 16:13:23 +0000
committerJason Kirtland <jek@discorporate.us>2007-07-29 16:13:23 +0000
commit6a1d279a9a702a4c520446f52edeb1ee4864dd70 (patch)
tree4c69f76299ee30f3b715f7c43fcfc880d51803cd /lib/sqlalchemy/databases/mysql.py
parentc9cc90bbdc24f8d4d0429468404cd43de46fc07f (diff)
downloadsqlalchemy-6a1d279a9a702a4c520446f52edeb1ee4864dd70.tar.gz
Big MySQL dialect update, mostly efficiency and style.
Added TINYINT [ticket:691]- whoa, how did that one go missing for so long? Added a charset-fixing pool listener. The driver-level option doesn't help everyone with this one. New reflector code not quite done and omiited from this commit.
Diffstat (limited to 'lib/sqlalchemy/databases/mysql.py')
-rw-r--r--lib/sqlalchemy/databases/mysql.py426
1 files changed, 309 insertions, 117 deletions
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index 53ef1a95b..2c54c2512 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -5,25 +5,21 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import re, datetime, inspect, warnings, weakref, operator
+from array import array as _array
+from decimal import Decimal
from sqlalchemy import sql, schema, ansisql
from sqlalchemy.engine import default
import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions
import sqlalchemy.util as util
-from array import array as _array
-from decimal import Decimal
-try:
- from threading import Lock
-except ImportError:
- from dummy_threading import Lock
RESERVED_WORDS = util.Set(
['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc',
'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both',
'by', 'call', 'cascade', 'case', 'change', 'char', 'character', 'check',
- 'collate', 'column', 'condition', 'constraint', 'continue', 'convert',
+ 'collate', 'column', 'con dition', 'constraint', 'continue', 'convert',
'create', 'cross', 'current_date', 'current_time', 'current_timestamp',
'current_user', 'cursor', 'database', 'databases', 'day_hour',
'day_microsecond', 'day_minute', 'day_second', 'dec', 'decimal',
@@ -60,7 +56,7 @@ RESERVED_WORDS = util.Set(
'accessible', 'linear', 'master_ssl_verify_server_cert', 'range',
'read_only', 'read_write', # 5.1
])
-_per_connection_mutex = Lock()
+
class _NumericType(object):
"Base for MySQL numeric types."
@@ -78,6 +74,7 @@ class _NumericType(object):
spec += ' ZEROFILL'
return spec
+
class _StringType(object):
"Base for MySQL string types."
@@ -133,10 +130,11 @@ class _StringType(object):
return "%s(%s)" % (self.__class__.__name__,
','.join(['%s=%s' % (k, params[k]) for k in params]))
+
class MSNumeric(sqltypes.Numeric, _NumericType):
"""MySQL NUMERIC type"""
- def __init__(self, precision = 10, length = 2, asdecimal=True, **kw):
+ def __init__(self, precision=10, length=2, asdecimal=True, **kw):
"""Construct a NUMERIC.
precision
@@ -173,6 +171,7 @@ class MSNumeric(sqltypes.Numeric, _NumericType):
else:
return value
+
class MSDecimal(MSNumeric):
"""MySQL DECIMAL type"""
@@ -205,6 +204,7 @@ class MSDecimal(MSNumeric):
else:
return self._extend("DECIMAL(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
+
class MSDouble(MSNumeric):
"""MySQL DOUBLE type"""
@@ -240,6 +240,7 @@ class MSDouble(MSNumeric):
else:
return self._extend('DOUBLE')
+
class MSFloat(sqltypes.Float, _NumericType):
"""MySQL FLOAT type"""
@@ -307,6 +308,7 @@ class MSInteger(sqltypes.Integer, _NumericType):
else:
return self._extend("INTEGER")
+
class MSBigInteger(MSInteger):
"""MySQL BIGINTEGER type"""
@@ -333,6 +335,34 @@ class MSBigInteger(MSInteger):
else:
return self._extend("BIGINT")
+
+class MSTinyInteger(MSInteger):
+ """MySQL TINYINT type"""
+
+ def __init__(self, length=None, **kw):
+ """Construct a SMALLINTEGER.
+
+ length
+ Optional, maximum display width for this number.
+
+ unsigned
+ Optional.
+
+ zerofill
+ Optional. If true, values will be stored as strings left-padded with
+ zeros. Note that this does not effect the values returned by the
+ underlying database API, which continue to be numeric.
+ """
+
+ super(MSTinyInteger, self).__init__(length, **kw)
+
+ def get_col_spec(self):
+ if self.length is not None:
+ return self._extend("TINYINT(%s)" % self.length)
+ else:
+ return self._extend("TINYINT")
+
+
class MSSmallInteger(sqltypes.Smallinteger, _NumericType):
"""MySQL SMALLINTEGER type"""
@@ -361,18 +391,21 @@ class MSSmallInteger(sqltypes.Smallinteger, _NumericType):
else:
return self._extend("SMALLINT")
+
class MSDateTime(sqltypes.DateTime):
"""MySQL DATETIME type"""
def get_col_spec(self):
return "DATETIME"
+
class MSDate(sqltypes.Date):
"""MySQL DATE type"""
def get_col_spec(self):
return "DATE"
+
class MSTime(sqltypes.Time):
"""MySQL TIME type"""
@@ -386,6 +419,7 @@ class MSTime(sqltypes.Time):
else:
return None
+
class MSTimeStamp(sqltypes.TIMESTAMP):
"""MySQL TIMESTAMP type
@@ -399,6 +433,7 @@ class MSTimeStamp(sqltypes.TIMESTAMP):
def get_col_spec(self):
return "TIMESTAMP"
+
class MSYear(sqltypes.String):
"""MySQL YEAR type, for single byte storage of years 1901-2155"""
@@ -408,6 +443,7 @@ class MSYear(sqltypes.String):
else:
return "YEAR(%d)" % self.length
+
class MSText(_StringType, sqltypes.TEXT):
"""MySQL TEXT type, for text up to 2^16 characters"""
@@ -495,6 +531,7 @@ class MSTinyText(MSText):
def get_col_spec(self):
return self._extend("TINYTEXT")
+
class MSMediumText(MSText):
"""MySQL MEDIUMTEXT type, for text up to 2^24 characters"""
@@ -533,6 +570,7 @@ class MSMediumText(MSText):
def get_col_spec(self):
return self._extend("MEDIUMTEXT")
+
class MSLongText(MSText):
"""MySQL LONGTEXT type, for text up to 2^32 characters"""
@@ -571,6 +609,7 @@ class MSLongText(MSText):
def get_col_spec(self):
return self._extend("LONGTEXT")
+
class MSString(_StringType, sqltypes.String):
"""MySQL VARCHAR type, for variable-length character data."""
@@ -617,6 +656,7 @@ class MSString(_StringType, sqltypes.String):
else:
return self._extend("TEXT")
+
class MSChar(_StringType, sqltypes.CHAR):
"""MySQL CHAR type, for fixed-length character data."""
@@ -642,6 +682,7 @@ class MSChar(_StringType, sqltypes.CHAR):
def get_col_spec(self):
return self._extend("CHAR(%(length)s)" % {'length' : self.length})
+
class MSNVarChar(_StringType, sqltypes.String):
"""MySQL NVARCHAR type, for variable-length character data in the
server's configured national character set.
@@ -673,6 +714,7 @@ class MSNVarChar(_StringType, sqltypes.String):
# of "NVARCHAR".
return self._extend("VARCHAR(%(length)s)" % {'length': self.length})
+
class MSNChar(_StringType, sqltypes.CHAR):
"""MySQL NCHAR type, for fixed-length character data in the
server's configured national character set.
@@ -701,6 +743,7 @@ class MSNChar(_StringType, sqltypes.CHAR):
# We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR".
return self._extend("CHAR(%(length)s)" % {'length': self.length})
+
class _BinaryType(sqltypes.Binary):
"""MySQL binary types"""
@@ -716,6 +759,7 @@ class _BinaryType(sqltypes.Binary):
else:
return buffer(value)
+
class MSVarBinary(_BinaryType):
"""MySQL VARBINARY type, for variable length binary data"""
@@ -733,6 +777,7 @@ class MSVarBinary(_BinaryType):
else:
return "BLOB"
+
class MSBinary(_BinaryType):
"""MySQL BINARY type, for fixed length binary data"""
@@ -760,10 +805,10 @@ class MSBinary(_BinaryType):
else:
return buffer(value)
+
class MSBlob(_BinaryType):
"""MySQL BLOB type, for binary data up to 2^16 bytes"""
-
def __init__(self, length=None, **kw):
"""Construct a BLOB. Arguments are:
@@ -790,24 +835,28 @@ class MSBlob(_BinaryType):
def __repr__(self):
return "%s()" % self.__class__.__name__
+
class MSTinyBlob(MSBlob):
"""MySQL TINYBLOB type, for binary data up to 2^8 bytes"""
def get_col_spec(self):
return "TINYBLOB"
+
class MSMediumBlob(MSBlob):
"""MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes"""
def get_col_spec(self):
return "MEDIUMBLOB"
+
class MSLongBlob(MSBlob):
"""MySQL LONGBLOB type, for binary data up to 2^32 bytes"""
def get_col_spec(self):
return "LONGBLOB"
+
class MSEnum(MSString):
"""MySQL ENUM type."""
@@ -877,6 +926,7 @@ class MSEnum(MSString):
def get_col_spec(self):
return self._extend("ENUM(%s)" % ",".join(self.__ddl_values))
+
class MSBoolean(sqltypes.Boolean):
def get_col_spec(self):
return "BOOL"
@@ -899,17 +949,17 @@ class MSBoolean(sqltypes.Boolean):
# TODO: SET, BIT
colspecs = {
- sqltypes.Integer : MSInteger,
- sqltypes.Smallinteger : MSSmallInteger,
- sqltypes.Numeric : MSNumeric,
- sqltypes.Float : MSFloat,
- sqltypes.DateTime : MSDateTime,
- sqltypes.Date : MSDate,
- sqltypes.Time : MSTime,
- sqltypes.String : MSString,
- sqltypes.Binary : MSBlob,
- sqltypes.Boolean : MSBoolean,
- sqltypes.TEXT : MSText,
+ sqltypes.Integer: MSInteger,
+ sqltypes.Smallinteger: MSSmallInteger,
+ sqltypes.Numeric: MSNumeric,
+ sqltypes.Float: MSFloat,
+ sqltypes.DateTime: MSDateTime,
+ sqltypes.Date: MSDate,
+ sqltypes.Time: MSTime,
+ sqltypes.String: MSString,
+ sqltypes.Binary: MSBlob,
+ sqltypes.Boolean: MSBoolean,
+ sqltypes.TEXT: MSText,
sqltypes.CHAR: MSChar,
sqltypes.NCHAR: MSNChar,
sqltypes.TIMESTAMP: MSTimeStamp,
@@ -919,37 +969,37 @@ colspecs = {
ischema_names = {
- 'bigint' : MSBigInteger,
- 'binary' : MSBinary,
- 'blob' : MSBlob,
+ 'bigint': MSBigInteger,
+ 'binary': MSBinary,
+ 'blob': MSBlob,
'boolean':MSBoolean,
- 'char' : MSChar,
- 'date' : MSDate,
- 'datetime' : MSDateTime,
- 'decimal' : MSDecimal,
- 'double' : MSDouble,
+ 'char': MSChar,
+ 'date': MSDate,
+ 'datetime': MSDateTime,
+ 'decimal': MSDecimal,
+ 'double': MSDouble,
'enum': MSEnum,
'fixed': MSDecimal,
- 'float' : MSFloat,
- 'int' : MSInteger,
- 'integer' : MSInteger,
+ 'float': MSFloat,
+ 'int': MSInteger,
+ 'integer': MSInteger,
'longblob': MSLongBlob,
'longtext': MSLongText,
'mediumblob': MSMediumBlob,
- 'mediumint' : MSInteger,
+ 'mediumint': MSInteger,
'mediumtext': MSMediumText,
'nchar': MSNChar,
'nvarchar': MSNVarChar,
- 'numeric' : MSNumeric,
- 'smallint' : MSSmallInteger,
- 'text' : MSText,
- 'time' : MSTime,
- 'timestamp' : MSTimeStamp,
+ 'numeric': MSNumeric,
+ 'smallint': MSSmallInteger,
+ 'text': MSText,
+ 'time': MSTime,
+ 'timestamp': MSTimeStamp,
'tinyblob': MSTinyBlob,
- 'tinyint' : MSSmallInteger,
- 'tinytext' : MSTinyText,
- 'varbinary' : MSVarBinary,
- 'varchar' : MSString,
+ 'tinyint': MSTinyInteger,
+ 'tinytext': MSTinyText,
+ 'varbinary': MSVarBinary,
+ 'varchar': MSString,
}
def descriptor():
@@ -962,19 +1012,26 @@ def descriptor():
('host',"Hostname", None),
]}
+
class MySQLExecutionContext(default.DefaultExecutionContext):
+ _my_is_select = re.compile(r'\s*(?:SELECT|SHOW|DESCRIBE|XA RECOVER)',
+ re.I | re.UNICODE)
+
def post_exec(self):
if self.compiled.isinsert:
- if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
- self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
+ if (not len(self._last_inserted_ids) or
+ self._last_inserted_ids[0] is None):
+ self._last_inserted_ids = ([self.cursor.lastrowid] +
+ self._last_inserted_ids[1:])
def is_select(self):
- return re.match(r'SELECT|SHOW|DESCRIBE|XA RECOVER', self.statement.lstrip(), re.I) is not None
+ return self._my_is_select.match(self.statement) is not None
+
class MySQLDialect(ansisql.ANSIDialect):
def __init__(self, **kwargs):
- ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
- self.per_connection = weakref.WeakKeyDictionary()
+ ansisql.ANSIDialect.__init__(self, default_paramstyle='format',
+ **kwargs)
def dbapi(cls):
import MySQLdb as mysql
@@ -989,12 +1046,15 @@ class MySQLDialect(ansisql.ANSIDialect):
util.coerce_kw_type(opts, 'connect_timeout', int)
util.coerce_kw_type(opts, 'client_flag', int)
util.coerce_kw_type(opts, 'local_infile', int)
- # note: these two could break SA Unicode type
+ # Note: using either of the below will cause all strings to be returned
+ # as Unicode, both in raw SQL operations and with column types like
+ # String and MSString.
util.coerce_kw_type(opts, 'use_unicode', bool)
util.coerce_kw_type(opts, 'charset', str)
- # TODO: cursorclass and conv: support via query string or punt?
+
+ # Rich values 'cursorclass' and 'conv' are not supported via
+ # query string.
- # ssl
ssl = {}
for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']:
if key in opts:
@@ -1004,7 +1064,7 @@ class MySQLDialect(ansisql.ANSIDialect):
if len(ssl):
opts['ssl'] = ssl
- # FOUND_ROWS must be set in CLIENT_FLAGS for to enable
+ # FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
client_flag = opts.get('client_flag', 0)
if self.dbapi is not None:
@@ -1041,7 +1101,8 @@ class MySQLDialect(ansisql.ANSIDialect):
def preparer(self):
return MySQLIdentifierPreparer(self)
- def do_executemany(self, cursor, statement, parameters, context=None, **kwargs):
+ def do_executemany(self, cursor, statement, parameters,
+ context=None, **kwargs):
rowcount = cursor.executemany(statement, parameters)
if context is not None:
context._rowcount = rowcount
@@ -1060,28 +1121,31 @@ class MySQLDialect(ansisql.ANSIDialect):
pass
def do_begin_twophase(self, connection, xid):
- connection.execute(sql.text("XA BEGIN :xid", bindparams=[sql.bindparam('xid',xid)]))
+ connection.execute("XA BEGIN %s", xid)
def do_prepare_twophase(self, connection, xid):
- connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
- connection.execute(sql.text("XA PREPARE :xid", bindparams=[sql.bindparam('xid',xid)]))
+ connection.execute("XA END %s", xid)
+ connection.execute("XA PREPARE %s", xid)
- def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+ def do_rollback_twophase(self, connection, xid, is_prepared=True,
+ recover=False):
if not is_prepared:
- connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
- connection.execute(sql.text("XA ROLLBACK :xid", bindparams=[sql.bindparam('xid',xid)]))
+ connection.execute("XA END %s", xid)
+ connection.execute("XA ROLLBACK %s", xid)
- def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+ def do_commit_twophase(self, connection, xid, is_prepared=True,
+ recover=False):
if not is_prepared:
self.do_prepare_twophase(connection, xid)
- connection.execute(sql.text("XA COMMIT :xid", bindparams=[sql.bindparam('xid',xid)]))
+ connection.execute("XA COMMIT %s", xid)
def do_recover_twophase(self, connection):
- resultset = connection.execute(sql.text("XA RECOVER"))
+ resultset = connection.execute("XA RECOVER")
return [row['data'][0:row['gtrid_length']] for row in resultset]
def is_disconnect(self, e):
- return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055)
+ return isinstance(e, self.dbapi.OperationalError) and \
+ e.args[0] in (2006, 2013, 2014, 2045, 2055)
def get_default_schema_name(self, connection):
try:
@@ -1102,7 +1166,7 @@ class MySQLDialect(ansisql.ANSIDialect):
# on macosx (and maybe win?) with multibyte table names.
#
# TODO: if this is not a problem on win, make the strategy swappable
- # based on platform. DESCRIBE is much slower.
+ # based on platform. DESCRIBE is slower.
if schema is not None:
st = "DESCRIBE `%s`.`%s`" % (schema, table_name)
else:
@@ -1118,12 +1182,14 @@ class MySQLDialect(ansisql.ANSIDialect):
raise
def get_version_info(self, connectable):
+ """A tuple of the database server version."""
+
if hasattr(connectable, 'connect'):
- con = connectable.connect().connection
+ dbapi_con = connectable.connect().connection
else:
- con = connectable
+ dbapi_con = connectable
version = []
- for n in con.get_server_info().split('.'):
+ for n in dbapi_con.get_server_info().split('.'):
try:
version.append(int(n))
except ValueError:
@@ -1140,8 +1206,9 @@ class MySQLDialect(ansisql.ANSIDialect):
table.name = table.name.lower()
table.metadata.tables[table.name]= table
+ table_name = '.'.join(self.identifier_preparer.format_table_seq(table))
try:
- rp = connection.execute("DESCRIBE " + self._escape_table_name(table))
+ rp = connection.execute("DESCRIBE " + table_name)
except exceptions.SQLError, e:
if e.orig.args[0] == 1146:
raise exceptions.NoSuchTableError(table.fullname)
@@ -1166,7 +1233,9 @@ class MySQLDialect(ansisql.ANSIDialect):
try:
coltype = ischema_names[col_type]
except KeyError:
- warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (col_type, name)))
+ warnings.warn(RuntimeWarning(
+ "Did not recognize type '%s' of column '%s'" %
+ (col_type, name)))
coltype = sqltypes.NULLTYPE
kw = {}
@@ -1194,24 +1263,28 @@ class MySQLDialect(ansisql.ANSIDialect):
nullable=nullable,
)))
- tabletype = self.moretableinfo(connection, table, decode_from)
- table.kwargs['mysql_engine'] = tabletype
+ table_options = self.moretableinfo(connection, table, decode_from)
+ table.kwargs.update(table_options)
def moretableinfo(self, connection, table, charset=None):
"""SHOW CREATE TABLE to get foreign key/table options."""
- rp = connection.execute("SHOW CREATE TABLE " + self._escape_table_name(table), {})
+ table_name = '.'.join(self.identifier_preparer.format_table_seq(table))
+ rp = connection.execute("SHOW CREATE TABLE " + table_name)
row = _compat_fetchone(rp, charset=charset)
if not row:
raise exceptions.NoSuchTableError(table.fullname)
desc = row[1].strip()
+ row.close()
+
+ table_options = {}
- tabletype = ''
lastparen = re.search(r'\)[^\)]*\Z', desc)
if lastparen:
- match = re.search(r'\b(?:TYPE|ENGINE)=(?P<ttype>.+)\b', desc[lastparen.start():], re.I)
+ match = re.search(r'\b(?P<spec>TYPE|ENGINE)=(?P<ttype>.+)\b', desc[lastparen.start():], re.I)
if match:
- tabletype = match.group('ttype')
+ table_options["mysql_%s" % match.group('spec')] = \
+ match.group('ttype')
# \x27 == ' (single quote) (avoid xemacs syntax highlighting issue)
fkpat = r'''CONSTRAINT [`"\x27](?P<name>.+?)[`"\x27] FOREIGN KEY \((?P<columns>.+?)\) REFERENCES [`"\x27](?P<reftable>.+?)[`"\x27] \((?P<refcols>.+?)\)'''
@@ -1222,21 +1295,32 @@ class MySQLDialect(ansisql.ANSIDialect):
constraint = schema.ForeignKeyConstraint(columns, refcols, name=match.group('name'))
table.append_constraint(constraint)
- return tabletype
-
- def _escape_table_name(self, table):
- if table.schema is not None:
- return '`%s`.`%s`' % (table.schema, table.name)
- else:
- return '`%s`' % table.name
+ return table_options
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
+ # Allow user override, won't sniff if force_charset is set.
+ if 'force_charset' in connection.properties:
+ return connection.properties['force_charset']
+
# Note: MySQL-python 1.2.1c7 seems to ignore changes made
# on a connection via set_character_set()
-
- rs = connection.execute("show variables like 'character_set%%'")
+ if self.get_version_info(connection) < (4, 1, 0):
+ try:
+ return connection.connection.character_set_name()
+ except AttributeError:
+ # < 1.2.1 final MySQL-python drivers have no charset support.
+ # a query is needed.
+ pass
+
+ # Prefer 'character_set_results' for the current connection over the
+ # value in the driver. SET NAMES or individual variable SETs will
+ # change the charset without updating the driver's view of the world.
+ #
+ # If it's decided that issuing that sort of SQL leaves you SOL, then
+ # this can prefer the driver value.
+ rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
opts = dict([(row[0], row[1]) for row in _compat_fetchall(rs)])
if 'character_set_results' in opts:
@@ -1244,11 +1328,14 @@ class MySQLDialect(ansisql.ANSIDialect):
try:
return connection.connection.character_set_name()
except AttributeError:
- # < 1.2.1 final MySQL-python drivers have no charset support
+ # Still no charset on < 1.2.1 final...
if 'character_set' in opts:
return opts['character_set']
else:
- warnings.warn(RuntimeWarning("Could not detect the connection character set with this combination of MySQL server and MySQL-python. MySQL-python >= 1.2.2 is recommended. Assuming latin1."))
+ warnings.warn(RuntimeWarning(
+ "Could not detect the connection character set with this "
+ "combination of MySQL server and MySQL-python. "
+ "MySQL-python >= 1.2.2 is recommended. Assuming latin1."))
return 'latin1'
def _detect_case_sensitive(self, connection, charset=None):
@@ -1257,25 +1344,41 @@ class MySQLDialect(ansisql.ANSIDialect):
Cached per-connection. This value can not change without a server
restart.
"""
+
# http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
- _per_connection_mutex.acquire()
try:
- raw_connection = connection.connection.connection
- cache = self.per_connection.get(raw_connection, {})
- if 'lower_case_table_names' not in cache:
- row = _compat_fetchone(connection.execute(
- "SHOW VARIABLES LIKE 'lower_case_table_names'"),
- charset=charset)
- if not row:
- cs = True
- else:
- cs = row[1] in ('0', 'OFF' 'off')
- cache['lower_case_table_names'] = cs
- self.per_connection[raw_connection] = cache
- return cache.get('lower_case_table_names')
- finally:
- _per_connection_mutex.release()
+ return connection.properties['lower_case_table_names']
+ except KeyError:
+ row = _compat_fetchone(connection.execute(
+ "SHOW VARIABLES LIKE 'lower_case_table_names'"),
+ charset=charset)
+ if not row:
+ cs = True
+ else:
+ cs = row[1] in ('0', 'OFF' 'off')
+ row.close()
+ connection.properties['lower_case_table_names'] = cs
+ return cs
+
+ def _detect_collations(self, connection, charset=None):
+ """Pull the active COLLATIONS list from the server.
+
+ Cached per-connection.
+ """
+
+ try:
+ return connection.properties['collations']
+ except KeyError:
+ collations = {}
+ if self.get_version_info(connection) < (4, 1, 0):
+ pass
+ else:
+ rs = connection.execute('SHOW COLLATION')
+ for row in _compat_fetchall(rs, charset):
+ collations[row[0]] = row[1]
+ connection.properties['collations'] = collations
+ return collations
def _compat_fetchall(rp, charset=None):
"""Proxy result rows to smooth over MySQL-Python driver inconsistencies."""
@@ -1291,6 +1394,10 @@ def _compat_fetchone(rp, charset=None):
class _MySQLPythonRowProxy(object):
"""Return consistent column values for all versions of MySQL-python (esp. alphas) and unicode settings."""
+ # Some MySQL-python versions can return some columns as
+ # sets.Set(['value']) (seriously) but thankfully that doesn't
+ # seem to come up in DDL queries.
+
def __init__(self, rowproxy, charset):
self.rowproxy = rowproxy
self.charset = charset
@@ -1316,13 +1423,15 @@ class MySQLCompiler(ansisql.ANSICompiler):
operators = ansisql.ANSICompiler.operators.copy()
operators.update(
{
- sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y),
- operator.mod : '%%'
+ sql.ColumnOperators.concat_op: \
+ lambda x, y: "concat(%s, %s)" % (x, y),
+ operator.mod: '%%'
}
)
def visit_cast(self, cast, **kwargs):
- if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)):
+ if isinstance(cast.type, (sqltypes.Date, sqltypes.Time,
+ sqltypes.DateTime)):
return super(MySQLCompiler, self).visit_cast(cast, **kwargs)
else:
# so just skip the CAST altogether for now.
@@ -1348,26 +1457,45 @@ class MySQLCompiler(ansisql.ANSICompiler):
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
- def get_column_specification(self, column, override_pk=False, first_pk=False):
- colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
+ def get_column_specification(self, column, override_pk=False,
+ first_pk=False):
+ """Builds column DDL."""
+
+ colspec = [self.preparer.format_column(column),
+ column.type.dialect_impl(self.dialect).get_col_spec()]
+
default = self.get_column_default_string(column)
if default is not None:
- colspec += " DEFAULT " + default
+ colspec.append('DEFAULT ' + default)
if not column.nullable:
- colspec += " NOT NULL"
+ colspec.append('NOT NULL')
+
+ # FIXME: #649, also #612 with regard to SHOW CREATE
if column.primary_key:
- if len(column.foreign_keys)==0 and first_pk and column.autoincrement and isinstance(column.type, sqltypes.Integer):
- colspec += " AUTO_INCREMENT"
- return colspec
+ if (len(column.foreign_keys)==0
+ and first_pk
+ and column.autoincrement
+ and isinstance(column.type, sqltypes.Integer)):
+ colspec.append('AUTO_INCREMENT')
+
+ return ' '.join(colspec)
def post_create_table(self, table):
- args = ""
+ """Build table-level CREATE options like ENGINE and COLLATE."""
+
+ table_opts = []
for k in table.kwargs:
if k.startswith('mysql_'):
- opt = k[6:]
- args += " %s=%s" % (opt.upper(), table.kwargs[k])
- return args
+ opt = k[6:].upper()
+ joiner = '='
+ if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET',
+ 'CHARACTER SET', 'COLLATE'):
+ joiner = ' '
+
+ table_opts.append(joiner.join((opt, table.kwargs[k])))
+ return ' '.join(table_opts)
+
class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
def visit_index(self, index):
@@ -1382,9 +1510,11 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
self.preparer.format_constraint(constraint)))
self.execute()
+
class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
def __init__(self, dialect):
- super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote='`')
+ super(MySQLIdentifierPreparer, self).__init__(dialect,
+ initial_quote='`')
def _reserved_words(self):
return RESERVED_WORDS
@@ -1393,7 +1523,69 @@ class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
return value.replace('`', '``')
def _fold_identifier_case(self, value):
- #TODO: determine MySQL's case folding rules
+ # TODO: determine MySQL's case folding rules
+ #
+ # For compatability with sql.text() issued statements, maybe it's best
+ # to just leave things as-is. When lower_case_table_names > 0 it
+ # looks a good idea to lc everything, but more importantly the casing
+ # of all identifiers in an expression must be consistent. So for now,
+ # just leave everything as-is.
return value
+ def format_table_seq(self, table, use_schema=True):
+ """Format table name and schema as a tuple."""
+
+ if use_schema and getattr(table, 'schema', None):
+ return (self.quote_identifier(table.schema),
+ self.format_table(table, use_schema=False))
+ else:
+ return (self.format_table(table, use_schema=False), )
+
+
+class MySQLCharsetOnConnect(object):
+ """Use an alternate connection character set automatically."""
+
+ def __init__(self, charset, collation=None):
+ """Creates a pool listener that decorates new database connections.
+
+ Sets the connection character set on MySQL connections. Strings
+ sent to and from the server will use this encoding, and if a collation
+ is provided it will be used as the default.
+
+ There is also a MySQL-python 'charset' keyword for connections,
+ however that keyword has the side-effect of turning all strings into
+ Unicode.
+
+ This class is a ``Pool`` listener. To use, pass an insstance to the
+ ``listeners`` argument to create_engine or Pool constructor, or
+ manually add it to a pool with ``add_listener()``.
+
+ charset:
+ The character set to use
+
+ collation:
+ Optional, use a non-default collation for the given charset
+ """
+
+ self.charset = charset
+ self.collation = collation
+
+ def connect(self, dbapi_con, con_record):
+ cr = dbapi_con.cursor()
+ try:
+ if self.collation is None:
+ if hasattr(dbapi_con, 'set_character_set'):
+ dbapi_con.set_character_set(self.charset)
+ else:
+ cr.execute("SET NAMES %s" % self.charset)
+ else:
+ if hasattr(dbapi_con, 'set_character_set'):
+ dbapi_con.set_character_set(self.charset)
+ cr.execute("SET NAMES %s COLLATE %s" % (self.charset,
+ self.collation))
+ # let SQL errors (1064 if SET NAMES is not supported) raise
+ finally:
+ cr.close()
+
+
dialect = MySQLDialect