diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-10-25 19:44:21 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-10-25 19:44:21 +0000 |
commit | 76e8175971a5fa5abb7f7e1f85ccd3068f04aff5 (patch) | |
tree | a1284c646dbfe8fce1980be126205c7ee80eae3a | |
parent | 25e5157785ee9dd7a3bbb606b1f7642342936d18 (diff) | |
download | sqlalchemy-76e8175971a5fa5abb7f7e1f85ccd3068f04aff5.tar.gz |
- moved _FigureVisitName into visitiors.VisitorType, added Visitor base class to reduce dependencies
- implemented _generative decorator for select/update/insert/delete constructs
- other minutiae
-rw-r--r-- | lib/sqlalchemy/schema.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 200 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 57 | ||||
-rw-r--r-- | test/profiling/zoomark.py | 2 |
6 files changed, 138 insertions, 138 deletions
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 567b1d43a..9e8df78ac 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -40,10 +40,9 @@ __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'DefaulClause', 'FetchedValue', 'ColumnDefault', 'DDL'] -class SchemaItem(object): +class SchemaItem(visitors.Visitable): """Base class for items that define a database schema.""" - __metaclass__ = expression._FigureVisitName quote = None def _init_items(self, *args): @@ -87,7 +86,7 @@ def _get_table_key(name, schema): else: return schema + "." + name -class _TableSingleton(expression._FigureVisitName): +class _TableSingleton(visitors.VisitableType): """A metaclass used by the ``Table`` object to provide singleton behavior.""" def __call__(self, name, metadata, *args, **kwargs): @@ -2050,3 +2049,4 @@ def _bind_or_error(schemaitem): 'assign %s to enable implicit execution.') % (item, bindable) raise exc.UnboundExecutionError(msg) return bind + diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b4069e6fd..63557f24b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -20,7 +20,7 @@ is otherwise internal to SQLAlchemy. import string, re from sqlalchemy import schema, engine, util, exc -from sqlalchemy.sql import operators, functions, util as sql_util +from sqlalchemy.sql import operators, functions, util as sql_util, visitors from sqlalchemy.sql import expression as sql RESERVED_WORDS = set([ @@ -110,10 +110,9 @@ FUNCTIONS = { } -class _CompileLabel(object): +class _CompileLabel(visitors.Visitable): """lightweight label object which acts as an expression._Label.""" - __metaclass__ = sql._FigureVisitName __visit_name__ = 'label' __slots__ = 'element', 'name' diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index d937d0507..52291c487 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -29,7 +29,8 @@ import itertools, re from operator import attrgetter from sqlalchemy import util, exc -from sqlalchemy.sql import operators, visitors +from sqlalchemy.sql import operators +from sqlalchemy.sql.visitors import Visitable, cloned_traverse from sqlalchemy import types as sqltypes functions, schema, sql_util = None, None, None @@ -876,7 +877,7 @@ def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) def _is_literal(element): - return not isinstance(element, ClauseElement) + return not isinstance(element, Visitable) and not hasattr(element, '__clause_element__') def _from_objects(*elements, **kwargs): return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements]) @@ -890,7 +891,7 @@ def _labeled(element): def _literal_as_text(element): if hasattr(element, '__clause_element__'): return element.__clause_element__() - elif not isinstance(element, ClauseElement): + elif not isinstance(element, Visitable): return _TextClause(unicode(element)) else: return element @@ -898,7 +899,7 @@ def _literal_as_text(element): def _literal_as_column(element): if hasattr(element, '__clause_element__'): return element.__clause_element__() - elif not isinstance(element, ClauseElement): + elif not isinstance(element, Visitable): return literal_column(str(element)) else: return element @@ -906,7 +907,7 @@ def _literal_as_column(element): def _literal_as_binds(element, name=None, type_=None): if hasattr(element, '__clause_element__'): return element.__clause_element__() - elif not isinstance(element, ClauseElement): + elif not isinstance(element, Visitable): if element is None: return null() else: @@ -917,15 +918,18 @@ def _literal_as_binds(element, name=None, type_=None): def _no_literals(element): if hasattr(element, '__clause_element__'): return element.__clause_element__() - elif not isinstance(element, ClauseElement): - raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) + elif not isinstance(element, Visitable): + raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function " + "to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) else: return element def _corresponding_column_or_error(fromclause, column, require_embedded=False): c = fromclause.corresponding_column(column, require_embedded=require_embedded) if not c: - raise exc.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description)) + raise exc.InvalidRequestError("Given column '%s', attached to table '%s', " + "failed to locate a corresponding column from table '%s'" + % (column, getattr(column, 'table', None), fromclause.description)) return c def _selectable(element): @@ -934,39 +938,15 @@ def _selectable(element): elif isinstance(element, Selectable): return element else: - raise exc.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) + raise exc.ArgumentError("Object %r is not a Selectable and does not implement `__selectable__()`" % element) def is_column(col): """True if ``col`` is an instance of ``ColumnElement``.""" return isinstance(col, ColumnElement) -class _FigureVisitName(type): - def __init__(cls, clsname, bases, dict): - if not '__visit_name__' in cls.__dict__: - m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname) - x = m.group(1) - x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x) - cls.__visit_name__ = x.lower() - - # set up an optimized visit dispatch function - # for use by the compiler - visit_name = cls.__dict__["__visit_name__"] - if isinstance(visit_name, str): - func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\ - " return visitor.visit_%s(self, **kw)" % visit_name - else: - func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\ - " return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)" - env = locals().copy() - exec func_text in env - cls._compiler_dispatch = env['_compiler_dispatch'] - - super(_FigureVisitName, cls).__init__(clsname, bases, dict) - -class ClauseElement(object): +class ClauseElement(Visitable): """Base class for elements of a programmatically constructed SQL expression.""" - __metaclass__ = _FigureVisitName _annotations = {} supports_execution = False @@ -976,6 +956,7 @@ class ClauseElement(object): This method may be used by a generative API. Its also used as part of the "deep" copy afforded by a traversal that combines the _copy_internals() method. + """ c = self.__class__.__new__(self.__class__) c.__dict__ = self.__dict__.copy() @@ -1001,8 +982,8 @@ class ClauseElement(object): should be added to the ``FROM`` list of a query, when this ``ClauseElement`` is placed in the column clause of a ``Select`` statement. + """ - raise NotImplementedError(repr(self)) def _annotate(self, values): @@ -1049,7 +1030,7 @@ class ClauseElement(object): bind.value = kwargs[bind.key] if unique: bind._convert_to_unique() - return visitors.cloned_traverse(self, {}, {'bindparam':visit_bindparam}) + return cloned_traverse(self, {}, {'bindparam':visit_bindparam}) def compare(self, other): """Compare this ClauseElement to the given ClauseElement. @@ -2637,14 +2618,13 @@ class _ColumnClause(_Immutable, ColumnElement): rules applied regardless of case sensitive settings. the ``literal_column()`` function is usually used to create such a ``_ColumnClause``. + """ - def __init__(self, text, selectable=None, type_=None, is_literal=False): ColumnElement.__init__(self) self.key = self.name = text self.table = selectable self.type = sqltypes.to_instance(type_) - self.__label = None self.is_literal = is_literal @util.memoized_property @@ -2655,23 +2635,25 @@ class _ColumnClause(_Immutable, ColumnElement): def _label(self): if self.is_literal: return None - if not self.__label: - if self.table and self.table.named_with_column: - if getattr(self.table, 'schema', None): - self.__label = self.table.schema + "_" + self.table.name + "_" + self.name - else: - self.__label = self.table.name + "_" + self.name - - if self.__label in self.table.c: - label = self.__label - counter = 1 - while label in self.table.c: - label = self.__label + "_" + str(counter) - counter += 1 - self.__label = label + + elif self.table and self.table.named_with_column: + if getattr(self.table, 'schema', None): + label = self.table.schema + "_" + self.table.name + "_" + self.name else: - self.__label = self.name - return self.__label + label = self.table.name + "_" + self.name + + if label in self.table.c: + # TODO: coverage does not seem to be present for this + _label = label + counter = 1 + while _label in self.table.c: + _label = label + "_" + str(counter) + counter += 1 + label = _label + return label + + else: + return self.name def label(self, name): if name is None: @@ -2723,7 +2705,7 @@ class TableClause(_Immutable, FromClause): def _export_columns(self): raise NotImplementedError() - @property + @util.memoized_property def description(self): return self.name.encode('ascii', 'backslashreplace') @@ -2756,6 +2738,14 @@ class TableClause(_Immutable, FromClause): def _get_from_objects(self, **modifiers): return [self] +@util.decorator +def _generative(fn, *args, **kw): + """Mark a method as generative.""" + + self = args[0]._generate() + fn(self, *args[1:], **kw) + return self + class _SelectBaseMixin(object): """Base class for ``Select`` and ``CompoundSelects``.""" @@ -2784,6 +2774,7 @@ class _SelectBaseMixin(object): """ return _ScalarSelect(self) + @_generative def apply_labels(self): """return a new selectable with the 'use_labels' flag set to True. @@ -2793,9 +2784,7 @@ class _SelectBaseMixin(object): among the individual FROM clauses. """ - s = self._generate() - s.use_labels = True - return s + self.use_labels = True def label(self, name): """return a 'scalar' representation of this selectable, embedded as a subquery @@ -2806,12 +2795,11 @@ class _SelectBaseMixin(object): """ return self.as_scalar().label(name) + @_generative def autocommit(self): """return a new selectable with the 'autocommit' flag set to True.""" - s = self._generate() - s._autocommit = True - return s + self._autocommit = True def _generate(self): s = self.__class__.__new__(self.__class__) @@ -2819,39 +2807,35 @@ class _SelectBaseMixin(object): s._reset_exported() return s + @_generative def limit(self, limit): """return a new selectable with the given LIMIT criterion applied.""" - s = self._generate() - s._limit = limit - return s + self._limit = limit + @_generative def offset(self, offset): """return a new selectable with the given OFFSET criterion applied.""" - s = self._generate() - s._offset = offset - return s + self._offset = offset + @_generative def order_by(self, *clauses): """return a new selectable with the given list of ORDER BY criterion applied. The criterion will be appended to any pre-existing ORDER BY criterion. """ - s = self._generate() - s.append_order_by(*clauses) - return s + self.append_order_by(*clauses) + @_generative def group_by(self, *clauses): """return a new selectable with the given list of GROUP BY criterion applied. The criterion will be appended to any pre-existing GROUP BY criterion. """ - s = self._generate() - s.append_group_by(*clauses) - return s + self.append_group_by(*clauses) def append_order_by(self, *clauses): """Append the given ORDER BY criterion applied to this selectable. @@ -3112,72 +3096,67 @@ class Select(_SelectBaseMixin, FromClause): self._raw_columns + list(self._froms) + \ [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] + @_generative def column(self, column): """return a new select() construct with the given column expression added to its columns clause.""" - s = self._generate() column = _literal_as_column(column) if isinstance(column, _ScalarSelect): column = column.self_group(against=operators.comma_op) - s._raw_columns = s._raw_columns + [column] - s._froms = s._froms.union(_from_objects(column)) - return s + self._raw_columns = self._raw_columns + [column] + self._froms = self._froms.union(_from_objects(column)) + @_generative def with_only_columns(self, columns): """return a new select() construct with its columns clause replaced with the given columns.""" - s = self._generate() - s._raw_columns = [ + + self._raw_columns = [ isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c for c in [_literal_as_column(c) for c in columns] ] - return s + @_generative def where(self, whereclause): """return a new select() construct with the given expression added to its WHERE clause, joined to the existing clause via AND, if any.""" - s = self._generate() - s.append_whereclause(whereclause) - return s + self.append_whereclause(whereclause) + @_generative def having(self, having): """return a new select() construct with the given expression added to its HAVING clause, joined to the existing clause via AND, if any.""" - s = self._generate() - s.append_having(having) - return s + self.append_having(having) + @_generative def distinct(self): """return a new select() construct which will apply DISTINCT to its columns clause.""" - s = self._generate() - s._distinct = True - return s + self._distinct = True + @_generative def prefix_with(self, clause): """return a new select() construct which will apply the given expression to the start of its columns clause, not using any commas.""" - s = self._generate() clause = _literal_as_text(clause) - s._prefixes = s._prefixes + [clause] - return s + self._prefixes = self._prefixes + [clause] + @_generative def select_from(self, fromclause): """return a new select() construct with the given FROM expression applied to its list of FROM objects.""" - s = self._generate() if _is_literal(fromclause): fromclause = _TextClause(fromclause) - s._froms = s._froms.union([fromclause]) - return s + self._froms = self._froms.union([fromclause]) + @_generative def correlate(self, *fromclauses): """return a new select() construct which will correlate the given FROM clauses to that of an enclosing select(), if a match is found. @@ -3192,13 +3171,11 @@ class Select(_SelectBaseMixin, FromClause): If the fromclause is None, correlation is disabled for the returned select(). """ - s = self._generate() - s._should_correlate = False + self._should_correlate = False if fromclauses == (None,): - s._correlate = set() + self._correlate = set() else: - s._correlate = s._correlate.union(fromclauses) - return s + self._correlate = self._correlate.union(fromclauses) def append_correlation(self, fromclause): """append the given correlation expression to this select() construct.""" @@ -3416,16 +3393,15 @@ class Insert(_ValuesBase): def _copy_internals(self, clone=_clone): self.parameters = self.parameters.copy() + @_generative def prefix_with(self, clause): """Add a word or expression between INSERT and INTO. Generative. If multiple prefixes are supplied, they will be separated with spaces. """ - gen = self._generate() clause = _literal_as_text(clause) - gen._prefixes = self._prefixes + [clause] - return gen + self._prefixes = self._prefixes + [clause] class Update(_ValuesBase): def __init__(self, table, whereclause, values=None, inline=False, bind=None, **kwargs): @@ -3450,16 +3426,15 @@ class Update(_ValuesBase): self._whereclause = clone(self._whereclause) self.parameters = self.parameters.copy() + @_generative def where(self, whereclause): """return a new update() construct with the given expression added to its WHERE clause, joined to the existing clause via AND, if any.""" - s = self._generate() - if s._whereclause is not None: - s._whereclause = and_(s._whereclause, _literal_as_text(whereclause)) + if self._whereclause is not None: + self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) else: - s._whereclause = _literal_as_text(whereclause) - return s + self._whereclause = _literal_as_text(whereclause) class Delete(_UpdateBase): @@ -3479,16 +3454,15 @@ class Delete(_UpdateBase): else: return () + @_generative def where(self, whereclause): """return a new delete() construct with the given expression added to its WHERE clause, joined to the existing clause via AND, if any.""" - s = self._generate() - if s._whereclause is not None: - s._whereclause = and_(s._whereclause, _literal_as_text(whereclause)) + if self._whereclause is not None: + self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) else: - s._whereclause = _literal_as_text(whereclause) - return s + self._whereclause = _literal_as_text(whereclause) def _copy_internals(self, clone=_clone): self._whereclause = clone(self._whereclause) diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index c7a0f142d..87901b266 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1,11 +1,11 @@ from sqlalchemy import types as sqltypes from sqlalchemy.sql.expression import ( - ClauseList, _FigureVisitName, _Function, _literal_as_binds, text + ClauseList, _Function, _literal_as_binds, text ) from sqlalchemy.sql import operators +from sqlalchemy.sql.visitors import VisitableType - -class _GenericMeta(_FigureVisitName): +class _GenericMeta(VisitableType): def __init__(cls, clsname, bases, dict): cls.__visit_name__ = 'function' type.__init__(cls, clsname, bases, dict) diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 3f06535de..371c6ec4b 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -1,10 +1,39 @@ from collections import deque +import re +from sqlalchemy import util + +class VisitableType(type): + def __init__(cls, clsname, bases, dict): + if not '__visit_name__' in cls.__dict__: + m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname) + x = m.group(1) + x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x) + cls.__visit_name__ = x.lower() + + # set up an optimized visit dispatch function + # for use by the compiler + visit_name = cls.__dict__["__visit_name__"] + if isinstance(visit_name, str): + func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\ + " return visitor.visit_%s(self, **kw)" % visit_name + else: + func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\ + " return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)" + + env = locals().copy() + exec func_text in env + cls._compiler_dispatch = env['_compiler_dispatch'] + + super(VisitableType, cls).__init__(clsname, bases, dict) + +class Visitable(object): + __metaclass__ = VisitableType class ClauseVisitor(object): __traverse_options__ = {} def traverse_single(self, obj): - for v in self._iterate_visitors: + for v in self._visitor_iterator: meth = getattr(v, "visit_%s" % obj.__visit_name__, None) if meth: return meth(obj) @@ -17,29 +46,33 @@ class ClauseVisitor(object): def traverse(self, obj): """traverse and visit the given expression structure.""" + return traverse(obj, self.__traverse_options__, self._visitor_dict) + + @util.memoized_property + def _visitor_dict(self): visitors = {} for name in dir(self): if name.startswith('visit_'): visitors[name[6:]] = getattr(self, name) - - return traverse(obj, self.__traverse_options__, visitors) - - def _iterate_visitors(self): + return visitors + + @property + def _visitor_iterator(self): """iterate through this visitor and each 'chained' visitor.""" v = self while v: yield v v = getattr(v, '_next', None) - _iterate_visitors = property(_iterate_visitors) def chain(self, visitor): """'chain' an additional ClauseVisitor onto this ClauseVisitor. the chained visitor will receive all visit events after this one. + """ - tail = list(self._iterate_visitors)[-1] + tail = list(self._visitor_iterator)[-1] tail._next = visitor return self @@ -52,13 +85,7 @@ class CloningVisitor(ClauseVisitor): def traverse(self, obj): """traverse and visit the given expression structure.""" - visitors = {} - - for name in dir(self): - if name.startswith('visit_'): - visitors[name[6:]] = getattr(self, name) - - return cloned_traverse(obj, self.__traverse_options__, visitors) + return cloned_traverse(obj, self.__traverse_options__, self._visitor_dict) class ReplacingCloningVisitor(CloningVisitor): def replace(self, elem): @@ -74,7 +101,7 @@ class ReplacingCloningVisitor(CloningVisitor): """traverse and visit the given expression structure.""" def replace(elem): - for v in self._iterate_visitors: + for v in self._visitor_iterator: e = v.replace(elem) if e: return e diff --git a/test/profiling/zoomark.py b/test/profiling/zoomark.py index 6cf0771c4..41f6cdc5d 100644 --- a/test/profiling/zoomark.py +++ b/test/profiling/zoomark.py @@ -324,7 +324,7 @@ class ZooMarkTest(TestBase): def test_profile_1_create_tables(self): self.test_baseline_1_create_tables() - @profiling.function_call_count(5726, {'2.4': 3844}) + @profiling.function_call_count(5726, {'2.4': 3650}) def test_profile_1a_populate(self): self.test_baseline_1a_populate() |