diff options
Diffstat (limited to 'lib/sqlalchemy/schema.py')
-rw-r--r-- | lib/sqlalchemy/schema.py | 417 |
1 files changed, 178 insertions, 239 deletions
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 05753e424..5728d7c37 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -19,7 +19,7 @@ import sqlalchemy import copy, re, string __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint', - 'PrimaryKeyConstraint', + 'PrimaryKeyConstraint', 'CheckConstraint', 'UniqueConstraint', 'MetaData', 'BoundMetaData', 'DynamicMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] class SchemaItem(object): @@ -99,36 +99,33 @@ def _get_table_key(name, schema): class TableSingleton(type): """a metaclass used by the Table object to provide singleton behavior.""" def __call__(self, name, metadata, *args, **kwargs): + if isinstance(metadata, sql.Engine): + # backwards compatibility - get a BoundSchema associated with the engine + engine = metadata + if not hasattr(engine, '_legacy_metadata'): + engine._legacy_metadata = BoundMetaData(engine) + metadata = engine._legacy_metadata + elif metadata is not None and not isinstance(metadata, MetaData): + # they left MetaData out, so assume its another SchemaItem, add it to *args + args = list(args) + args.insert(0, metadata) + metadata = None + + if metadata is None: + metadata = default_metadata + + name = str(name) # in case of incoming unicode + schema = kwargs.get('schema', None) + autoload = kwargs.pop('autoload', False) + autoload_with = kwargs.pop('autoload_with', False) + mustexist = kwargs.pop('mustexist', False) + useexisting = kwargs.pop('useexisting', False) + key = _get_table_key(name, schema) try: - if isinstance(metadata, sql.Engine): - # backwards compatibility - get a BoundSchema associated with the engine - engine = metadata - if not hasattr(engine, '_legacy_metadata'): - engine._legacy_metadata = BoundMetaData(engine) - metadata = engine._legacy_metadata - elif metadata is not None and not isinstance(metadata, MetaData): - # they left MetaData out, so assume its another SchemaItem, add it to *args - args = list(args) - args.insert(0, metadata) - metadata = None - - if metadata is None: - metadata = default_metadata - - name = str(name) # in case of incoming unicode - schema = kwargs.get('schema', None) - autoload = kwargs.pop('autoload', False) - autoload_with = kwargs.pop('autoload_with', False) - redefine = kwargs.pop('redefine', False) - mustexist = kwargs.pop('mustexist', False) - useexisting = kwargs.pop('useexisting', False) - key = _get_table_key(name, schema) table = metadata.tables[key] if len(args): - if redefine: - table._reload_values(*args) - elif not useexisting: - raise exceptions.ArgumentError("Table '%s.%s' is already defined. specify 'redefine=True' to remap columns, or 'useexisting=True' to use the existing table" % (schema, name)) + if not useexisting: + raise exceptions.ArgumentError("Table '%s.%s' is already defined for this MetaData instance." % (schema, name)) return table except KeyError: if mustexist: @@ -145,7 +142,7 @@ class TableSingleton(type): else: metadata.get_engine().reflecttable(table) except exceptions.NoSuchTableError: - table.deregister() + del metadata.tables[key] raise # initialize all the column, etc. objects. done after # reflection to allow user-overrides @@ -210,8 +207,8 @@ class Table(SchemaItem, sql.TableClause): super(Table, self).__init__(name) self._metadata = metadata self.schema = kwargs.pop('schema', None) - self.indexes = util.OrderedProperties() - self.constraints = [] + self.indexes = util.Set() + self.constraints = util.Set() self.primary_key = PrimaryKeyConstraint() self.quote = kwargs.get('quote', False) self.quote_schema = kwargs.get('quote_schema', False) @@ -237,7 +234,7 @@ class Table(SchemaItem, sql.TableClause): if getattr(self, '_primary_key', None) in self.constraints: self.constraints.remove(self._primary_key) self._primary_key = pk - self.constraints.append(pk) + self.constraints.add(pk) primary_key = property(lambda s:s._primary_key, _set_primary_key) def _derived_metadata(self): @@ -251,93 +248,45 @@ class Table(SchemaItem, sql.TableClause): def __str__(self): return _get_table_key(self.name, self.schema) - - def _reload_values(self, *args): - """clear out the columns and other properties of this Table, and reload them from the - given argument list. This is used with the "redefine" keyword argument sent to the - metaclass constructor.""" - self._clear() - - self._init_items(*args) - def append_item(self, item): - """appends a Column item or other schema item to this Table.""" - self._init_items(item) - def append_column(self, column): - if not column.hidden: - self._columns[column.key] = column - if column.primary_key: - self.primary_key.append(column) - column.table = self + """append a Column to this Table.""" + column._set_parent(self) + def append_constraint(self, constraint): + """append a Constraint to this Table.""" + constraint._set_parent(self) - def append_index(self, index): - self.indexes[index.name] = index - def _get_parent(self): return self._metadata def _set_parent(self, metadata): metadata.tables[_get_table_key(self.name, self.schema)] = self self._metadata = metadata - def accept_schema_visitor(self, visitor): - """traverses the given visitor across the Column objects inside this Table, - then calls the visit_table method on the visitor.""" - for c in self.columns: - c.accept_schema_visitor(visitor) + def accept_schema_visitor(self, visitor, traverse=True): + if traverse: + for c in self.columns: + c.accept_schema_visitor(visitor, True) return visitor.visit_table(self) - def append_index_column(self, column, index=None, unique=None): - """Add an index or a column to an existing index of the same name. - """ - if index is not None and unique is not None: - raise ValueError("index and unique may not both be specified") - if index: - if index is True: - name = 'ix_%s' % column._label - else: - name = index - elif unique: - if unique is True: - name = 'ux_%s' % column._label - else: - name = unique - # find this index in self.indexes - # add this column to it if found - # otherwise create new - try: - index = self.indexes[name] - index.append_column(column) - except KeyError: - index = Index(name, column, unique=unique) - return index - - def deregister(self): - """remove this table from it's owning metadata. - - this does not issue a SQL DROP statement.""" - key = _get_table_key(self.name, self.schema) - del self.metadata.tables[key] - - def exists(self, engine=None): - if engine is None: - engine = self.get_engine() + def exists(self, connectable=None): + """return True if this table exists.""" + if connectable is None: + connectable = self.get_engine() def do(conn): e = conn.engine return e.dialect.has_table(conn, self.name) - return engine.run_callable(do) + return connectable.run_callable(do) def create(self, connectable=None, checkfirst=False): - if connectable is not None: - connectable.create(self, checkfirst=checkfirst) - else: - self.get_engine().create(self, checkfirst=checkfirst) - return self + """issue a CREATE statement for this table. + + see also metadata.create_all().""" + self.metadata.create_all(connectable=connectable, checkfirst=checkfirst, tables=[self]) def drop(self, connectable=None, checkfirst=False): - if connectable is not None: - connectable.drop(self, checkfirst=checkfirst) - else: - self.get_engine().drop(self, checkfirst=checkfirst) + """issue a DROP statement for this table. + + see also metadata.drop_all().""" + self.metadata.drop_all(connectable=connectable, checkfirst=checkfirst, tables=[self]) def tometadata(self, metadata, schema=None): """return a copy of this Table associated with a different MetaData.""" try: @@ -389,17 +338,16 @@ class Column(SchemaItem, sql.ColumnClause): table's list of columns. Used for the "oid" column, which generally isnt in column lists. - index=None : True or index name. Indicates that this column is - indexed. Pass true to autogenerate the index name. Pass a string to - specify the index name. Multiple columns that specify the same index - name will all be included in the index, in the order of their - creation. + index=False : Indicates that this column is + indexed. The name of the index is autogenerated. + to specify indexes with explicit names or indexes that contain multiple + columns, use the Index construct instead. - unique=None : True or index name. Indicates that this column is - indexed in a unique index . Pass true to autogenerate the index - name. Pass a string to specify the index name. Multiple columns that - specify the same index name will all be included in the index, in the - order of their creation. + unique=False : Indicates that this column + contains a unique constraint, or if index=True as well, indicates + that the Index should be created with the unique flag. + To specify multiple columns in the constraint/index or to specify an + explicit name, use the UniqueConstraint or Index constructs instead. autoincrement=True : Indicates that integer-based primary key columns should have autoincrementing behavior, if supported by the underlying database. This will affect CREATE TABLE statements such that they will @@ -430,9 +378,8 @@ class Column(SchemaItem, sql.ColumnClause): self._set_casing_strategy(name, kwargs) self.onupdate = kwargs.pop('onupdate', None) self.autoincrement = kwargs.pop('autoincrement', True) + self.constraints = util.Set() self.__originating_column = self - if self.index is not None and self.unique is not None: - raise exceptions.ArgumentError("Column may not define both index and unique") self._foreign_keys = util.Set() if len(kwargs): raise exceptions.ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys())) @@ -455,7 +402,10 @@ class Column(SchemaItem, sql.ColumnClause): return self.table.metadata def _get_engine(self): return self.table.engine - + + def append_foreign_key(self, fk): + fk._set_parent(self) + def __repr__(self): return "Column(%s)" % string.join( [repr(self.name)] + [repr(self.type)] + @@ -463,33 +413,33 @@ class Column(SchemaItem, sql.ColumnClause): ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default', 'onupdate']] , ',') - def append_item(self, item): - self._init_items(item) - - def _set_primary_key(self): - if self.primary_key: - return - self.primary_key = True - self.nullable = False - self.table.primary_key.append(self) - def _get_parent(self): return self.table + def _set_parent(self, table): if getattr(self, 'table', None) is not None: raise exceptions.ArgumentError("this Column already has a table!") - table.append_column(self) - if self.index or self.unique: - table.append_index_column(self, index=self.index, - unique=self.unique) - + if not self.hidden: + table._columns.add(self) + if self.primary_key: + table.primary_key.add(self) + self.table = table + + if self.index: + if isinstance(self.index, str): + raise exceptions.ArgumentError("The 'index' keyword argument on Column is boolean only. To create indexes with a specific name, append an explicit Index object to the Table's list of elements.") + Index('ix_%s' % self._label, self, unique=self.unique) + elif self.unique: + if isinstance(self.unique, str): + raise exceptions.ArgumentError("The 'unique' keyword argument on Column is boolean only. To create unique constraints or indexes with a specific name, append an explicit UniqueConstraint or Index object to the Table's list of elements.") + table.append_constraint(UniqueConstraint(self.key)) + + toinit = list(self.args) if self.default is not None: - self.default = ColumnDefault(self.default) - self._init_items(self.default) + toinit.append(ColumnDefault(self.default)) if self.onupdate is not None: - self.onupdate = ColumnDefault(self.onupdate, for_update=True) - self._init_items(self.onupdate) - self._init_items(*self.args) + toinit.append(ColumnDefault(self.onupdate, for_update=True)) + self._init_items(*toinit) self.args = None def copy(self): @@ -507,9 +457,9 @@ class Column(SchemaItem, sql.ColumnClause): c.orig_set = self.orig_set c.__originating_column = self.__originating_column if not c.hidden: - selectable.columns[c.key] = c + selectable.columns.add(c) if self.primary_key: - selectable.primary_key.append(c) + selectable.primary_key.add(c) [c._init_items(f) for f in fk] return c @@ -519,15 +469,18 @@ class Column(SchemaItem, sql.ColumnClause): return self.__originating_column._get_case_sensitive() case_sensitive = property(_case_sens) - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): """traverses the given visitor to this Column's default and foreign key object, then calls visit_column on the visitor.""" - if self.default is not None: - self.default.accept_schema_visitor(visitor) - if self.onupdate is not None: - self.onupdate.accept_schema_visitor(visitor) - for f in self.foreign_keys: - f.accept_schema_visitor(visitor) + if traverse: + if self.default is not None: + self.default.accept_schema_visitor(visitor, traverse=True) + if self.onupdate is not None: + self.onupdate.accept_schema_visitor(visitor, traverse=True) + for f in self.foreign_keys: + f.accept_schema_visitor(visitor, traverse=True) + for constraint in self.constraints: + constraint.accept_schema_visitor(visitor, traverse=True) visitor.visit_column(self) @@ -538,7 +491,7 @@ class ForeignKey(SchemaItem): One or more ForeignKey objects are used within a ForeignKeyConstraint object which represents the table-level constraint definition.""" - def __init__(self, column, constraint=None): + def __init__(self, column, constraint=None, use_alter=False): """Construct a new ForeignKey object. "column" can be a schema.Column object representing the relationship, @@ -553,6 +506,7 @@ class ForeignKey(SchemaItem): self._colspec = column self._column = None self.constraint = constraint + self.use_alter = use_alter def __repr__(self): return "ForeignKey(%s)" % repr(self._get_colspec()) @@ -611,7 +565,7 @@ class ForeignKey(SchemaItem): column = property(lambda s: s._init_column()) - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): """calls the visit_foreign_key method on the given visitor.""" visitor.visit_foreign_key(self) @@ -621,17 +575,13 @@ class ForeignKey(SchemaItem): self.parent = column if self.constraint is None and isinstance(self.parent.table, Table): - self.constraint = ForeignKeyConstraint([],[]) - self.parent.table.append_item(self.constraint) + self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter) + self.parent.table.append_constraint(self.constraint) self.constraint._append_fk(self) - # if a foreign key was already set up for the parent column, replace it with - # this one - #if self.parent.foreign_key is not None: - # self.parent.table.foreign_keys.remove(self.parent.foreign_key) - #self.parent.foreign_key = self self.parent.foreign_keys.add(self) self.parent.table.foreign_keys.add(self) + class DefaultGenerator(SchemaItem): """Base class for column "default" values.""" def __init__(self, for_update=False, metadata=None): @@ -661,7 +611,7 @@ class PassiveDefault(DefaultGenerator): def __init__(self, arg, **kwargs): super(PassiveDefault, self).__init__(**kwargs) self.arg = arg - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): return visitor.visit_passive_default(self) def __repr__(self): return "PassiveDefault(%s)" % repr(self.arg) @@ -672,7 +622,7 @@ class ColumnDefault(DefaultGenerator): def __init__(self, arg, **kwargs): super(ColumnDefault, self).__init__(**kwargs) self.arg = arg - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): """calls the visit_column_default method on the given visitor.""" if self.for_update: return visitor.visit_column_onupdate(self) @@ -704,57 +654,66 @@ class Sequence(DefaultGenerator): return self def drop(self): self.get_engine().drop(self) - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): """calls the visit_seauence method on the given visitor.""" return visitor.visit_sequence(self) class Constraint(SchemaItem): """represents a table-level Constraint such as a composite primary key, foreign key, or unique constraint. - Also follows list behavior with regards to the underlying set of columns.""" + Implements a hybrid of dict/setlike behavior with regards to the list of underying columns""" def __init__(self, name=None): self.name = name - self.columns = [] + self.columns = sql.ColumnCollection() def __contains__(self, x): return x in self.columns + def keys(self): + return self.columns.keys() def __add__(self, other): return self.columns + other def __iter__(self): return iter(self.columns) def __len__(self): return len(self.columns) - def __getitem__(self, index): - return self.columns[index] - def __setitem__(self, index, item): - self.columns[index] = item def copy(self): raise NotImplementedError() def _get_parent(self): return getattr(self, 'table', None) - + +class CheckConstraint(Constraint): + def __init__(self, sqltext, name=None): + super(CheckConstraint, self).__init__(name) + self.sqltext = sqltext + def accept_schema_visitor(self, visitor, traverse=True): + visitor.visit_check_constraint(self) + def _set_parent(self, parent): + self.parent = parent + parent.constraints.add(self) + class ForeignKeyConstraint(Constraint): """table-level foreign key constraint, represents a colleciton of ForeignKey objects.""" - def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None): + def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False): super(ForeignKeyConstraint, self).__init__(name) self.__colnames = columns self.__refcolnames = refcolumns - self.elements = [] + self.elements = util.Set() self.onupdate = onupdate self.ondelete = ondelete + self.use_alter = use_alter def _set_parent(self, table): self.table = table - table.constraints.append(self) + table.constraints.add(self) for (c, r) in zip(self.__colnames, self.__refcolnames): - self.append(c,r) - def accept_schema_visitor(self, visitor): + self.append_element(c,r) + def accept_schema_visitor(self, visitor, traverse=True): visitor.visit_foreign_key_constraint(self) - def append(self, col, refcol): + def append_element(self, col, refcol): fk = ForeignKey(refcol, constraint=self) fk._set_parent(self.table.c[col]) self._append_fk(fk) def _append_fk(self, fk): - self.columns.append(self.table.c[fk.parent.key]) - self.elements.append(fk) + self.columns.add(self.table.c[fk.parent.key]) + self.elements.add(fk) def copy(self): return ForeignKeyConstraint([x.parent.name for x in self.elements], [x._get_colspec() for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete) @@ -766,37 +725,37 @@ class PrimaryKeyConstraint(Constraint): self.table = table table.primary_key = self for c in self.__colnames: - self.append(table.c[c]) - def accept_schema_visitor(self, visitor): + self.append_column(table.c[c]) + def accept_schema_visitor(self, visitor, traverse=True): visitor.visit_primary_key_constraint(self) - def append(self, col): - # TODO: change "columns" to a key-sensitive set ? - for c in self.columns: - if c.key == col.key: - self.columns.remove(c) - self.columns.append(col) + def add(self, col): + self.append_column(col) + def append_column(self, col): + self.columns.add(col) col.primary_key=True def copy(self): return PrimaryKeyConstraint(name=self.name, *[c.key for c in self]) - + def __eq__(self, other): + return self.columns == other + class UniqueConstraint(Constraint): - def __init__(self, name=None, *columns): - super(Constraint, self).__init__(name) + def __init__(self, *columns, **kwargs): + super(UniqueConstraint, self).__init__(name=kwargs.pop('name', None)) self.__colnames = list(columns) def _set_parent(self, table): self.table = table - table.constraints.append(self) + table.constraints.add(self) for c in self.__colnames: - self.append(table.c[c]) - def append(self, col): - self.columns.append(col) - def accept_schema_visitor(self, visitor): + self.append_column(table.c[c]) + def append_column(self, col): + self.columns.add(col) + def accept_schema_visitor(self, visitor, traverse=True): visitor.visit_unique_constraint(self) class Index(SchemaItem): """Represents an index of columns from a database table """ - def __init__(self, name, *columns, **kw): + def __init__(self, name, *columns, **kwargs): """Constructs an index object. Arguments are: name : the name of the index @@ -811,7 +770,7 @@ class Index(SchemaItem): self.name = name self.columns = [] self.table = None - self.unique = kw.pop('unique', False) + self.unique = kwargs.pop('unique', False) self._init_items(*columns) def _derived_metadata(self): @@ -821,12 +780,15 @@ class Index(SchemaItem): self.append_column(column) def _get_parent(self): return self.table + def _set_parent(self, table): + self.table = table + table.indexes.add(self) + def append_column(self, column): # make sure all columns are from the same table # and no column is repeated if self.table is None: - self.table = column.table - self.table.append_index(self) + self._set_parent(column.table) elif column.table != self.table: # all columns muse be from same table raise exceptions.ArgumentError("All index columns must be from same table. " @@ -850,7 +812,7 @@ class Index(SchemaItem): connectable.drop(self) else: self.get_engine().drop(self) - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): visitor.visit_index(self) def __str__(self): return repr(self) @@ -863,7 +825,6 @@ class Index(SchemaItem): class MetaData(SchemaItem): """represents a collection of Tables and their associated schema constructs.""" def __init__(self, name=None, **kwargs): - # a dictionary that stores Table objects keyed off their name (and possibly schema name) self.tables = {} self.name = name self._set_casing_strategy(name, kwargs) @@ -871,11 +832,18 @@ class MetaData(SchemaItem): return False def clear(self): self.tables.clear() - def table_iterator(self, reverse=True): - return self._sort_tables(self.tables.values(), reverse=reverse) + + def table_iterator(self, reverse=True, tables=None): + import sqlalchemy.sql_util + if tables is None: + tables = self.tables.values() + else: + tables = util.Set(tables).intersection(self.tables.values()) + sorter = sqlalchemy.sql_util.TableCollection(list(tables)) + return iter(sorter.sort(reverse=reverse)) def _get_parent(self): return None - def create_all(self, connectable=None, tables=None, engine=None): + def create_all(self, connectable=None, tables=None, checkfirst=True): """create all tables stored in this metadata. This will conditionally create tables depending on if they do not yet @@ -884,28 +852,13 @@ class MetaData(SchemaItem): connectable - a Connectable used to access the database; or use the engine bound to this MetaData. - tables - optional list of tables to create - - engine - deprecated argument.""" - if not tables: - tables = self.tables.values() - - if connectable is None: - connectable = engine - + tables - optional list of tables, which is a subset of the total + tables in the MetaData (others are ignored)""" if connectable is None: connectable = self.get_engine() - - def do(conn): - e = conn.engine - ts = self._sort_tables( tables ) - for table in ts: - if e.dialect.has_table(conn, table.name): - continue - conn.create(table) - connectable.run_callable(do) + connectable.create(self, checkfirst=checkfirst, tables=tables) - def drop_all(self, connectable=None, tables=None, engine=None): + def drop_all(self, connectable=None, tables=None, checkfirst=True): """drop all tables stored in this metadata. This will conditionally drop tables depending on if they currently @@ -914,33 +867,17 @@ class MetaData(SchemaItem): connectable - a Connectable used to access the database; or use the engine bound to this MetaData. - tables - optional list of tables to drop - - engine - deprecated argument.""" - if not tables: - tables = self.tables.values() - - if connectable is None: - connectable = engine - + tables - optional list of tables, which is a subset of the total + tables in the MetaData (others are ignored) + """ if connectable is None: connectable = self.get_engine() - - def do(conn): - e = conn.engine - ts = self._sort_tables( tables, reverse=True ) - for table in ts: - if e.dialect.has_table(conn, table.name): - conn.drop(table) - connectable.run_callable(do) + connectable.drop(self, checkfirst=checkfirst, tables=tables) - def _sort_tables(self, tables, reverse=False): - import sqlalchemy.sql_util - sorter = sqlalchemy.sql_util.TableCollection() - for t in tables: - sorter.add(t) - return sorter.sort(reverse=reverse) - + + def accept_schema_visitor(self, visitor, traverse=True): + visitor.visit_metadata(self) + def _derived_metadata(self): return self def _get_engine(self): @@ -1029,6 +966,8 @@ class SchemaVisitor(sql.ClauseVisitor): pass def visit_unique_constraint(self, constraint): pass + def visit_check_constraint(self, constraint): + pass default_metadata = DynamicMetaData('default') |