summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorJason Kirtland <jek@discorporate.us>2008-04-02 11:39:26 +0000
committerJason Kirtland <jek@discorporate.us>2008-04-02 11:39:26 +0000
commitf12969a4d8dd8a4f91a9da4cc9d855457a5a0060 (patch)
tree274125cdcd73dbcea79575b7a82810032ae0be36 /lib/sqlalchemy
parenta0000075437370ae23d0db48c706a4237f4e43f3 (diff)
downloadsqlalchemy-f12969a4d8dd8a4f91a9da4cc9d855457a5a0060.tar.gz
- Revamped the Connection memoize decorator a bit, moved to engine
- MySQL character set caching is more aggressive but will invalidate the cache if a SET is issued. - MySQL connection memos are namespaced: info[('mysql', 'server_variable')]
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/databases/mysql.py102
-rw-r--r--lib/sqlalchemy/databases/oracle.py6
-rw-r--r--lib/sqlalchemy/databases/postgres.py4
-rw-r--r--lib/sqlalchemy/engine/base.py26
-rw-r--r--lib/sqlalchemy/pool.py15
-rw-r--r--lib/sqlalchemy/util.py26
6 files changed, 102 insertions, 77 deletions
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index 18b236d1c..c4b8af684 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -157,7 +157,6 @@ import datetime, inspect, re, sys
from array import array as _array
from sqlalchemy import exceptions, logging, schema, sql, util
-from sqlalchemy.pool import connection_cache_decorator
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy.sql import functions as sql_functions
from sqlalchemy.sql import compiler
@@ -224,6 +223,10 @@ AUTOCOMMIT_RE = re.compile(
SELECT_RE = re.compile(
r'\s*(?:SELECT|SHOW|DESCRIBE|XA RECOVER)',
re.I | re.UNICODE)
+SET_RE = re.compile(
+ r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w',
+ re.I | re.UNICODE)
+
class _NumericType(object):
"""Base for MySQL numeric types."""
@@ -1396,6 +1399,11 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
self._last_inserted_ids[0] is None):
self._last_inserted_ids = ([self.cursor.lastrowid] +
self._last_inserted_ids[1:])
+ elif (not self.isupdate and not self.should_autocommit and
+ self.statement and SET_RE.match(self.statement)):
+ # This misses if a user forces autocommit on text('SET NAMES'),
+ # which is probably a programming error anyhow.
+ self.connection.info.pop(('mysql', 'charset'), None)
def returns_rows_text(self, statement):
return SELECT_RE.match(statement)
@@ -1544,8 +1552,9 @@ class MySQLDialect(default.DefaultDialect):
def get_default_schema_name(self, connection):
return connection.execute('SELECT DATABASE()').scalar()
- get_default_schema_name = connection_cache_decorator(get_default_schema_name)
-
+ get_default_schema_name = engine_base.connection_memoize(
+ ('dialect', 'default_schema_name'))(get_default_schema_name)
+
def table_names(self, connection, schema):
"""Return a Unicode SHOW TABLES from a given schema."""
@@ -1599,12 +1608,9 @@ class MySQLDialect(default.DefaultDialect):
cached per-Connection.
"""
- try:
- return connection.info['_mysql_server_version_info']
- except KeyError:
- version = connection.info['_mysql_server_version_info'] = \
- self._server_version_info(connection.connection.connection)
- return version
+ return self._server_version_info(connection.connection.connection)
+ server_version_info = engine_base.connection_memoize(
+ ('mysql', 'server_version_info'))(server_version_info)
def _server_version_info(self, dbapi_con):
"""Convert a MySQL-python server_info string into a tuple."""
@@ -1654,7 +1660,7 @@ class MySQLDialect(default.DefaultDialect):
columns = self._describe_table(connection, table, charset)
sql = reflector._describe_to_create(table, columns)
- self._adjust_casing(connection, table, charset)
+ self._adjust_casing(connection, table)
return reflector.reflect(connection, table, sql, charset,
only=include_columns)
@@ -1662,10 +1668,7 @@ class MySQLDialect(default.DefaultDialect):
def _adjust_casing(self, connection, table, charset=None):
"""Adjust Table name to the server case sensitivity, if needed."""
- if charset is None:
- charset = self._detect_charset(connection)
-
- casing = self._detect_casing(connection, charset)
+ casing = self._detect_casing(connection)
# For winxx database hosts. TODO: is this really needed?
if casing == 1 and table.name != table.name.lower():
@@ -1678,8 +1681,8 @@ class MySQLDialect(default.DefaultDialect):
"""Sniff out the character set in use for connection results."""
# Allow user override, won't sniff if force_charset is set.
- if 'force_charset' in connection.info:
- return connection.info['force_charset']
+ if ('mysql', 'force_charset') in connection.info:
+ return connection.info[('mysql', 'force_charset')]
# Note: MySQL-python 1.2.1c7 seems to ignore changes made
# on a connection via set_character_set()
@@ -1714,55 +1717,56 @@ class MySQLDialect(default.DefaultDialect):
"combination of MySQL server and MySQL-python. "
"MySQL-python >= 1.2.2 is recommended. Assuming latin1.")
return 'latin1'
+ _detect_charset = engine_base.connection_memoize(
+ ('mysql', 'charset'))(_detect_charset)
+
- def _detect_casing(self, connection, charset=None):
+ def _detect_casing(self, connection):
"""Sniff out identifier case sensitivity.
Cached per-connection. This value can not change without a server
restart.
- """
+ """
# http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
- try:
- return connection.info['lower_case_table_names']
- except KeyError:
- row = _compat_fetchone(connection.execute(
- "SHOW VARIABLES LIKE 'lower_case_table_names'"),
- charset=charset)
- if not row:
+ charset = self._detect_charset(connection)
+ row = _compat_fetchone(connection.execute(
+ "SHOW VARIABLES LIKE 'lower_case_table_names'"),
+ charset=charset)
+ if not row:
+ cs = 0
+ else:
+ # 4.0.15 returns OFF or ON according to [ticket:489]
+ # 3.23 doesn't, 4.0.27 doesn't..
+ if row[1] == 'OFF':
cs = 0
+ elif row[1] == 'ON':
+ cs = 1
else:
- # 4.0.15 returns OFF or ON according to [ticket:489]
- # 3.23 doesn't, 4.0.27 doesn't..
- if row[1] == 'OFF':
- cs = 0
- elif row[1] == 'ON':
- cs = 1
- else:
- cs = int(row[1])
- row.close()
- connection.info['lower_case_table_names'] = cs
- return cs
+ cs = int(row[1])
+ row.close()
+ return cs
+ _detect_casing = engine_base.connection_memoize(
+ ('mysql', 'lower_case_table_names'))(_detect_casing)
- def _detect_collations(self, connection, charset=None):
+ def _detect_collations(self, connection):
"""Pull the active COLLATIONS list from the server.
Cached per-connection.
"""
- try:
- return connection.info['collations']
- except KeyError:
- collations = {}
- if self.server_version_info(connection) < (4, 1, 0):
- pass
- else:
- rs = connection.execute('SHOW COLLATION')
- for row in _compat_fetchall(rs, charset):
- collations[row[0]] = row[1]
- connection.info['collations'] = collations
- return collations
+ collations = {}
+ if self.server_version_info(connection) < (4, 1, 0):
+ pass
+ else:
+ charset = self._detect_charset(connection)
+ rs = connection.execute('SHOW COLLATION')
+ for row in _compat_fetchall(rs, charset):
+ collations[row[0]] = row[1]
+ return collations
+ _detect_collations = engine_base.connection_memoize(
+ ('mysql', 'collations'))(_detect_collations)
def use_ansiquotes(self, useansi):
self._use_ansiquotes = useansi
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index b23ec06e5..2faeab651 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -12,7 +12,6 @@ from sqlalchemy.engine import default, base
from sqlalchemy.sql import compiler, visitors
from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
from sqlalchemy import types as sqltypes
-from sqlalchemy.pool import connection_cache_decorator
class OracleNumeric(sqltypes.Numeric):
@@ -379,9 +378,10 @@ class OracleDialect(default.DefaultDialect):
else:
return name.encode(self.encoding)
- def get_default_schema_name(self,connection):
+ def get_default_schema_name(self, connection):
return connection.execute('SELECT USER FROM DUAL').scalar()
- get_default_schema_name = connection_cache_decorator(get_default_schema_name)
+ get_default_schema_name = base.connection_memoize(
+ ('dialect', 'default_schema_name'))(get_default_schema_name)
def table_names(self, connection, schema):
# note that table_names() isnt loading DBLINKed or synonym'ed tables
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index abae27eb1..94ad7d2e4 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -26,7 +26,6 @@ from sqlalchemy.engine import base, default
from sqlalchemy.sql import compiler, expression
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import types as sqltypes
-from sqlalchemy.pool import connection_cache_decorator
class PGInet(sqltypes.TypeEngine):
@@ -369,7 +368,8 @@ class PGDialect(default.DefaultDialect):
def get_default_schema_name(self, connection):
return connection.scalar("select current_schema()", None)
- get_default_schema_name = connection_cache_decorator(get_default_schema_name)
+ get_default_schema_name = base.connection_memoize(
+ ('dialect', 'default_schema_name'))(get_default_schema_name)
def last_inserted_ids(self):
if self.context.last_inserted_ids is None:
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 3d7e65198..cd662ac92 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -12,7 +12,7 @@ higher-level statement-construction, connection-management, execution
and result contexts.
"""
-import StringIO, sys
+import inspect, StringIO, sys
from sqlalchemy import exceptions, schema, util, types, logging
from sqlalchemy.sql import expression
@@ -1864,3 +1864,27 @@ class DefaultRunner(schema.SchemaVisitor):
return default.arg(self.context)
else:
return default.arg
+
+
+def connection_memoize(key):
+ """Decorator, memoize a function in a connection.info stash.
+
+ Only applicable to functions which take no arguments other than a
+ connection. The memo will be stored in ``connection.info[key]``.
+
+ """
+ def decorate(fn):
+ spec = inspect.getargspec(fn)
+ assert len(spec[0]) == 2
+ assert spec[0][1] == 'connection'
+ assert spec[1:3] == (None, None)
+
+ def decorated(self, connection):
+ try:
+ return connection.info[key]
+ except KeyError:
+ connection.info[key] = val = fn(self, connection)
+ return val
+
+ return util.function_named(decorated, fn.__name__)
+ return decorate
diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py
index e22d1d8d3..94d9127f0 100644
--- a/lib/sqlalchemy/pool.py
+++ b/lib/sqlalchemy/pool.py
@@ -58,21 +58,6 @@ def clear_managers():
manager.close()
proxies.clear()
-def connection_cache_decorator(func):
- """apply caching to the return value of a function, using
- the 'info' collection on its given connection."""
-
- name = func.__name__
-
- def do_with_cache(self, connection):
- try:
- return connection.info[name]
- except KeyError:
- value = func(self, connection)
- connection.info[name] = value
- return value
- return do_with_cache
-
class Pool(object):
"""Base class for connection pools.
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index 26b6dbe9a..101ef1462 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -4,7 +4,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import inspect, itertools, sets, sys, warnings, weakref
+import inspect, itertools, new, sets, sys, warnings, weakref
import __builtin__
types = __import__('types')
@@ -1055,7 +1055,22 @@ class symbol(object):
return sym
finally:
symbol._lock.release()
-
+
+
+def function_named(fn, name):
+ """Return a function with a given __name__.
+
+ Will assign to __name__ and return the original function if possible on
+ the Python implementation, otherwise a new function will be constructed.
+
+ """
+ try:
+ fn.__name__ = name
+ except TypeError:
+ fn = new.function(fn.func_code, fn.func_globals, name,
+ fn.func_defaults, fn.func_closure)
+ return fn
+
def conditional_cache_decorator(func):
"""apply conditional caching to the return value of a function."""
@@ -1166,8 +1181,5 @@ def _decorate_with_warning(func, wtype, message, docstring_header=None):
func_with_warning.__doc__ = doc
func_with_warning.__dict__.update(func.__dict__)
- try:
- func_with_warning.__name__ = func.__name__
- except TypeError:
- pass
- return func_with_warning
+
+ return function_named(func_with_warning, func.__name__)