summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/schema.py')
-rw-r--r--lib/sqlalchemy/schema.py61
1 files changed, 31 insertions, 30 deletions
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 52324e63e..78c31e9ac 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -42,7 +42,11 @@ class SchemaItem(object):
"""Associate with this SchemaItem's parent object."""
raise NotImplementedError()
-
+
+ def get_children(self, **kwargs):
+ """used to allow SchemaVisitor access"""
+ return []
+
def __repr__(self):
return "%s()" % self.__class__.__name__
@@ -322,11 +326,14 @@ class Table(SchemaItem, sql.TableClause):
metadata.tables[_get_table_key(self.name, self.schema)] = self
self._metadata = metadata
- def accept_schema_visitor(self, visitor, traverse=True):
- if traverse:
- for c in self.columns:
- c.accept_schema_visitor(visitor, True)
- return visitor.visit_table(self)
+ def get_children(self, column_collections=True, schema_visitor=False, **kwargs):
+ if not schema_visitor:
+ return sql.TableClause.get_children(self, column_collections=column_collections, **kwargs)
+ else:
+ if column_collections:
+ return [c for c in self.columns]
+ else:
+ return []
def exists(self, connectable=None):
"""Return True if this table exists."""
@@ -604,20 +611,12 @@ class Column(SchemaItem, sql._ColumnClause):
return self.__originating_column._get_case_sensitive()
case_sensitive = property(_case_sens, lambda s,v:None)
- def accept_schema_visitor(self, visitor, traverse=True):
- """Traverse the given visitor to this ``Column``'s default and foreign key object,
- then call `visit_column` on the visitor."""
-
- if traverse:
- if self.default is not None:
- self.default.accept_schema_visitor(visitor, traverse=True)
- if self.onupdate is not None:
- self.onupdate.accept_schema_visitor(visitor, traverse=True)
- for f in self.foreign_keys:
- f.accept_schema_visitor(visitor, traverse=True)
- for constraint in self.constraints:
- constraint.accept_schema_visitor(visitor, traverse=True)
- visitor.visit_column(self)
+ def get_children(self, schema_visitor=False, **kwargs):
+ if schema_visitor:
+ return [x for x in (self.default, self.onupdate) if x is not None] + \
+ list(self.foreign_keys) + list(self.constraints)
+ else:
+ return sql._ColumnClause.get_children(self, **kwargs)
class ForeignKey(SchemaItem):
@@ -715,7 +714,7 @@ class ForeignKey(SchemaItem):
column = property(lambda s: s._init_column())
- def accept_schema_visitor(self, visitor, traverse=True):
+ def accept_visitor(self, visitor):
"""Call the `visit_foreign_key` method on the given visitor."""
visitor.visit_foreign_key(self)
@@ -771,7 +770,7 @@ class PassiveDefault(DefaultGenerator):
super(PassiveDefault, self).__init__(**kwargs)
self.arg = arg
- def accept_schema_visitor(self, visitor, traverse=True):
+ def accept_visitor(self, visitor):
return visitor.visit_passive_default(self)
def __repr__(self):
@@ -788,7 +787,7 @@ class ColumnDefault(DefaultGenerator):
super(ColumnDefault, self).__init__(**kwargs)
self.arg = arg
- def accept_schema_visitor(self, visitor, traverse=True):
+ def accept_visitor(self, visitor):
"""Call the visit_column_default method on the given visitor."""
if self.for_update:
@@ -828,7 +827,7 @@ class Sequence(DefaultGenerator):
def drop(self):
self.get_engine().drop(self)
- def accept_schema_visitor(self, visitor, traverse=True):
+ def accept_visitor(self, visitor):
"""Call the visit_seauence method on the given visitor."""
return visitor.visit_sequence(self)
@@ -871,7 +870,7 @@ class CheckConstraint(Constraint):
super(CheckConstraint, self).__init__(name)
self.sqltext = sqltext
- def accept_schema_visitor(self, visitor, traverse=True):
+ def accept_visitor(self, visitor):
if isinstance(self.parent, Table):
visitor.visit_check_constraint(self)
else:
@@ -904,7 +903,7 @@ class ForeignKeyConstraint(Constraint):
for (c, r) in zip(self.__colnames, self.__refcolnames):
self.append_element(c,r)
- def accept_schema_visitor(self, visitor, traverse=True):
+ def accept_visitor(self, visitor):
visitor.visit_foreign_key_constraint(self)
def append_element(self, col, refcol):
@@ -930,7 +929,7 @@ class PrimaryKeyConstraint(Constraint):
for c in self.__colnames:
self.append_column(table.c[c])
- def accept_schema_visitor(self, visitor, traverse=True):
+ def accept_visitor(self, visitor):
visitor.visit_primary_key_constraint(self)
def add(self, col):
@@ -964,7 +963,7 @@ class UniqueConstraint(Constraint):
def append_column(self, col):
self.columns.add(col)
- def accept_schema_visitor(self, visitor, traverse=True):
+ def accept_visitor(self, visitor):
visitor.visit_unique_constraint(self)
def copy(self):
@@ -1042,7 +1041,7 @@ class Index(SchemaItem):
else:
self.get_engine().drop(self)
- def accept_schema_visitor(self, visitor, traverse=True):
+ def accept_visitor(self, visitor):
visitor.visit_index(self)
def __str__(self):
@@ -1118,7 +1117,7 @@ class MetaData(SchemaItem):
connectable = self.get_engine()
connectable.drop(self, checkfirst=checkfirst, tables=tables)
- def accept_schema_visitor(self, visitor, traverse=True):
+ def accept_visitor(self, visitor):
visitor.visit_metadata(self)
def _derived_metadata(self):
@@ -1190,6 +1189,8 @@ class DynamicMetaData(MetaData):
class SchemaVisitor(sql.ClauseVisitor):
"""Define the visiting for ``SchemaItem`` objects."""
+ __traverse_options__ = {'schema_visitor':True}
+
def visit_schema(self, schema):
"""Visit a generic ``SchemaItem``."""
pass