diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-03-11 20:52:02 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-03-11 20:52:02 +0000 |
commit | 6a3c374b955299f0065356ef1de6cc0920d5382e (patch) | |
tree | 1ec2c2fddcc2d3c8b8f350fb42f86a84918c6fe1 /lib/sqlalchemy/schema.py | |
parent | 320cb9b75f763355ed798c80d245998ce57e21cc (diff) | |
download | sqlalchemy-6a3c374b955299f0065356ef1de6cc0920d5382e.tar.gz |
- for hackers, refactored the "visitor" system of ClauseElement and
SchemaItem so that the traversal of items is controlled by the
ClauseVisitor itself, using the method visitor.traverse(item).
accept_visitor() methods can still be called directly but will
not do any traversal of child items. ClauseElement/SchemaItem now
have a configurable get_children() method to return the collection
of child elements for each parent object. This allows the full
traversal of items to be clear and unambiguous (as well as loggable),
with an easy method of limiting a traversal (just pass flags which
are picked up by appropriate get_children() methods). [ticket:501]
- accept_schema_visitor() methods removed, replaced with
get_children(schema_visitor=True)
- various docstring/changelog cleanup/reformatting
Diffstat (limited to 'lib/sqlalchemy/schema.py')
-rw-r--r-- | lib/sqlalchemy/schema.py | 61 |
1 files changed, 31 insertions, 30 deletions
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 52324e63e..78c31e9ac 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -42,7 +42,11 @@ class SchemaItem(object): """Associate with this SchemaItem's parent object.""" raise NotImplementedError() - + + def get_children(self, **kwargs): + """used to allow SchemaVisitor access""" + return [] + def __repr__(self): return "%s()" % self.__class__.__name__ @@ -322,11 +326,14 @@ class Table(SchemaItem, sql.TableClause): metadata.tables[_get_table_key(self.name, self.schema)] = self self._metadata = metadata - def accept_schema_visitor(self, visitor, traverse=True): - if traverse: - for c in self.columns: - c.accept_schema_visitor(visitor, True) - return visitor.visit_table(self) + def get_children(self, column_collections=True, schema_visitor=False, **kwargs): + if not schema_visitor: + return sql.TableClause.get_children(self, column_collections=column_collections, **kwargs) + else: + if column_collections: + return [c for c in self.columns] + else: + return [] def exists(self, connectable=None): """Return True if this table exists.""" @@ -604,20 +611,12 @@ class Column(SchemaItem, sql._ColumnClause): return self.__originating_column._get_case_sensitive() case_sensitive = property(_case_sens, lambda s,v:None) - def accept_schema_visitor(self, visitor, traverse=True): - """Traverse the given visitor to this ``Column``'s default and foreign key object, - then call `visit_column` on the visitor.""" - - if traverse: - if self.default is not None: - self.default.accept_schema_visitor(visitor, traverse=True) - if self.onupdate is not None: - self.onupdate.accept_schema_visitor(visitor, traverse=True) - for f in self.foreign_keys: - f.accept_schema_visitor(visitor, traverse=True) - for constraint in self.constraints: - constraint.accept_schema_visitor(visitor, traverse=True) - visitor.visit_column(self) + def get_children(self, schema_visitor=False, **kwargs): + if schema_visitor: + return [x for x in (self.default, self.onupdate) if x is not None] + \ + list(self.foreign_keys) + list(self.constraints) + else: + return sql._ColumnClause.get_children(self, **kwargs) class ForeignKey(SchemaItem): @@ -715,7 +714,7 @@ class ForeignKey(SchemaItem): column = property(lambda s: s._init_column()) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): """Call the `visit_foreign_key` method on the given visitor.""" visitor.visit_foreign_key(self) @@ -771,7 +770,7 @@ class PassiveDefault(DefaultGenerator): super(PassiveDefault, self).__init__(**kwargs) self.arg = arg - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): return visitor.visit_passive_default(self) def __repr__(self): @@ -788,7 +787,7 @@ class ColumnDefault(DefaultGenerator): super(ColumnDefault, self).__init__(**kwargs) self.arg = arg - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): """Call the visit_column_default method on the given visitor.""" if self.for_update: @@ -828,7 +827,7 @@ class Sequence(DefaultGenerator): def drop(self): self.get_engine().drop(self) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): """Call the visit_seauence method on the given visitor.""" return visitor.visit_sequence(self) @@ -871,7 +870,7 @@ class CheckConstraint(Constraint): super(CheckConstraint, self).__init__(name) self.sqltext = sqltext - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): if isinstance(self.parent, Table): visitor.visit_check_constraint(self) else: @@ -904,7 +903,7 @@ class ForeignKeyConstraint(Constraint): for (c, r) in zip(self.__colnames, self.__refcolnames): self.append_element(c,r) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_foreign_key_constraint(self) def append_element(self, col, refcol): @@ -930,7 +929,7 @@ class PrimaryKeyConstraint(Constraint): for c in self.__colnames: self.append_column(table.c[c]) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_primary_key_constraint(self) def add(self, col): @@ -964,7 +963,7 @@ class UniqueConstraint(Constraint): def append_column(self, col): self.columns.add(col) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_unique_constraint(self) def copy(self): @@ -1042,7 +1041,7 @@ class Index(SchemaItem): else: self.get_engine().drop(self) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_index(self) def __str__(self): @@ -1118,7 +1117,7 @@ class MetaData(SchemaItem): connectable = self.get_engine() connectable.drop(self, checkfirst=checkfirst, tables=tables) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_metadata(self) def _derived_metadata(self): @@ -1190,6 +1189,8 @@ class DynamicMetaData(MetaData): class SchemaVisitor(sql.ClauseVisitor): """Define the visiting for ``SchemaItem`` objects.""" + __traverse_options__ = {'schema_visitor':True} + def visit_schema(self, schema): """Visit a generic ``SchemaItem``.""" pass |