summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-10-25 19:44:21 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-10-25 19:44:21 +0000
commit76e8175971a5fa5abb7f7e1f85ccd3068f04aff5 (patch)
treea1284c646dbfe8fce1980be126205c7ee80eae3a
parent25e5157785ee9dd7a3bbb606b1f7642342936d18 (diff)
downloadsqlalchemy-76e8175971a5fa5abb7f7e1f85ccd3068f04aff5.tar.gz
- moved _FigureVisitName into visitiors.VisitorType, added Visitor base class to reduce dependencies
- implemented _generative decorator for select/update/insert/delete constructs - other minutiae
-rw-r--r--lib/sqlalchemy/schema.py6
-rw-r--r--lib/sqlalchemy/sql/compiler.py5
-rw-r--r--lib/sqlalchemy/sql/expression.py200
-rw-r--r--lib/sqlalchemy/sql/functions.py6
-rw-r--r--lib/sqlalchemy/sql/visitors.py57
-rw-r--r--test/profiling/zoomark.py2
6 files changed, 138 insertions, 138 deletions
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 567b1d43a..9e8df78ac 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -40,10 +40,9 @@ __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
'DefaulClause', 'FetchedValue', 'ColumnDefault', 'DDL']
-class SchemaItem(object):
+class SchemaItem(visitors.Visitable):
"""Base class for items that define a database schema."""
- __metaclass__ = expression._FigureVisitName
quote = None
def _init_items(self, *args):
@@ -87,7 +86,7 @@ def _get_table_key(name, schema):
else:
return schema + "." + name
-class _TableSingleton(expression._FigureVisitName):
+class _TableSingleton(visitors.VisitableType):
"""A metaclass used by the ``Table`` object to provide singleton behavior."""
def __call__(self, name, metadata, *args, **kwargs):
@@ -2050,3 +2049,4 @@ def _bind_or_error(schemaitem):
'assign %s to enable implicit execution.') % (item, bindable)
raise exc.UnboundExecutionError(msg)
return bind
+
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index b4069e6fd..63557f24b 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -20,7 +20,7 @@ is otherwise internal to SQLAlchemy.
import string, re
from sqlalchemy import schema, engine, util, exc
-from sqlalchemy.sql import operators, functions, util as sql_util
+from sqlalchemy.sql import operators, functions, util as sql_util, visitors
from sqlalchemy.sql import expression as sql
RESERVED_WORDS = set([
@@ -110,10 +110,9 @@ FUNCTIONS = {
}
-class _CompileLabel(object):
+class _CompileLabel(visitors.Visitable):
"""lightweight label object which acts as an expression._Label."""
- __metaclass__ = sql._FigureVisitName
__visit_name__ = 'label'
__slots__ = 'element', 'name'
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index d937d0507..52291c487 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -29,7 +29,8 @@ import itertools, re
from operator import attrgetter
from sqlalchemy import util, exc
-from sqlalchemy.sql import operators, visitors
+from sqlalchemy.sql import operators
+from sqlalchemy.sql.visitors import Visitable, cloned_traverse
from sqlalchemy import types as sqltypes
functions, schema, sql_util = None, None, None
@@ -876,7 +877,7 @@ def _compound_select(keyword, *selects, **kwargs):
return CompoundSelect(keyword, *selects, **kwargs)
def _is_literal(element):
- return not isinstance(element, ClauseElement)
+ return not isinstance(element, Visitable) and not hasattr(element, '__clause_element__')
def _from_objects(*elements, **kwargs):
return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements])
@@ -890,7 +891,7 @@ def _labeled(element):
def _literal_as_text(element):
if hasattr(element, '__clause_element__'):
return element.__clause_element__()
- elif not isinstance(element, ClauseElement):
+ elif not isinstance(element, Visitable):
return _TextClause(unicode(element))
else:
return element
@@ -898,7 +899,7 @@ def _literal_as_text(element):
def _literal_as_column(element):
if hasattr(element, '__clause_element__'):
return element.__clause_element__()
- elif not isinstance(element, ClauseElement):
+ elif not isinstance(element, Visitable):
return literal_column(str(element))
else:
return element
@@ -906,7 +907,7 @@ def _literal_as_column(element):
def _literal_as_binds(element, name=None, type_=None):
if hasattr(element, '__clause_element__'):
return element.__clause_element__()
- elif not isinstance(element, ClauseElement):
+ elif not isinstance(element, Visitable):
if element is None:
return null()
else:
@@ -917,15 +918,18 @@ def _literal_as_binds(element, name=None, type_=None):
def _no_literals(element):
if hasattr(element, '__clause_element__'):
return element.__clause_element__()
- elif not isinstance(element, ClauseElement):
- raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
+ elif not isinstance(element, Visitable):
+ raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function "
+ "to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
else:
return element
def _corresponding_column_or_error(fromclause, column, require_embedded=False):
c = fromclause.corresponding_column(column, require_embedded=require_embedded)
if not c:
- raise exc.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description))
+ raise exc.InvalidRequestError("Given column '%s', attached to table '%s', "
+ "failed to locate a corresponding column from table '%s'"
+ % (column, getattr(column, 'table', None), fromclause.description))
return c
def _selectable(element):
@@ -934,39 +938,15 @@ def _selectable(element):
elif isinstance(element, Selectable):
return element
else:
- raise exc.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
+ raise exc.ArgumentError("Object %r is not a Selectable and does not implement `__selectable__()`" % element)
def is_column(col):
"""True if ``col`` is an instance of ``ColumnElement``."""
return isinstance(col, ColumnElement)
-class _FigureVisitName(type):
- def __init__(cls, clsname, bases, dict):
- if not '__visit_name__' in cls.__dict__:
- m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
- x = m.group(1)
- x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
- cls.__visit_name__ = x.lower()
-
- # set up an optimized visit dispatch function
- # for use by the compiler
- visit_name = cls.__dict__["__visit_name__"]
- if isinstance(visit_name, str):
- func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
- " return visitor.visit_%s(self, **kw)" % visit_name
- else:
- func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
- " return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)"
- env = locals().copy()
- exec func_text in env
- cls._compiler_dispatch = env['_compiler_dispatch']
-
- super(_FigureVisitName, cls).__init__(clsname, bases, dict)
-
-class ClauseElement(object):
+class ClauseElement(Visitable):
"""Base class for elements of a programmatically constructed SQL expression."""
- __metaclass__ = _FigureVisitName
_annotations = {}
supports_execution = False
@@ -976,6 +956,7 @@ class ClauseElement(object):
This method may be used by a generative API. Its also used as
part of the "deep" copy afforded by a traversal that combines
the _copy_internals() method.
+
"""
c = self.__class__.__new__(self.__class__)
c.__dict__ = self.__dict__.copy()
@@ -1001,8 +982,8 @@ class ClauseElement(object):
should be added to the ``FROM`` list of a query, when this
``ClauseElement`` is placed in the column clause of a
``Select`` statement.
+
"""
-
raise NotImplementedError(repr(self))
def _annotate(self, values):
@@ -1049,7 +1030,7 @@ class ClauseElement(object):
bind.value = kwargs[bind.key]
if unique:
bind._convert_to_unique()
- return visitors.cloned_traverse(self, {}, {'bindparam':visit_bindparam})
+ return cloned_traverse(self, {}, {'bindparam':visit_bindparam})
def compare(self, other):
"""Compare this ClauseElement to the given ClauseElement.
@@ -2637,14 +2618,13 @@ class _ColumnClause(_Immutable, ColumnElement):
rules applied regardless of case sensitive settings. the
``literal_column()`` function is usually used to create such a
``_ColumnClause``.
+
"""
-
def __init__(self, text, selectable=None, type_=None, is_literal=False):
ColumnElement.__init__(self)
self.key = self.name = text
self.table = selectable
self.type = sqltypes.to_instance(type_)
- self.__label = None
self.is_literal = is_literal
@util.memoized_property
@@ -2655,23 +2635,25 @@ class _ColumnClause(_Immutable, ColumnElement):
def _label(self):
if self.is_literal:
return None
- if not self.__label:
- if self.table and self.table.named_with_column:
- if getattr(self.table, 'schema', None):
- self.__label = self.table.schema + "_" + self.table.name + "_" + self.name
- else:
- self.__label = self.table.name + "_" + self.name
-
- if self.__label in self.table.c:
- label = self.__label
- counter = 1
- while label in self.table.c:
- label = self.__label + "_" + str(counter)
- counter += 1
- self.__label = label
+
+ elif self.table and self.table.named_with_column:
+ if getattr(self.table, 'schema', None):
+ label = self.table.schema + "_" + self.table.name + "_" + self.name
else:
- self.__label = self.name
- return self.__label
+ label = self.table.name + "_" + self.name
+
+ if label in self.table.c:
+ # TODO: coverage does not seem to be present for this
+ _label = label
+ counter = 1
+ while _label in self.table.c:
+ _label = label + "_" + str(counter)
+ counter += 1
+ label = _label
+ return label
+
+ else:
+ return self.name
def label(self, name):
if name is None:
@@ -2723,7 +2705,7 @@ class TableClause(_Immutable, FromClause):
def _export_columns(self):
raise NotImplementedError()
- @property
+ @util.memoized_property
def description(self):
return self.name.encode('ascii', 'backslashreplace')
@@ -2756,6 +2738,14 @@ class TableClause(_Immutable, FromClause):
def _get_from_objects(self, **modifiers):
return [self]
+@util.decorator
+def _generative(fn, *args, **kw):
+ """Mark a method as generative."""
+
+ self = args[0]._generate()
+ fn(self, *args[1:], **kw)
+ return self
+
class _SelectBaseMixin(object):
"""Base class for ``Select`` and ``CompoundSelects``."""
@@ -2784,6 +2774,7 @@ class _SelectBaseMixin(object):
"""
return _ScalarSelect(self)
+ @_generative
def apply_labels(self):
"""return a new selectable with the 'use_labels' flag set to True.
@@ -2793,9 +2784,7 @@ class _SelectBaseMixin(object):
among the individual FROM clauses.
"""
- s = self._generate()
- s.use_labels = True
- return s
+ self.use_labels = True
def label(self, name):
"""return a 'scalar' representation of this selectable, embedded as a subquery
@@ -2806,12 +2795,11 @@ class _SelectBaseMixin(object):
"""
return self.as_scalar().label(name)
+ @_generative
def autocommit(self):
"""return a new selectable with the 'autocommit' flag set to True."""
- s = self._generate()
- s._autocommit = True
- return s
+ self._autocommit = True
def _generate(self):
s = self.__class__.__new__(self.__class__)
@@ -2819,39 +2807,35 @@ class _SelectBaseMixin(object):
s._reset_exported()
return s
+ @_generative
def limit(self, limit):
"""return a new selectable with the given LIMIT criterion applied."""
- s = self._generate()
- s._limit = limit
- return s
+ self._limit = limit
+ @_generative
def offset(self, offset):
"""return a new selectable with the given OFFSET criterion applied."""
- s = self._generate()
- s._offset = offset
- return s
+ self._offset = offset
+ @_generative
def order_by(self, *clauses):
"""return a new selectable with the given list of ORDER BY criterion applied.
The criterion will be appended to any pre-existing ORDER BY criterion.
"""
- s = self._generate()
- s.append_order_by(*clauses)
- return s
+ self.append_order_by(*clauses)
+ @_generative
def group_by(self, *clauses):
"""return a new selectable with the given list of GROUP BY criterion applied.
The criterion will be appended to any pre-existing GROUP BY criterion.
"""
- s = self._generate()
- s.append_group_by(*clauses)
- return s
+ self.append_group_by(*clauses)
def append_order_by(self, *clauses):
"""Append the given ORDER BY criterion applied to this selectable.
@@ -3112,72 +3096,67 @@ class Select(_SelectBaseMixin, FromClause):
self._raw_columns + list(self._froms) + \
[x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
+ @_generative
def column(self, column):
"""return a new select() construct with the given column expression added to its columns clause."""
- s = self._generate()
column = _literal_as_column(column)
if isinstance(column, _ScalarSelect):
column = column.self_group(against=operators.comma_op)
- s._raw_columns = s._raw_columns + [column]
- s._froms = s._froms.union(_from_objects(column))
- return s
+ self._raw_columns = self._raw_columns + [column]
+ self._froms = self._froms.union(_from_objects(column))
+ @_generative
def with_only_columns(self, columns):
"""return a new select() construct with its columns clause replaced with the given columns."""
- s = self._generate()
- s._raw_columns = [
+
+ self._raw_columns = [
isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c
for c in
[_literal_as_column(c) for c in columns]
]
- return s
+ @_generative
def where(self, whereclause):
"""return a new select() construct with the given expression added to its WHERE clause, joined
to the existing clause via AND, if any."""
- s = self._generate()
- s.append_whereclause(whereclause)
- return s
+ self.append_whereclause(whereclause)
+ @_generative
def having(self, having):
"""return a new select() construct with the given expression added to its HAVING clause, joined
to the existing clause via AND, if any."""
- s = self._generate()
- s.append_having(having)
- return s
+ self.append_having(having)
+ @_generative
def distinct(self):
"""return a new select() construct which will apply DISTINCT to its columns clause."""
- s = self._generate()
- s._distinct = True
- return s
+ self._distinct = True
+ @_generative
def prefix_with(self, clause):
"""return a new select() construct which will apply the given expression to the start of its
columns clause, not using any commas."""
- s = self._generate()
clause = _literal_as_text(clause)
- s._prefixes = s._prefixes + [clause]
- return s
+ self._prefixes = self._prefixes + [clause]
+ @_generative
def select_from(self, fromclause):
"""return a new select() construct with the given FROM expression applied to its list of
FROM objects."""
- s = self._generate()
if _is_literal(fromclause):
fromclause = _TextClause(fromclause)
- s._froms = s._froms.union([fromclause])
- return s
+ self._froms = self._froms.union([fromclause])
+ @_generative
def correlate(self, *fromclauses):
"""return a new select() construct which will correlate the given FROM clauses to that
of an enclosing select(), if a match is found.
@@ -3192,13 +3171,11 @@ class Select(_SelectBaseMixin, FromClause):
If the fromclause is None, correlation is disabled for the returned select().
"""
- s = self._generate()
- s._should_correlate = False
+ self._should_correlate = False
if fromclauses == (None,):
- s._correlate = set()
+ self._correlate = set()
else:
- s._correlate = s._correlate.union(fromclauses)
- return s
+ self._correlate = self._correlate.union(fromclauses)
def append_correlation(self, fromclause):
"""append the given correlation expression to this select() construct."""
@@ -3416,16 +3393,15 @@ class Insert(_ValuesBase):
def _copy_internals(self, clone=_clone):
self.parameters = self.parameters.copy()
+ @_generative
def prefix_with(self, clause):
"""Add a word or expression between INSERT and INTO. Generative.
If multiple prefixes are supplied, they will be separated with
spaces.
"""
- gen = self._generate()
clause = _literal_as_text(clause)
- gen._prefixes = self._prefixes + [clause]
- return gen
+ self._prefixes = self._prefixes + [clause]
class Update(_ValuesBase):
def __init__(self, table, whereclause, values=None, inline=False, bind=None, **kwargs):
@@ -3450,16 +3426,15 @@ class Update(_ValuesBase):
self._whereclause = clone(self._whereclause)
self.parameters = self.parameters.copy()
+ @_generative
def where(self, whereclause):
"""return a new update() construct with the given expression added to its WHERE clause, joined
to the existing clause via AND, if any."""
- s = self._generate()
- if s._whereclause is not None:
- s._whereclause = and_(s._whereclause, _literal_as_text(whereclause))
+ if self._whereclause is not None:
+ self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
else:
- s._whereclause = _literal_as_text(whereclause)
- return s
+ self._whereclause = _literal_as_text(whereclause)
class Delete(_UpdateBase):
@@ -3479,16 +3454,15 @@ class Delete(_UpdateBase):
else:
return ()
+ @_generative
def where(self, whereclause):
"""return a new delete() construct with the given expression added to its WHERE clause, joined
to the existing clause via AND, if any."""
- s = self._generate()
- if s._whereclause is not None:
- s._whereclause = and_(s._whereclause, _literal_as_text(whereclause))
+ if self._whereclause is not None:
+ self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
else:
- s._whereclause = _literal_as_text(whereclause)
- return s
+ self._whereclause = _literal_as_text(whereclause)
def _copy_internals(self, clone=_clone):
self._whereclause = clone(self._whereclause)
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index c7a0f142d..87901b266 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -1,11 +1,11 @@
from sqlalchemy import types as sqltypes
from sqlalchemy.sql.expression import (
- ClauseList, _FigureVisitName, _Function, _literal_as_binds, text
+ ClauseList, _Function, _literal_as_binds, text
)
from sqlalchemy.sql import operators
+from sqlalchemy.sql.visitors import VisitableType
-
-class _GenericMeta(_FigureVisitName):
+class _GenericMeta(VisitableType):
def __init__(cls, clsname, bases, dict):
cls.__visit_name__ = 'function'
type.__init__(cls, clsname, bases, dict)
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 3f06535de..371c6ec4b 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -1,10 +1,39 @@
from collections import deque
+import re
+from sqlalchemy import util
+
+class VisitableType(type):
+ def __init__(cls, clsname, bases, dict):
+ if not '__visit_name__' in cls.__dict__:
+ m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
+ x = m.group(1)
+ x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
+ cls.__visit_name__ = x.lower()
+
+ # set up an optimized visit dispatch function
+ # for use by the compiler
+ visit_name = cls.__dict__["__visit_name__"]
+ if isinstance(visit_name, str):
+ func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
+ " return visitor.visit_%s(self, **kw)" % visit_name
+ else:
+ func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
+ " return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)"
+
+ env = locals().copy()
+ exec func_text in env
+ cls._compiler_dispatch = env['_compiler_dispatch']
+
+ super(VisitableType, cls).__init__(clsname, bases, dict)
+
+class Visitable(object):
+ __metaclass__ = VisitableType
class ClauseVisitor(object):
__traverse_options__ = {}
def traverse_single(self, obj):
- for v in self._iterate_visitors:
+ for v in self._visitor_iterator:
meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth:
return meth(obj)
@@ -17,29 +46,33 @@ class ClauseVisitor(object):
def traverse(self, obj):
"""traverse and visit the given expression structure."""
+ return traverse(obj, self.__traverse_options__, self._visitor_dict)
+
+ @util.memoized_property
+ def _visitor_dict(self):
visitors = {}
for name in dir(self):
if name.startswith('visit_'):
visitors[name[6:]] = getattr(self, name)
-
- return traverse(obj, self.__traverse_options__, visitors)
-
- def _iterate_visitors(self):
+ return visitors
+
+ @property
+ def _visitor_iterator(self):
"""iterate through this visitor and each 'chained' visitor."""
v = self
while v:
yield v
v = getattr(v, '_next', None)
- _iterate_visitors = property(_iterate_visitors)
def chain(self, visitor):
"""'chain' an additional ClauseVisitor onto this ClauseVisitor.
the chained visitor will receive all visit events after this one.
+
"""
- tail = list(self._iterate_visitors)[-1]
+ tail = list(self._visitor_iterator)[-1]
tail._next = visitor
return self
@@ -52,13 +85,7 @@ class CloningVisitor(ClauseVisitor):
def traverse(self, obj):
"""traverse and visit the given expression structure."""
- visitors = {}
-
- for name in dir(self):
- if name.startswith('visit_'):
- visitors[name[6:]] = getattr(self, name)
-
- return cloned_traverse(obj, self.__traverse_options__, visitors)
+ return cloned_traverse(obj, self.__traverse_options__, self._visitor_dict)
class ReplacingCloningVisitor(CloningVisitor):
def replace(self, elem):
@@ -74,7 +101,7 @@ class ReplacingCloningVisitor(CloningVisitor):
"""traverse and visit the given expression structure."""
def replace(elem):
- for v in self._iterate_visitors:
+ for v in self._visitor_iterator:
e = v.replace(elem)
if e:
return e
diff --git a/test/profiling/zoomark.py b/test/profiling/zoomark.py
index 6cf0771c4..41f6cdc5d 100644
--- a/test/profiling/zoomark.py
+++ b/test/profiling/zoomark.py
@@ -324,7 +324,7 @@ class ZooMarkTest(TestBase):
def test_profile_1_create_tables(self):
self.test_baseline_1_create_tables()
- @profiling.function_call_count(5726, {'2.4': 3844})
+ @profiling.function_call_count(5726, {'2.4': 3650})
def test_profile_1a_populate(self):
self.test_baseline_1a_populate()