diff options
Diffstat (limited to 'lib/sqlalchemy/schema.py')
-rw-r--r-- | lib/sqlalchemy/schema.py | 61 |
1 files changed, 31 insertions, 30 deletions
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 52324e63e..78c31e9ac 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -42,7 +42,11 @@ class SchemaItem(object): """Associate with this SchemaItem's parent object.""" raise NotImplementedError() - + + def get_children(self, **kwargs): + """used to allow SchemaVisitor access""" + return [] + def __repr__(self): return "%s()" % self.__class__.__name__ @@ -322,11 +326,14 @@ class Table(SchemaItem, sql.TableClause): metadata.tables[_get_table_key(self.name, self.schema)] = self self._metadata = metadata - def accept_schema_visitor(self, visitor, traverse=True): - if traverse: - for c in self.columns: - c.accept_schema_visitor(visitor, True) - return visitor.visit_table(self) + def get_children(self, column_collections=True, schema_visitor=False, **kwargs): + if not schema_visitor: + return sql.TableClause.get_children(self, column_collections=column_collections, **kwargs) + else: + if column_collections: + return [c for c in self.columns] + else: + return [] def exists(self, connectable=None): """Return True if this table exists.""" @@ -604,20 +611,12 @@ class Column(SchemaItem, sql._ColumnClause): return self.__originating_column._get_case_sensitive() case_sensitive = property(_case_sens, lambda s,v:None) - def accept_schema_visitor(self, visitor, traverse=True): - """Traverse the given visitor to this ``Column``'s default and foreign key object, - then call `visit_column` on the visitor.""" - - if traverse: - if self.default is not None: - self.default.accept_schema_visitor(visitor, traverse=True) - if self.onupdate is not None: - self.onupdate.accept_schema_visitor(visitor, traverse=True) - for f in self.foreign_keys: - f.accept_schema_visitor(visitor, traverse=True) - for constraint in self.constraints: - constraint.accept_schema_visitor(visitor, traverse=True) - visitor.visit_column(self) + def get_children(self, schema_visitor=False, **kwargs): + if schema_visitor: + return [x for x in (self.default, self.onupdate) if x is not None] + \ + list(self.foreign_keys) + list(self.constraints) + else: + return sql._ColumnClause.get_children(self, **kwargs) class ForeignKey(SchemaItem): @@ -715,7 +714,7 @@ class ForeignKey(SchemaItem): column = property(lambda s: s._init_column()) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): """Call the `visit_foreign_key` method on the given visitor.""" visitor.visit_foreign_key(self) @@ -771,7 +770,7 @@ class PassiveDefault(DefaultGenerator): super(PassiveDefault, self).__init__(**kwargs) self.arg = arg - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): return visitor.visit_passive_default(self) def __repr__(self): @@ -788,7 +787,7 @@ class ColumnDefault(DefaultGenerator): super(ColumnDefault, self).__init__(**kwargs) self.arg = arg - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): """Call the visit_column_default method on the given visitor.""" if self.for_update: @@ -828,7 +827,7 @@ class Sequence(DefaultGenerator): def drop(self): self.get_engine().drop(self) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): """Call the visit_seauence method on the given visitor.""" return visitor.visit_sequence(self) @@ -871,7 +870,7 @@ class CheckConstraint(Constraint): super(CheckConstraint, self).__init__(name) self.sqltext = sqltext - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): if isinstance(self.parent, Table): visitor.visit_check_constraint(self) else: @@ -904,7 +903,7 @@ class ForeignKeyConstraint(Constraint): for (c, r) in zip(self.__colnames, self.__refcolnames): self.append_element(c,r) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_foreign_key_constraint(self) def append_element(self, col, refcol): @@ -930,7 +929,7 @@ class PrimaryKeyConstraint(Constraint): for c in self.__colnames: self.append_column(table.c[c]) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_primary_key_constraint(self) def add(self, col): @@ -964,7 +963,7 @@ class UniqueConstraint(Constraint): def append_column(self, col): self.columns.add(col) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_unique_constraint(self) def copy(self): @@ -1042,7 +1041,7 @@ class Index(SchemaItem): else: self.get_engine().drop(self) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_index(self) def __str__(self): @@ -1118,7 +1117,7 @@ class MetaData(SchemaItem): connectable = self.get_engine() connectable.drop(self, checkfirst=checkfirst, tables=tables) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_metadata(self) def _derived_metadata(self): @@ -1190,6 +1189,8 @@ class DynamicMetaData(MetaData): class SchemaVisitor(sql.ClauseVisitor): """Define the visiting for ``SchemaItem`` objects.""" + __traverse_options__ = {'schema_visitor':True} + def visit_schema(self, schema): """Visit a generic ``SchemaItem``.""" pass |