diff options
-rw-r--r-- | CHANGES | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 102 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 26 | ||||
-rw-r--r-- | lib/sqlalchemy/pool.py | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/util.py | 26 | ||||
-rw-r--r-- | test/dialect/mysql.py | 43 |
8 files changed, 149 insertions, 77 deletions
@@ -253,6 +253,10 @@ CHANGES - Improvements to pyodbc + Unix. If you couldn't get that combination to work before, please try again. +- mysql + - The connection.info keys the dialect uses to cache server + settings have changed and are now namespaced. + 0.4.4 ------ - sql diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 18b236d1c..c4b8af684 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -157,7 +157,6 @@ import datetime, inspect, re, sys from array import array as _array from sqlalchemy import exceptions, logging, schema, sql, util -from sqlalchemy.pool import connection_cache_decorator from sqlalchemy.sql import operators as sql_operators from sqlalchemy.sql import functions as sql_functions from sqlalchemy.sql import compiler @@ -224,6 +223,10 @@ AUTOCOMMIT_RE = re.compile( SELECT_RE = re.compile( r'\s*(?:SELECT|SHOW|DESCRIBE|XA RECOVER)', re.I | re.UNICODE) +SET_RE = re.compile( + r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w', + re.I | re.UNICODE) + class _NumericType(object): """Base for MySQL numeric types.""" @@ -1396,6 +1399,11 @@ class MySQLExecutionContext(default.DefaultExecutionContext): self._last_inserted_ids[0] is None): self._last_inserted_ids = ([self.cursor.lastrowid] + self._last_inserted_ids[1:]) + elif (not self.isupdate and not self.should_autocommit and + self.statement and SET_RE.match(self.statement)): + # This misses if a user forces autocommit on text('SET NAMES'), + # which is probably a programming error anyhow. + self.connection.info.pop(('mysql', 'charset'), None) def returns_rows_text(self, statement): return SELECT_RE.match(statement) @@ -1544,8 +1552,9 @@ class MySQLDialect(default.DefaultDialect): def get_default_schema_name(self, connection): return connection.execute('SELECT DATABASE()').scalar() - get_default_schema_name = connection_cache_decorator(get_default_schema_name) - + get_default_schema_name = engine_base.connection_memoize( + ('dialect', 'default_schema_name'))(get_default_schema_name) + def table_names(self, connection, schema): """Return a Unicode SHOW TABLES from a given schema.""" @@ -1599,12 +1608,9 @@ class MySQLDialect(default.DefaultDialect): cached per-Connection. """ - try: - return connection.info['_mysql_server_version_info'] - except KeyError: - version = connection.info['_mysql_server_version_info'] = \ - self._server_version_info(connection.connection.connection) - return version + return self._server_version_info(connection.connection.connection) + server_version_info = engine_base.connection_memoize( + ('mysql', 'server_version_info'))(server_version_info) def _server_version_info(self, dbapi_con): """Convert a MySQL-python server_info string into a tuple.""" @@ -1654,7 +1660,7 @@ class MySQLDialect(default.DefaultDialect): columns = self._describe_table(connection, table, charset) sql = reflector._describe_to_create(table, columns) - self._adjust_casing(connection, table, charset) + self._adjust_casing(connection, table) return reflector.reflect(connection, table, sql, charset, only=include_columns) @@ -1662,10 +1668,7 @@ class MySQLDialect(default.DefaultDialect): def _adjust_casing(self, connection, table, charset=None): """Adjust Table name to the server case sensitivity, if needed.""" - if charset is None: - charset = self._detect_charset(connection) - - casing = self._detect_casing(connection, charset) + casing = self._detect_casing(connection) # For winxx database hosts. TODO: is this really needed? if casing == 1 and table.name != table.name.lower(): @@ -1678,8 +1681,8 @@ class MySQLDialect(default.DefaultDialect): """Sniff out the character set in use for connection results.""" # Allow user override, won't sniff if force_charset is set. - if 'force_charset' in connection.info: - return connection.info['force_charset'] + if ('mysql', 'force_charset') in connection.info: + return connection.info[('mysql', 'force_charset')] # Note: MySQL-python 1.2.1c7 seems to ignore changes made # on a connection via set_character_set() @@ -1714,55 +1717,56 @@ class MySQLDialect(default.DefaultDialect): "combination of MySQL server and MySQL-python. " "MySQL-python >= 1.2.2 is recommended. Assuming latin1.") return 'latin1' + _detect_charset = engine_base.connection_memoize( + ('mysql', 'charset'))(_detect_charset) + - def _detect_casing(self, connection, charset=None): + def _detect_casing(self, connection): """Sniff out identifier case sensitivity. Cached per-connection. This value can not change without a server restart. - """ + """ # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html - try: - return connection.info['lower_case_table_names'] - except KeyError: - row = _compat_fetchone(connection.execute( - "SHOW VARIABLES LIKE 'lower_case_table_names'"), - charset=charset) - if not row: + charset = self._detect_charset(connection) + row = _compat_fetchone(connection.execute( + "SHOW VARIABLES LIKE 'lower_case_table_names'"), + charset=charset) + if not row: + cs = 0 + else: + # 4.0.15 returns OFF or ON according to [ticket:489] + # 3.23 doesn't, 4.0.27 doesn't.. + if row[1] == 'OFF': cs = 0 + elif row[1] == 'ON': + cs = 1 else: - # 4.0.15 returns OFF or ON according to [ticket:489] - # 3.23 doesn't, 4.0.27 doesn't.. - if row[1] == 'OFF': - cs = 0 - elif row[1] == 'ON': - cs = 1 - else: - cs = int(row[1]) - row.close() - connection.info['lower_case_table_names'] = cs - return cs + cs = int(row[1]) + row.close() + return cs + _detect_casing = engine_base.connection_memoize( + ('mysql', 'lower_case_table_names'))(_detect_casing) - def _detect_collations(self, connection, charset=None): + def _detect_collations(self, connection): """Pull the active COLLATIONS list from the server. Cached per-connection. """ - try: - return connection.info['collations'] - except KeyError: - collations = {} - if self.server_version_info(connection) < (4, 1, 0): - pass - else: - rs = connection.execute('SHOW COLLATION') - for row in _compat_fetchall(rs, charset): - collations[row[0]] = row[1] - connection.info['collations'] = collations - return collations + collations = {} + if self.server_version_info(connection) < (4, 1, 0): + pass + else: + charset = self._detect_charset(connection) + rs = connection.execute('SHOW COLLATION') + for row in _compat_fetchall(rs, charset): + collations[row[0]] = row[1] + return collations + _detect_collations = engine_base.connection_memoize( + ('mysql', 'collations'))(_detect_collations) def use_ansiquotes(self, useansi): self._use_ansiquotes = useansi diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index b23ec06e5..2faeab651 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -12,7 +12,6 @@ from sqlalchemy.engine import default, base from sqlalchemy.sql import compiler, visitors from sqlalchemy.sql import operators as sql_operators, functions as sql_functions from sqlalchemy import types as sqltypes -from sqlalchemy.pool import connection_cache_decorator class OracleNumeric(sqltypes.Numeric): @@ -379,9 +378,10 @@ class OracleDialect(default.DefaultDialect): else: return name.encode(self.encoding) - def get_default_schema_name(self,connection): + def get_default_schema_name(self, connection): return connection.execute('SELECT USER FROM DUAL').scalar() - get_default_schema_name = connection_cache_decorator(get_default_schema_name) + get_default_schema_name = base.connection_memoize( + ('dialect', 'default_schema_name'))(get_default_schema_name) def table_names(self, connection, schema): # note that table_names() isnt loading DBLINKed or synonym'ed tables diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index abae27eb1..94ad7d2e4 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -26,7 +26,6 @@ from sqlalchemy.engine import base, default from sqlalchemy.sql import compiler, expression from sqlalchemy.sql import operators as sql_operators from sqlalchemy import types as sqltypes -from sqlalchemy.pool import connection_cache_decorator class PGInet(sqltypes.TypeEngine): @@ -369,7 +368,8 @@ class PGDialect(default.DefaultDialect): def get_default_schema_name(self, connection): return connection.scalar("select current_schema()", None) - get_default_schema_name = connection_cache_decorator(get_default_schema_name) + get_default_schema_name = base.connection_memoize( + ('dialect', 'default_schema_name'))(get_default_schema_name) def last_inserted_ids(self): if self.context.last_inserted_ids is None: diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 3d7e65198..cd662ac92 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -12,7 +12,7 @@ higher-level statement-construction, connection-management, execution and result contexts. """ -import StringIO, sys +import inspect, StringIO, sys from sqlalchemy import exceptions, schema, util, types, logging from sqlalchemy.sql import expression @@ -1864,3 +1864,27 @@ class DefaultRunner(schema.SchemaVisitor): return default.arg(self.context) else: return default.arg + + +def connection_memoize(key): + """Decorator, memoize a function in a connection.info stash. + + Only applicable to functions which take no arguments other than a + connection. The memo will be stored in ``connection.info[key]``. + + """ + def decorate(fn): + spec = inspect.getargspec(fn) + assert len(spec[0]) == 2 + assert spec[0][1] == 'connection' + assert spec[1:3] == (None, None) + + def decorated(self, connection): + try: + return connection.info[key] + except KeyError: + connection.info[key] = val = fn(self, connection) + return val + + return util.function_named(decorated, fn.__name__) + return decorate diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index e22d1d8d3..94d9127f0 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -58,21 +58,6 @@ def clear_managers(): manager.close() proxies.clear() -def connection_cache_decorator(func): - """apply caching to the return value of a function, using - the 'info' collection on its given connection.""" - - name = func.__name__ - - def do_with_cache(self, connection): - try: - return connection.info[name] - except KeyError: - value = func(self, connection) - connection.info[name] = value - return value - return do_with_cache - class Pool(object): """Base class for connection pools. diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 26b6dbe9a..101ef1462 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import inspect, itertools, sets, sys, warnings, weakref +import inspect, itertools, new, sets, sys, warnings, weakref import __builtin__ types = __import__('types') @@ -1055,7 +1055,22 @@ class symbol(object): return sym finally: symbol._lock.release() - + + +def function_named(fn, name): + """Return a function with a given __name__. + + Will assign to __name__ and return the original function if possible on + the Python implementation, otherwise a new function will be constructed. + + """ + try: + fn.__name__ = name + except TypeError: + fn = new.function(fn.func_code, fn.func_globals, name, + fn.func_defaults, fn.func_closure) + return fn + def conditional_cache_decorator(func): """apply conditional caching to the return value of a function.""" @@ -1166,8 +1181,5 @@ def _decorate_with_warning(func, wtype, message, docstring_header=None): func_with_warning.__doc__ = doc func_with_warning.__dict__.update(func.__dict__) - try: - func_with_warning.__name__ = func.__name__ - except TypeError: - pass - return func_with_warning + + return function_named(func_with_warning, func.__name__) diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py index e1bd47d29..00478908e 100644 --- a/test/dialect/mysql.py +++ b/test/dialect/mysql.py @@ -927,6 +927,49 @@ class SQLTest(TestBase, AssertsCompiledSQL): self.assert_compile(cast(t.c.col, type_), expected) +class ExecutionTest(TestBase): + """Various MySQL execution special cases.""" + + __only_on__ = 'mysql' + + def test_charset_caching(self): + engine = engines.testing_engine() + + cx = engine.connect() + meta = MetaData() + + assert ('mysql', 'charset') not in cx.info + assert ('mysql', 'force_charset') not in cx.info + + cx.execute(text("SELECT 1")).fetchall() + assert ('mysql', 'charset') not in cx.info + + meta.reflect(cx) + assert ('mysql', 'charset') in cx.info + + cx.execute(text("SET @squiznart=123")) + assert ('mysql', 'charset') in cx.info + + # the charset invalidation is very conservative + cx.execute(text("SET TIMESTAMP = DEFAULT")) + assert ('mysql', 'charset') not in cx.info + + cx.info[('mysql', 'force_charset')] = 'latin1' + + assert engine.dialect._detect_charset(cx) == 'latin1' + assert cx.info[('mysql', 'charset')] == 'latin1' + + del cx.info[('mysql', 'force_charset')] + del cx.info[('mysql', 'charset')] + + meta.reflect(cx) + assert ('mysql', 'charset') in cx.info + + # String execution doesn't go through the detector. + cx.execute("SET TIMESTAMP = DEFAULT") + assert ('mysql', 'charset') in cx.info + + def colspec(c): return testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None).get_column_specification(c) |