diff options
Diffstat (limited to 'lib/sqlalchemy/schema.py')
-rw-r--r-- | lib/sqlalchemy/schema.py | 55 |
1 files changed, 41 insertions, 14 deletions
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 24392b3d9..acce555ab 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -23,8 +23,17 @@ import copy, re, string __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] +class SchemaMeta(type): + """provides universal constructor arguments for all SchemaItems""" + def __call__(self, *args, **kwargs): + engine = kwargs.pop('engine', None) + obj = type.__call__(self, *args, **kwargs) + obj._engine = engine + return obj + class SchemaItem(object): """base class for items that define a database schema.""" + __metaclass__ = SchemaMeta def _init_items(self, *args): for item in args: if item is not None: @@ -34,7 +43,20 @@ class SchemaItem(object): raise NotImplementedError() def __repr__(self): return "%s()" % self.__class__.__name__ - + +class EngineMixin(object): + """a mixin for SchemaItems that provides an "engine" accessor.""" + def _derived_engine(self): + """subclasses override this method to return an AbstractEngine + bound to a parent item""" + return None + def _get_engine(self): + if self._engine is not None: + return self._engine + else: + return self._derived_engine() + engine = property(_get_engine) + def _get_table_key(engine, name, schema): if schema is not None and schema == engine.get_default_schema_name(): schema = None @@ -43,14 +65,12 @@ def _get_table_key(engine, name, schema): else: return schema + "." + name -class TableSingleton(type): +class TableSingleton(SchemaMeta): """a metaclass used by the Table object to provide singleton behavior.""" def __call__(self, name, engine=None, *args, **kwargs): try: - if not isinstance(engine, SchemaEngine): + if engine is not None and not isinstance(engine, SchemaEngine): args = [engine] + list(args) - engine = None - if engine is None: engine = default_engine name = str(name) # in case of incoming unicode schema = kwargs.get('schema', None) @@ -58,6 +78,10 @@ class TableSingleton(type): redefine = kwargs.pop('redefine', False) mustexist = kwargs.pop('mustexist', False) useexisting = kwargs.pop('useexisting', False) + if not engine: + table = type.__call__(self, name, engine, **kwargs) + table._init_items(*args) + return table key = _get_table_key(engine, name, schema) table = engine.tables[key] if len(args): @@ -440,15 +464,14 @@ class ForeignKey(SchemaItem): self.parent.foreign_key = self self.parent.table.foreign_keys.append(self) -class DefaultGenerator(SchemaItem): +class DefaultGenerator(SchemaItem, EngineMixin): """Base class for column "default" values.""" - def __init__(self, for_update=False, engine=None): + def __init__(self, for_update=False): self.for_update = for_update - self.engine = engine + def _derived_engine(self): + return self.column.table.engine def _set_parent(self, column): self.column = column - if self.engine is None: - self.engine = column.table.engine if self.for_update: self.column.onupdate = self else: @@ -509,7 +532,7 @@ class Sequence(DefaultGenerator): return visitor.visit_sequence(self) -class Index(SchemaItem): +class Index(SchemaItem, EngineMixin): """Represents an index of columns from a database table """ def __init__(self, name, *columns, **kw): @@ -530,7 +553,8 @@ class Index(SchemaItem): self.unique = kw.pop('unique', False) self._init_items(*columns) - engine = property(lambda s:s.table.engine) + def _derived_engine(self): + return self.table.engine def _init_items(self, *args): for column in args: self.append_column(column) @@ -570,18 +594,21 @@ class Index(SchemaItem): for c in self.columns]), (self.unique and ', unique=True') or '') -class SchemaEngine(object): +class SchemaEngine(sql.AbstractEngine): """a factory object used to create implementations for schema objects. This object is the ultimate base class for the engine.SQLEngine class.""" def __init__(self): # a dictionary that stores Table objects keyed off their name (and possibly schema name) self.tables = {} - def reflecttable(self, table): """given a table, will query the database and populate its Column and ForeignKey objects.""" raise NotImplementedError() + def schemagenerator(self, **params): + raise NotImplementedError() + def schemadropper(self, **params): + raise NotImplementedError() class SchemaVisitor(sql.ClauseVisitor): """defines the visiting for SchemaItem objects""" |