diff options
author | Jason Kirtland <jek@discorporate.us> | 2007-07-29 16:13:23 +0000 |
---|---|---|
committer | Jason Kirtland <jek@discorporate.us> | 2007-07-29 16:13:23 +0000 |
commit | 6a1d279a9a702a4c520446f52edeb1ee4864dd70 (patch) | |
tree | 4c69f76299ee30f3b715f7c43fcfc880d51803cd /lib/sqlalchemy/databases/mysql.py | |
parent | c9cc90bbdc24f8d4d0429468404cd43de46fc07f (diff) | |
download | sqlalchemy-6a1d279a9a702a4c520446f52edeb1ee4864dd70.tar.gz |
Big MySQL dialect update, mostly efficiency and style.
Added TINYINT [ticket:691]- whoa, how did that one go missing for so long?
Added a charset-fixing pool listener. The driver-level option doesn't help everyone with this one.
New reflector code not quite done and omiited from this commit.
Diffstat (limited to 'lib/sqlalchemy/databases/mysql.py')
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 426 |
1 files changed, 309 insertions, 117 deletions
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 53ef1a95b..2c54c2512 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -5,25 +5,21 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php import re, datetime, inspect, warnings, weakref, operator +from array import array as _array +from decimal import Decimal from sqlalchemy import sql, schema, ansisql from sqlalchemy.engine import default import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions import sqlalchemy.util as util -from array import array as _array -from decimal import Decimal -try: - from threading import Lock -except ImportError: - from dummy_threading import Lock RESERVED_WORDS = util.Set( ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc', 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both', 'by', 'call', 'cascade', 'case', 'change', 'char', 'character', 'check', - 'collate', 'column', 'condition', 'constraint', 'continue', 'convert', + 'collate', 'column', 'con dition', 'constraint', 'continue', 'convert', 'create', 'cross', 'current_date', 'current_time', 'current_timestamp', 'current_user', 'cursor', 'database', 'databases', 'day_hour', 'day_microsecond', 'day_minute', 'day_second', 'dec', 'decimal', @@ -60,7 +56,7 @@ RESERVED_WORDS = util.Set( 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', 'read_only', 'read_write', # 5.1 ]) -_per_connection_mutex = Lock() + class _NumericType(object): "Base for MySQL numeric types." @@ -78,6 +74,7 @@ class _NumericType(object): spec += ' ZEROFILL' return spec + class _StringType(object): "Base for MySQL string types." @@ -133,10 +130,11 @@ class _StringType(object): return "%s(%s)" % (self.__class__.__name__, ','.join(['%s=%s' % (k, params[k]) for k in params])) + class MSNumeric(sqltypes.Numeric, _NumericType): """MySQL NUMERIC type""" - def __init__(self, precision = 10, length = 2, asdecimal=True, **kw): + def __init__(self, precision=10, length=2, asdecimal=True, **kw): """Construct a NUMERIC. precision @@ -173,6 +171,7 @@ class MSNumeric(sqltypes.Numeric, _NumericType): else: return value + class MSDecimal(MSNumeric): """MySQL DECIMAL type""" @@ -205,6 +204,7 @@ class MSDecimal(MSNumeric): else: return self._extend("DECIMAL(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}) + class MSDouble(MSNumeric): """MySQL DOUBLE type""" @@ -240,6 +240,7 @@ class MSDouble(MSNumeric): else: return self._extend('DOUBLE') + class MSFloat(sqltypes.Float, _NumericType): """MySQL FLOAT type""" @@ -307,6 +308,7 @@ class MSInteger(sqltypes.Integer, _NumericType): else: return self._extend("INTEGER") + class MSBigInteger(MSInteger): """MySQL BIGINTEGER type""" @@ -333,6 +335,34 @@ class MSBigInteger(MSInteger): else: return self._extend("BIGINT") + +class MSTinyInteger(MSInteger): + """MySQL TINYINT type""" + + def __init__(self, length=None, **kw): + """Construct a SMALLINTEGER. + + length + Optional, maximum display width for this number. + + unsigned + Optional. + + zerofill + Optional. If true, values will be stored as strings left-padded with + zeros. Note that this does not effect the values returned by the + underlying database API, which continue to be numeric. + """ + + super(MSTinyInteger, self).__init__(length, **kw) + + def get_col_spec(self): + if self.length is not None: + return self._extend("TINYINT(%s)" % self.length) + else: + return self._extend("TINYINT") + + class MSSmallInteger(sqltypes.Smallinteger, _NumericType): """MySQL SMALLINTEGER type""" @@ -361,18 +391,21 @@ class MSSmallInteger(sqltypes.Smallinteger, _NumericType): else: return self._extend("SMALLINT") + class MSDateTime(sqltypes.DateTime): """MySQL DATETIME type""" def get_col_spec(self): return "DATETIME" + class MSDate(sqltypes.Date): """MySQL DATE type""" def get_col_spec(self): return "DATE" + class MSTime(sqltypes.Time): """MySQL TIME type""" @@ -386,6 +419,7 @@ class MSTime(sqltypes.Time): else: return None + class MSTimeStamp(sqltypes.TIMESTAMP): """MySQL TIMESTAMP type @@ -399,6 +433,7 @@ class MSTimeStamp(sqltypes.TIMESTAMP): def get_col_spec(self): return "TIMESTAMP" + class MSYear(sqltypes.String): """MySQL YEAR type, for single byte storage of years 1901-2155""" @@ -408,6 +443,7 @@ class MSYear(sqltypes.String): else: return "YEAR(%d)" % self.length + class MSText(_StringType, sqltypes.TEXT): """MySQL TEXT type, for text up to 2^16 characters""" @@ -495,6 +531,7 @@ class MSTinyText(MSText): def get_col_spec(self): return self._extend("TINYTEXT") + class MSMediumText(MSText): """MySQL MEDIUMTEXT type, for text up to 2^24 characters""" @@ -533,6 +570,7 @@ class MSMediumText(MSText): def get_col_spec(self): return self._extend("MEDIUMTEXT") + class MSLongText(MSText): """MySQL LONGTEXT type, for text up to 2^32 characters""" @@ -571,6 +609,7 @@ class MSLongText(MSText): def get_col_spec(self): return self._extend("LONGTEXT") + class MSString(_StringType, sqltypes.String): """MySQL VARCHAR type, for variable-length character data.""" @@ -617,6 +656,7 @@ class MSString(_StringType, sqltypes.String): else: return self._extend("TEXT") + class MSChar(_StringType, sqltypes.CHAR): """MySQL CHAR type, for fixed-length character data.""" @@ -642,6 +682,7 @@ class MSChar(_StringType, sqltypes.CHAR): def get_col_spec(self): return self._extend("CHAR(%(length)s)" % {'length' : self.length}) + class MSNVarChar(_StringType, sqltypes.String): """MySQL NVARCHAR type, for variable-length character data in the server's configured national character set. @@ -673,6 +714,7 @@ class MSNVarChar(_StringType, sqltypes.String): # of "NVARCHAR". return self._extend("VARCHAR(%(length)s)" % {'length': self.length}) + class MSNChar(_StringType, sqltypes.CHAR): """MySQL NCHAR type, for fixed-length character data in the server's configured national character set. @@ -701,6 +743,7 @@ class MSNChar(_StringType, sqltypes.CHAR): # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR". return self._extend("CHAR(%(length)s)" % {'length': self.length}) + class _BinaryType(sqltypes.Binary): """MySQL binary types""" @@ -716,6 +759,7 @@ class _BinaryType(sqltypes.Binary): else: return buffer(value) + class MSVarBinary(_BinaryType): """MySQL VARBINARY type, for variable length binary data""" @@ -733,6 +777,7 @@ class MSVarBinary(_BinaryType): else: return "BLOB" + class MSBinary(_BinaryType): """MySQL BINARY type, for fixed length binary data""" @@ -760,10 +805,10 @@ class MSBinary(_BinaryType): else: return buffer(value) + class MSBlob(_BinaryType): """MySQL BLOB type, for binary data up to 2^16 bytes""" - def __init__(self, length=None, **kw): """Construct a BLOB. Arguments are: @@ -790,24 +835,28 @@ class MSBlob(_BinaryType): def __repr__(self): return "%s()" % self.__class__.__name__ + class MSTinyBlob(MSBlob): """MySQL TINYBLOB type, for binary data up to 2^8 bytes""" def get_col_spec(self): return "TINYBLOB" + class MSMediumBlob(MSBlob): """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes""" def get_col_spec(self): return "MEDIUMBLOB" + class MSLongBlob(MSBlob): """MySQL LONGBLOB type, for binary data up to 2^32 bytes""" def get_col_spec(self): return "LONGBLOB" + class MSEnum(MSString): """MySQL ENUM type.""" @@ -877,6 +926,7 @@ class MSEnum(MSString): def get_col_spec(self): return self._extend("ENUM(%s)" % ",".join(self.__ddl_values)) + class MSBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOL" @@ -899,17 +949,17 @@ class MSBoolean(sqltypes.Boolean): # TODO: SET, BIT colspecs = { - sqltypes.Integer : MSInteger, - sqltypes.Smallinteger : MSSmallInteger, - sqltypes.Numeric : MSNumeric, - sqltypes.Float : MSFloat, - sqltypes.DateTime : MSDateTime, - sqltypes.Date : MSDate, - sqltypes.Time : MSTime, - sqltypes.String : MSString, - sqltypes.Binary : MSBlob, - sqltypes.Boolean : MSBoolean, - sqltypes.TEXT : MSText, + sqltypes.Integer: MSInteger, + sqltypes.Smallinteger: MSSmallInteger, + sqltypes.Numeric: MSNumeric, + sqltypes.Float: MSFloat, + sqltypes.DateTime: MSDateTime, + sqltypes.Date: MSDate, + sqltypes.Time: MSTime, + sqltypes.String: MSString, + sqltypes.Binary: MSBlob, + sqltypes.Boolean: MSBoolean, + sqltypes.TEXT: MSText, sqltypes.CHAR: MSChar, sqltypes.NCHAR: MSNChar, sqltypes.TIMESTAMP: MSTimeStamp, @@ -919,37 +969,37 @@ colspecs = { ischema_names = { - 'bigint' : MSBigInteger, - 'binary' : MSBinary, - 'blob' : MSBlob, + 'bigint': MSBigInteger, + 'binary': MSBinary, + 'blob': MSBlob, 'boolean':MSBoolean, - 'char' : MSChar, - 'date' : MSDate, - 'datetime' : MSDateTime, - 'decimal' : MSDecimal, - 'double' : MSDouble, + 'char': MSChar, + 'date': MSDate, + 'datetime': MSDateTime, + 'decimal': MSDecimal, + 'double': MSDouble, 'enum': MSEnum, 'fixed': MSDecimal, - 'float' : MSFloat, - 'int' : MSInteger, - 'integer' : MSInteger, + 'float': MSFloat, + 'int': MSInteger, + 'integer': MSInteger, 'longblob': MSLongBlob, 'longtext': MSLongText, 'mediumblob': MSMediumBlob, - 'mediumint' : MSInteger, + 'mediumint': MSInteger, 'mediumtext': MSMediumText, 'nchar': MSNChar, 'nvarchar': MSNVarChar, - 'numeric' : MSNumeric, - 'smallint' : MSSmallInteger, - 'text' : MSText, - 'time' : MSTime, - 'timestamp' : MSTimeStamp, + 'numeric': MSNumeric, + 'smallint': MSSmallInteger, + 'text': MSText, + 'time': MSTime, + 'timestamp': MSTimeStamp, 'tinyblob': MSTinyBlob, - 'tinyint' : MSSmallInteger, - 'tinytext' : MSTinyText, - 'varbinary' : MSVarBinary, - 'varchar' : MSString, + 'tinyint': MSTinyInteger, + 'tinytext': MSTinyText, + 'varbinary': MSVarBinary, + 'varchar': MSString, } def descriptor(): @@ -962,19 +1012,26 @@ def descriptor(): ('host',"Hostname", None), ]} + class MySQLExecutionContext(default.DefaultExecutionContext): + _my_is_select = re.compile(r'\s*(?:SELECT|SHOW|DESCRIBE|XA RECOVER)', + re.I | re.UNICODE) + def post_exec(self): if self.compiled.isinsert: - if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: - self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] + if (not len(self._last_inserted_ids) or + self._last_inserted_ids[0] is None): + self._last_inserted_ids = ([self.cursor.lastrowid] + + self._last_inserted_ids[1:]) def is_select(self): - return re.match(r'SELECT|SHOW|DESCRIBE|XA RECOVER', self.statement.lstrip(), re.I) is not None + return self._my_is_select.match(self.statement) is not None + class MySQLDialect(ansisql.ANSIDialect): def __init__(self, **kwargs): - ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs) - self.per_connection = weakref.WeakKeyDictionary() + ansisql.ANSIDialect.__init__(self, default_paramstyle='format', + **kwargs) def dbapi(cls): import MySQLdb as mysql @@ -989,12 +1046,15 @@ class MySQLDialect(ansisql.ANSIDialect): util.coerce_kw_type(opts, 'connect_timeout', int) util.coerce_kw_type(opts, 'client_flag', int) util.coerce_kw_type(opts, 'local_infile', int) - # note: these two could break SA Unicode type + # Note: using either of the below will cause all strings to be returned + # as Unicode, both in raw SQL operations and with column types like + # String and MSString. util.coerce_kw_type(opts, 'use_unicode', bool) util.coerce_kw_type(opts, 'charset', str) - # TODO: cursorclass and conv: support via query string or punt? + + # Rich values 'cursorclass' and 'conv' are not supported via + # query string. - # ssl ssl = {} for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']: if key in opts: @@ -1004,7 +1064,7 @@ class MySQLDialect(ansisql.ANSIDialect): if len(ssl): opts['ssl'] = ssl - # FOUND_ROWS must be set in CLIENT_FLAGS for to enable + # FOUND_ROWS must be set in CLIENT_FLAGS to enable # supports_sane_rowcount. client_flag = opts.get('client_flag', 0) if self.dbapi is not None: @@ -1041,7 +1101,8 @@ class MySQLDialect(ansisql.ANSIDialect): def preparer(self): return MySQLIdentifierPreparer(self) - def do_executemany(self, cursor, statement, parameters, context=None, **kwargs): + def do_executemany(self, cursor, statement, parameters, + context=None, **kwargs): rowcount = cursor.executemany(statement, parameters) if context is not None: context._rowcount = rowcount @@ -1060,28 +1121,31 @@ class MySQLDialect(ansisql.ANSIDialect): pass def do_begin_twophase(self, connection, xid): - connection.execute(sql.text("XA BEGIN :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute("XA BEGIN %s", xid) def do_prepare_twophase(self, connection, xid): - connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)])) - connection.execute(sql.text("XA PREPARE :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute("XA END %s", xid) + connection.execute("XA PREPARE %s", xid) - def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + def do_rollback_twophase(self, connection, xid, is_prepared=True, + recover=False): if not is_prepared: - connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)])) - connection.execute(sql.text("XA ROLLBACK :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute("XA END %s", xid) + connection.execute("XA ROLLBACK %s", xid) - def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + def do_commit_twophase(self, connection, xid, is_prepared=True, + recover=False): if not is_prepared: self.do_prepare_twophase(connection, xid) - connection.execute(sql.text("XA COMMIT :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute("XA COMMIT %s", xid) def do_recover_twophase(self, connection): - resultset = connection.execute(sql.text("XA RECOVER")) + resultset = connection.execute("XA RECOVER") return [row['data'][0:row['gtrid_length']] for row in resultset] def is_disconnect(self, e): - return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055) + return isinstance(e, self.dbapi.OperationalError) and \ + e.args[0] in (2006, 2013, 2014, 2045, 2055) def get_default_schema_name(self, connection): try: @@ -1102,7 +1166,7 @@ class MySQLDialect(ansisql.ANSIDialect): # on macosx (and maybe win?) with multibyte table names. # # TODO: if this is not a problem on win, make the strategy swappable - # based on platform. DESCRIBE is much slower. + # based on platform. DESCRIBE is slower. if schema is not None: st = "DESCRIBE `%s`.`%s`" % (schema, table_name) else: @@ -1118,12 +1182,14 @@ class MySQLDialect(ansisql.ANSIDialect): raise def get_version_info(self, connectable): + """A tuple of the database server version.""" + if hasattr(connectable, 'connect'): - con = connectable.connect().connection + dbapi_con = connectable.connect().connection else: - con = connectable + dbapi_con = connectable version = [] - for n in con.get_server_info().split('.'): + for n in dbapi_con.get_server_info().split('.'): try: version.append(int(n)) except ValueError: @@ -1140,8 +1206,9 @@ class MySQLDialect(ansisql.ANSIDialect): table.name = table.name.lower() table.metadata.tables[table.name]= table + table_name = '.'.join(self.identifier_preparer.format_table_seq(table)) try: - rp = connection.execute("DESCRIBE " + self._escape_table_name(table)) + rp = connection.execute("DESCRIBE " + table_name) except exceptions.SQLError, e: if e.orig.args[0] == 1146: raise exceptions.NoSuchTableError(table.fullname) @@ -1166,7 +1233,9 @@ class MySQLDialect(ansisql.ANSIDialect): try: coltype = ischema_names[col_type] except KeyError: - warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (col_type, name))) + warnings.warn(RuntimeWarning( + "Did not recognize type '%s' of column '%s'" % + (col_type, name))) coltype = sqltypes.NULLTYPE kw = {} @@ -1194,24 +1263,28 @@ class MySQLDialect(ansisql.ANSIDialect): nullable=nullable, ))) - tabletype = self.moretableinfo(connection, table, decode_from) - table.kwargs['mysql_engine'] = tabletype + table_options = self.moretableinfo(connection, table, decode_from) + table.kwargs.update(table_options) def moretableinfo(self, connection, table, charset=None): """SHOW CREATE TABLE to get foreign key/table options.""" - rp = connection.execute("SHOW CREATE TABLE " + self._escape_table_name(table), {}) + table_name = '.'.join(self.identifier_preparer.format_table_seq(table)) + rp = connection.execute("SHOW CREATE TABLE " + table_name) row = _compat_fetchone(rp, charset=charset) if not row: raise exceptions.NoSuchTableError(table.fullname) desc = row[1].strip() + row.close() + + table_options = {} - tabletype = '' lastparen = re.search(r'\)[^\)]*\Z', desc) if lastparen: - match = re.search(r'\b(?:TYPE|ENGINE)=(?P<ttype>.+)\b', desc[lastparen.start():], re.I) + match = re.search(r'\b(?P<spec>TYPE|ENGINE)=(?P<ttype>.+)\b', desc[lastparen.start():], re.I) if match: - tabletype = match.group('ttype') + table_options["mysql_%s" % match.group('spec')] = \ + match.group('ttype') # \x27 == ' (single quote) (avoid xemacs syntax highlighting issue) fkpat = r'''CONSTRAINT [`"\x27](?P<name>.+?)[`"\x27] FOREIGN KEY \((?P<columns>.+?)\) REFERENCES [`"\x27](?P<reftable>.+?)[`"\x27] \((?P<refcols>.+?)\)''' @@ -1222,21 +1295,32 @@ class MySQLDialect(ansisql.ANSIDialect): constraint = schema.ForeignKeyConstraint(columns, refcols, name=match.group('name')) table.append_constraint(constraint) - return tabletype - - def _escape_table_name(self, table): - if table.schema is not None: - return '`%s`.`%s`' % (table.schema, table.name) - else: - return '`%s`' % table.name + return table_options def _detect_charset(self, connection): """Sniff out the character set in use for connection results.""" + # Allow user override, won't sniff if force_charset is set. + if 'force_charset' in connection.properties: + return connection.properties['force_charset'] + # Note: MySQL-python 1.2.1c7 seems to ignore changes made # on a connection via set_character_set() - - rs = connection.execute("show variables like 'character_set%%'") + if self.get_version_info(connection) < (4, 1, 0): + try: + return connection.connection.character_set_name() + except AttributeError: + # < 1.2.1 final MySQL-python drivers have no charset support. + # a query is needed. + pass + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") opts = dict([(row[0], row[1]) for row in _compat_fetchall(rs)]) if 'character_set_results' in opts: @@ -1244,11 +1328,14 @@ class MySQLDialect(ansisql.ANSIDialect): try: return connection.connection.character_set_name() except AttributeError: - # < 1.2.1 final MySQL-python drivers have no charset support + # Still no charset on < 1.2.1 final... if 'character_set' in opts: return opts['character_set'] else: - warnings.warn(RuntimeWarning("Could not detect the connection character set with this combination of MySQL server and MySQL-python. MySQL-python >= 1.2.2 is recommended. Assuming latin1.")) + warnings.warn(RuntimeWarning( + "Could not detect the connection character set with this " + "combination of MySQL server and MySQL-python. " + "MySQL-python >= 1.2.2 is recommended. Assuming latin1.")) return 'latin1' def _detect_case_sensitive(self, connection, charset=None): @@ -1257,25 +1344,41 @@ class MySQLDialect(ansisql.ANSIDialect): Cached per-connection. This value can not change without a server restart. """ + # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html - _per_connection_mutex.acquire() try: - raw_connection = connection.connection.connection - cache = self.per_connection.get(raw_connection, {}) - if 'lower_case_table_names' not in cache: - row = _compat_fetchone(connection.execute( - "SHOW VARIABLES LIKE 'lower_case_table_names'"), - charset=charset) - if not row: - cs = True - else: - cs = row[1] in ('0', 'OFF' 'off') - cache['lower_case_table_names'] = cs - self.per_connection[raw_connection] = cache - return cache.get('lower_case_table_names') - finally: - _per_connection_mutex.release() + return connection.properties['lower_case_table_names'] + except KeyError: + row = _compat_fetchone(connection.execute( + "SHOW VARIABLES LIKE 'lower_case_table_names'"), + charset=charset) + if not row: + cs = True + else: + cs = row[1] in ('0', 'OFF' 'off') + row.close() + connection.properties['lower_case_table_names'] = cs + return cs + + def _detect_collations(self, connection, charset=None): + """Pull the active COLLATIONS list from the server. + + Cached per-connection. + """ + + try: + return connection.properties['collations'] + except KeyError: + collations = {} + if self.get_version_info(connection) < (4, 1, 0): + pass + else: + rs = connection.execute('SHOW COLLATION') + for row in _compat_fetchall(rs, charset): + collations[row[0]] = row[1] + connection.properties['collations'] = collations + return collations def _compat_fetchall(rp, charset=None): """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" @@ -1291,6 +1394,10 @@ def _compat_fetchone(rp, charset=None): class _MySQLPythonRowProxy(object): """Return consistent column values for all versions of MySQL-python (esp. alphas) and unicode settings.""" + # Some MySQL-python versions can return some columns as + # sets.Set(['value']) (seriously) but thankfully that doesn't + # seem to come up in DDL queries. + def __init__(self, rowproxy, charset): self.rowproxy = rowproxy self.charset = charset @@ -1316,13 +1423,15 @@ class MySQLCompiler(ansisql.ANSICompiler): operators = ansisql.ANSICompiler.operators.copy() operators.update( { - sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y), - operator.mod : '%%' + sql.ColumnOperators.concat_op: \ + lambda x, y: "concat(%s, %s)" % (x, y), + operator.mod: '%%' } ) def visit_cast(self, cast, **kwargs): - if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)): + if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, + sqltypes.DateTime)): return super(MySQLCompiler, self).visit_cast(cast, **kwargs) else: # so just skip the CAST altogether for now. @@ -1348,26 +1457,45 @@ class MySQLCompiler(ansisql.ANSICompiler): class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column, override_pk=False, first_pk=False): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() + def get_column_specification(self, column, override_pk=False, + first_pk=False): + """Builds column DDL.""" + + colspec = [self.preparer.format_column(column), + column.type.dialect_impl(self.dialect).get_col_spec()] + default = self.get_column_default_string(column) if default is not None: - colspec += " DEFAULT " + default + colspec.append('DEFAULT ' + default) if not column.nullable: - colspec += " NOT NULL" + colspec.append('NOT NULL') + + # FIXME: #649, also #612 with regard to SHOW CREATE if column.primary_key: - if len(column.foreign_keys)==0 and first_pk and column.autoincrement and isinstance(column.type, sqltypes.Integer): - colspec += " AUTO_INCREMENT" - return colspec + if (len(column.foreign_keys)==0 + and first_pk + and column.autoincrement + and isinstance(column.type, sqltypes.Integer)): + colspec.append('AUTO_INCREMENT') + + return ' '.join(colspec) def post_create_table(self, table): - args = "" + """Build table-level CREATE options like ENGINE and COLLATE.""" + + table_opts = [] for k in table.kwargs: if k.startswith('mysql_'): - opt = k[6:] - args += " %s=%s" % (opt.upper(), table.kwargs[k]) - return args + opt = k[6:].upper() + joiner = '=' + if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', + 'CHARACTER SET', 'COLLATE'): + joiner = ' ' + + table_opts.append(joiner.join((opt, table.kwargs[k]))) + return ' '.join(table_opts) + class MySQLSchemaDropper(ansisql.ANSISchemaDropper): def visit_index(self, index): @@ -1382,9 +1510,11 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper): self.preparer.format_constraint(constraint))) self.execute() + class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): def __init__(self, dialect): - super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote='`') + super(MySQLIdentifierPreparer, self).__init__(dialect, + initial_quote='`') def _reserved_words(self): return RESERVED_WORDS @@ -1393,7 +1523,69 @@ class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): return value.replace('`', '``') def _fold_identifier_case(self, value): - #TODO: determine MySQL's case folding rules + # TODO: determine MySQL's case folding rules + # + # For compatability with sql.text() issued statements, maybe it's best + # to just leave things as-is. When lower_case_table_names > 0 it + # looks a good idea to lc everything, but more importantly the casing + # of all identifiers in an expression must be consistent. So for now, + # just leave everything as-is. return value + def format_table_seq(self, table, use_schema=True): + """Format table name and schema as a tuple.""" + + if use_schema and getattr(table, 'schema', None): + return (self.quote_identifier(table.schema), + self.format_table(table, use_schema=False)) + else: + return (self.format_table(table, use_schema=False), ) + + +class MySQLCharsetOnConnect(object): + """Use an alternate connection character set automatically.""" + + def __init__(self, charset, collation=None): + """Creates a pool listener that decorates new database connections. + + Sets the connection character set on MySQL connections. Strings + sent to and from the server will use this encoding, and if a collation + is provided it will be used as the default. + + There is also a MySQL-python 'charset' keyword for connections, + however that keyword has the side-effect of turning all strings into + Unicode. + + This class is a ``Pool`` listener. To use, pass an insstance to the + ``listeners`` argument to create_engine or Pool constructor, or + manually add it to a pool with ``add_listener()``. + + charset: + The character set to use + + collation: + Optional, use a non-default collation for the given charset + """ + + self.charset = charset + self.collation = collation + + def connect(self, dbapi_con, con_record): + cr = dbapi_con.cursor() + try: + if self.collation is None: + if hasattr(dbapi_con, 'set_character_set'): + dbapi_con.set_character_set(self.charset) + else: + cr.execute("SET NAMES %s" % self.charset) + else: + if hasattr(dbapi_con, 'set_character_set'): + dbapi_con.set_character_set(self.charset) + cr.execute("SET NAMES %s COLLATE %s" % (self.charset, + self.collation)) + # let SQL errors (1064 if SET NAMES is not supported) raise + finally: + cr.close() + + dialect = MySQLDialect |