diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-07-14 20:06:09 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-07-14 20:06:09 +0000 |
commit | bc6fbfa84ab6e1e9639e00cc23b3c41ab1d30dc1 (patch) | |
tree | 41cbfd1293b4413890d372b76f31209b1c793d09 /lib | |
parent | e58578cb4b5e96c2c99e84f6f67a773d168b8bd1 (diff) | |
download | sqlalchemy-bc6fbfa84ab6e1e9639e00cc23b3c41ab1d30dc1.tar.gz |
overhaul to schema, addition of ForeignKeyConstraint/
PrimaryKeyConstraint objects (also UniqueConstraint not
completed yet). table creation and reflection modified
to be more oriented towards these new table-level objects.
reflection for sqlite/postgres/mysql supports composite
foreign keys; oracle/mssql/firebird not converted yet.
Diffstat (limited to 'lib')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 34 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/firebird.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/information_schema.py | 56 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 31 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 39 | ||||
-rw-r--r-- | lib/sqlalchemy/schema.py | 171 | ||||
-rw-r--r-- | lib/sqlalchemy/sql.py | 8 |
10 files changed, 226 insertions, 139 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 78017bc91..5d01e275c 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -602,7 +602,7 @@ class ANSICompiler(sql.Compiled): class ANSISchemaGenerator(engine.SchemaIterator): - def get_column_specification(self, column, override_pk=False, first_pk=False): + def get_column_specification(self, column, first_pk=False): raise NotImplementedError() def visit_table(self, table): @@ -614,19 +614,17 @@ class ANSISchemaGenerator(engine.SchemaIterator): separator = "\n" # if only one primary key, specify it along with the column - pks = table.primary_key first_pk = False for column in table.columns: self.append(separator) separator = ", \n" - self.append("\t" + self.get_column_specification(column, override_pk=len(pks)>1, first_pk=column.primary_key and not first_pk)) + self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) if column.primary_key: first_pk = True - # if multiple primary keys, specify it at the bottom - if len(pks) > 1: - self.append(", \n") - self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in pks],', ')) - + + for constraint in table.constraints: + constraint.accept_schema_visitor(self) + self.append("\n)%s\n\n" % self.post_create_table(table)) self.execute() if hasattr(table, 'indexes'): @@ -650,6 +648,26 @@ class ANSISchemaGenerator(engine.SchemaIterator): compiler.compile() return compiler + def visit_primary_key_constraint(self, constraint): + if len(constraint) == 0: + return + self.append(", \n") + self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in constraint],', ')) + + def visit_foreign_key_constraint(self, constraint): + self.append(", \n\t ") + if constraint.name is not None: + self.append("CONSTRAINT %s " % constraint.name) + self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + string.join([f.parent.name for f in constraint.elements], ', '), + list(constraint.elements)[0].column.table.name, + string.join([f.column.name for f in constraint.elements], ', ') + )) + if constraint.ondelete is not None: + self.append(" ON DELETE %s" % constraint.ondelete) + if constraint.onupdate is not None: + self.append(" ON UPDATE %s" % constraint.onupdate) + def visit_column(self, column): pass diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 0039333d5..085d8cf44 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -293,7 +293,7 @@ class FBCompiler(ansisql.ANSICompiler): return "" class FBSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column, override_pk=False, **kwargs): + def get_column_specification(self, column, **kwargs): colspec = column.name colspec += " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) @@ -302,10 +302,6 @@ class FBSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec += " NOT NULL" - if column.primary_key and not override_pk: - colspec += " PRIMARY KEY" - if column.foreign_key: - colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) return colspec def visit_sequence(self, sequence): diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index 08236f799..296db2de5 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -54,6 +54,7 @@ pg_key_constraints = schema.Table("key_column_usage", ischema, Column("table_name", String), Column("column_name", String), Column("constraint_name", String), + Column("ordinal_position", Integer), schema="information_schema") #mysql_key_constraints = schema.Table("key_column_usage", ischema, @@ -100,13 +101,9 @@ class ISchema(object): return self.cache[name] -def reflecttable(connection, table, ischema_names, use_mysql=False): +def reflecttable(connection, table, ischema_names): - if use_mysql: - # no idea which INFORMATION_SCHEMA spec is correct, mysql or postgres - key_constraints = mysql_key_constraints - else: - key_constraints = pg_key_constraints + key_constraints = pg_key_constraints if table.schema is not None: current_schema = table.schema @@ -152,39 +149,50 @@ def reflecttable(connection, table, ischema_names, use_mysql=False): if not found_table: raise exceptions.NoSuchTableError(table.name) - s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True, from_obj=[constraints.join(column_constraints, column_constraints.c.constraint_name==constraints.c.constraint_name).join(key_constraints, key_constraints.c.constraint_name==column_constraints.c.constraint_name)]) - if not use_mysql: - s.append_column(column_constraints) - s.append_whereclause(constraints.c.table_name==table.name) - s.append_whereclause(constraints.c.table_schema==current_schema) - colmap = [constraints.c.constraint_type, key_constraints.c.column_name, column_constraints.c.table_schema, column_constraints.c.table_name, column_constraints.c.column_name] - else: - # this doesnt seem to pick up any foreign keys with mysql - s.append_whereclause(key_constraints.c.table_name==constraints.c.table_name) - s.append_whereclause(key_constraints.c.table_schema==constraints.c.table_schema) - s.append_whereclause(constraints.c.table_name==table.name) - s.append_whereclause(constraints.c.table_schema==current_schema) - colmap = [constraints.c.constraint_type, key_constraints.c.column_name, key_constraints.c.referenced_table_schema, key_constraints.c.referenced_table_name, key_constraints.c.referenced_column_name] + # we are relying on the natural ordering of the constraint_column_usage table to return the referenced columns + # in an order that corresponds to the ordinal_position in the key_constraints table, otherwise composite foreign keys + # wont reflect properly. dont see a way around this based on whats available from information_schema + s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True, from_obj=[constraints.join(column_constraints, column_constraints.c.constraint_name==constraints.c.constraint_name).join(key_constraints, key_constraints.c.constraint_name==column_constraints.c.constraint_name)], order_by=[key_constraints.c.ordinal_position]) + s.append_column(column_constraints) + s.append_whereclause(constraints.c.table_name==table.name) + s.append_whereclause(constraints.c.table_schema==current_schema) + colmap = [constraints.c.constraint_type, key_constraints.c.column_name, column_constraints.c.table_schema, column_constraints.c.table_name, column_constraints.c.column_name, constraints.c.constraint_name, key_constraints.c.ordinal_position] c = connection.execute(s) + fks = {} while True: row = c.fetchone() if row is None: break -# continue - (type, constrained_column, referred_schema, referred_table, referred_column) = ( + (type, constrained_column, referred_schema, referred_table, referred_column, constraint_name, ordinal_position) = ( row[colmap[0]], row[colmap[1]], row[colmap[2]], row[colmap[3]], - row[colmap[4]] + row[colmap[4]], + row[colmap[5]], + row[colmap[6]] ) #print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column) if type=='PRIMARY KEY': table.c[constrained_column]._set_primary_key() elif type=='FOREIGN KEY': + try: + fk = fks[constraint_name] + except KeyError: + fk = ([],[]) + fks[constraint_name] = fk if current_schema == referred_schema: referred_schema = table.schema - remotetable = Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema) - table.c[constrained_column].append_item(schema.ForeignKey(remotetable.c[referred_column])) + if referred_schema is not None: + refspec = ".".join([referred_schema, referred_table, referred_column]) + else: + refspec = ".".join([referred_table, referred_column]) + if constrained_column not in fk[0]: + fk[0].append(constrained_column) + if refspec not in fk[1]: + fk[1].append(refspec) + + for name, value in fks.iteritems(): + table.append_item(ForeignKeyConstraint(value[0], value[1], name=name)) diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index c297195ca..9d51d535d 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -511,7 +511,7 @@ class MSSQLCompiler(ansisql.ANSICompiler): class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column, override_pk=False, first_pk=False): + def get_column_specification(self, column, **kwargs): colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec() # install a IDENTITY Sequence if we have an implicit IDENTITY column @@ -528,12 +528,6 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default - - if column.primary_key: - if not override_pk: - colspec += " PRIMARY KEY" - if column.foreign_key: - colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.fullname, column.foreign_key.column.name) return colspec diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 997010f1c..1d587ff7c 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -309,8 +309,6 @@ class MySQLDialect(ansisql.ANSIDialect): break #print "row! " + repr(row) if not found_table: - tabletype, foreignkeyD = self.moretableinfo(connection, table=table) - table.kwargs['mysql_engine'] = tabletype found_table = True (name, type, nullable, primary_key, default) = (row[0], row[1], row[2] == 'YES', row[3] == 'PRI', row[4]) @@ -338,16 +336,15 @@ class MySQLDialect(ansisql.ANSIDialect): argslist = re.findall(r'(\d+)', args) coltype = coltype(*[int(a) for a in argslist], **kw) - arglist = [] - fkey = foreignkeyD.get(name) - if fkey is not None: - arglist.append(schema.ForeignKey(fkey)) - - table.append_item(schema.Column(name, coltype, *arglist, + table.append_item(schema.Column(name, coltype, **dict(primary_key=primary_key, nullable=nullable, default=default ))) + + tabletype = self.moretableinfo(connection, table=table) + table.kwargs['mysql_engine'] = tabletype + if not found_table: raise exceptions.NoSuchTableError(table.name) @@ -368,15 +365,15 @@ class MySQLDialect(ansisql.ANSIDialect): match = re.search(r'\b(?:TYPE|ENGINE)=(?P<ttype>.+)\b', desc[lastparen.start():], re.I) if match: tabletype = match.group('ttype') - foreignkeyD = {} - fkpat = (r'FOREIGN KEY\s*\(`?(?P<name>.+?)`?\)' - r'\s*REFERENCES\s*`?(?P<reftable>.+?)`?' - r'\s*\(`?(?P<refcol>.+?)`?\)' - ) + + fkpat = r'CONSTRAINT `(?P<name>.+?)` FOREIGN KEY \((?P<columns>.+?)\) REFERENCES `(?P<reftable>.+?)` \((?P<refcols>.+?)\)' for match in re.finditer(fkpat, desc): - foreignkeyD[match.group('name')] = match.group('reftable') + '.' + match.group('refcol') + columns = re.findall(r'`(.+?)`', match.group('columns')) + refcols = [match.group('reftable') + "." + x for x in re.findall(r'`(.+?)`', match.group('refcols'))] + constraint = schema.ForeignKeyConstraint(columns, refcols, name=match.group('name')) + table.append_item(constraint) - return (tabletype, foreignkeyD) + return tabletype class MySQLCompiler(ansisql.ANSICompiler): @@ -411,12 +408,8 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec += " NOT NULL" if column.primary_key: - if not override_pk: - colspec += " PRIMARY KEY" if not column.foreign_key and first_pk and isinstance(column.type, sqltypes.Integer): colspec += " AUTO_INCREMENT" - if column.foreign_key: - colspec += ", FOREIGN KEY (%s) REFERENCES %s(%s)" % (column.name, column.foreign_key.column.table.name, column.foreign_key.column.name) return colspec def post_create_table(self, table): diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index bf6c1fd8d..d184291fd 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -320,7 +320,7 @@ class OracleCompiler(ansisql.ANSICompiler): return "" class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column, override_pk=False, **kwargs): + def get_column_specification(self, column, **kwargs): colspec = column.name colspec += " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) @@ -329,10 +329,6 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec += " NOT NULL" - if column.primary_key and not override_pk: - colspec += " PRIMARY KEY" - if column.foreign_key: - colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) return colspec def visit_sequence(self, sequence): diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index de21bd570..decccba58 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -329,7 +329,7 @@ class PGCompiler(ansisql.ANSICompiler): class PGSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column, override_pk=False, **kwargs): + def get_column_specification(self, column, **kwargs): colspec = column.name if column.primary_key and not column.foreign_key and isinstance(column.type, sqltypes.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): colspec += " SERIAL" @@ -341,10 +341,6 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec += " NOT NULL" - if column.primary_key and not override_pk: - colspec += " PRIMARY KEY" - if column.foreign_key: - colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.fullname, column.foreign_key.column.name) return colspec def visit_sequence(self, sequence): diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index c07952ff2..c703cd81e 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -257,7 +257,7 @@ class SQLiteCompiler(ansisql.ANSICompiler): return ansisql.ANSICompiler.binary_operator_string(self, binary) class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column, override_pk=False, **kwargs): + def get_column_specification(self, column, **kwargs): colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) if default is not None: @@ -265,34 +265,17 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec += " NOT NULL" - if column.primary_key and not override_pk: - colspec += " PRIMARY KEY" - if column.foreign_key: - colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) return colspec - def visit_table(self, table): - """sqlite is going to create multi-primary keys with just a UNIQUE index.""" - self.append("\nCREATE TABLE " + table.fullname + "(") - - separator = "\n" - - have_pk = False - use_pks = len(table.primary_key) == 1 - for column in table.columns: - self.append(separator) - separator = ", \n" - self.append("\t" + self.get_column_specification(column, override_pk=not use_pks)) - - if len(table.primary_key) > 1: - self.append(", \n") - # put all PRIMARY KEYS in a UNIQUE index - self.append("\tUNIQUE (%s)" % string.join([c.name for c in table.primary_key],', ')) - - self.append("\n)\n\n") - self.execute() - if hasattr(table, 'indexes'): - for index in table.indexes: - self.visit_index(index) + # this doesnt seem to be needed, although i suspect older versions of sqlite might still + # not directly support composite primary keys + #def visit_primary_key_constraint(self, constraint): + # if len(constraint) > 1: + # self.append(", \n") + # # put all PRIMARY KEYS in a UNIQUE index + # self.append("\tUNIQUE (%s)" % string.join([c.name for c in constraint],', ')) + # else: + # super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint) + dialect = SQLiteDialect poolclass = pool.SingletonThreadPool diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 1df2d3005..dcd023fe9 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -18,23 +18,24 @@ from sqlalchemy import sql, types, exceptions,util import sqlalchemy import copy, re, string -__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', +__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint', + 'PrimaryKeyConstraint', 'MetaData', 'BoundMetaData', 'DynamicMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] class SchemaItem(object): """base class for items that define a database schema.""" def _init_items(self, *args): + """initialize the list of child items for this SchemaItem""" for item in args: if item is not None: item._set_parent(self) def _set_parent(self, parent): - """a child item attaches itself to its parent via this method.""" + """associate with this SchemaItem's parent object.""" raise NotImplementedError() def __repr__(self): return "%s()" % self.__class__.__name__ def _derived_metadata(self): - """subclasses override this method to return a the MetaData - to which this item is bound""" + """return the the MetaData to which this item is bound""" return None def _get_engine(self): return self._derived_metadata().engine @@ -77,7 +78,7 @@ class TableSingleton(type): table = metadata.tables[key] if len(args): if redefine: - table.reload_values(*args) + table._reload_values(*args) elif not useexisting: raise exceptions.ArgumentError("Table '%s.%s' is already defined. specify 'redefine=True' to remap columns, or 'useexisting=True' to use the existing table" % (schema, name)) return table @@ -109,7 +110,9 @@ class Table(SchemaItem, sql.TableClause): __metaclass__ = TableSingleton def __init__(self, name, metadata, **kwargs): - """Table objects can be constructed directly. The init method is actually called via + """Construct a Table. + + Table objects can be constructed directly. The init method is actually called via the TableSingleton metaclass. Arguments are: name : the name of this table, exactly as it appears, or will appear, in the database. @@ -141,11 +144,23 @@ class Table(SchemaItem, sql.TableClause): super(Table, self).__init__(name) self._metadata = metadata self.schema = kwargs.pop('schema', None) + self.indexes = util.OrderedProperties() + self.constraints = [] + self.primary_key = PrimaryKeyConstraint() + if self.schema is not None: self.fullname = "%s.%s" % (self.schema, self.name) else: self.fullname = self.name self.kwargs = kwargs + + def _set_primary_key(self, pk): + if getattr(self, '_primary_key', None) in self.constraints: + self.constraints.remove(self._primary_key) + self._primary_key = pk + self.constraints.append(pk) + primary_key = property(lambda s:s._primary_key, _set_primary_key) + def _derived_metadata(self): return self._metadata def __repr__(self): @@ -158,8 +173,8 @@ class Table(SchemaItem, sql.TableClause): def __str__(self): return _get_table_key(self.name, self.schema) - def reload_values(self, *args): - """clears out the columns and other properties of this Table, and reloads them from the + def _reload_values(self, *args): + """clear out the columns and other properties of this Table, and reload them from the given argument list. This is used with the "redefine" keyword argument sent to the metaclass constructor.""" self._clear() @@ -216,8 +231,9 @@ class Table(SchemaItem, sql.TableClause): return index def deregister(self): - """removes this table from it's metadata. this does not - issue a SQL DROP statement.""" + """remove this table from it's owning metadata. + + this does not issue a SQL DROP statement.""" key = _get_table_key(self.name, self.schema) del self.metadata.tables[key] def create(self, connectable=None): @@ -232,7 +248,7 @@ class Table(SchemaItem, sql.TableClause): else: self.engine.drop(self) def tometadata(self, metadata, schema=None): - """returns a singleton instance of this Table with a different Schema""" + """return a copy of this Table associated with a different MetaData.""" try: if schema is None: schema = self.schema @@ -242,6 +258,8 @@ class Table(SchemaItem, sql.TableClause): args = [] for c in self.columns: args.append(c.copy()) + for c in self.constraints: + args.append(c.copy()) return Table(self.name, metadata, schema=schema, *args) class Column(SchemaItem, sql.ColumnClause): @@ -362,13 +380,9 @@ class Column(SchemaItem, sql.ColumnClause): self._init_items(*self.args) self.args = None - def copy(self): + def copy(self): """creates a copy of this Column, unitialized""" - if self.foreign_key is None: - fk = None - else: - fk = self.foreign_key.copy() - return Column(self.name, self.type, fk, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden) + return Column(self.name, self.type, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden) def _make_proxy(self, selectable, name = None): """creates a copy of this Column, initialized the way this Column is""" @@ -401,23 +415,33 @@ class Column(SchemaItem, sql.ColumnClause): class ForeignKey(SchemaItem): - """defines a ForeignKey constraint between two columns. ForeignKey is - specified as an argument to a Column object.""" - def __init__(self, column): - """Constructs a new ForeignKey object. "column" can be a schema.Column - object representing the relationship, or just its string name given as - "tablename.columnname". schema can be specified as - "schema.tablename.columnname" """ + """defines a column-level ForeignKey constraint between two columns. + + ForeignKey is specified as an argument to a Column object. + + One or more ForeignKey objects are used within a ForeignKeyConstraint + object which represents the table-level constraint definition.""" + def __init__(self, column, constraint=None): + """Construct a new ForeignKey object. + + "column" can be a schema.Column object representing the relationship, + or just its string name given as "tablename.columnname". schema can be + specified as "schema.tablename.columnname" + + "constraint" is the owning ForeignKeyConstraint object, if any. if not given, + then a ForeignKeyConstraint will be automatically created and added to the parent table. + """ if isinstance(column, unicode): column = str(column) self._colspec = column self._column = None - + self.constraint = constraint + def __repr__(self): return "ForeignKey(%s)" % repr(self._get_colspec()) def copy(self): - """produces a copy of this ForeignKey object.""" + """produce a copy of this ForeignKey object.""" return ForeignKey(self._get_colspec()) def _get_colspec(self): @@ -462,6 +486,7 @@ class ForeignKey(SchemaItem): self._column = table.c[colname] else: self._column = self._colspec + return self._column column = property(lambda s: s._init_column()) @@ -472,8 +497,14 @@ class ForeignKey(SchemaItem): def _set_parent(self, column): self.parent = column - # if a foreign key was already set up for this, replace it with - # this one, including removing from the parent + + if self.constraint is None and isinstance(self.parent.table, Table): + self.constraint = ForeignKeyConstraint([],[]) + self.parent.table.append_item(self.constraint) + self.constraint._append_fk(self) + + # if a foreign key was already set up for the parent column, replace it with + # this one if self.parent.foreign_key is not None: self.parent.table.foreign_keys.remove(self.parent.foreign_key) self.parent.foreign_key = self @@ -551,7 +582,81 @@ class Sequence(DefaultGenerator): """calls the visit_seauence method on the given visitor.""" return visitor.visit_sequence(self) - +class Constraint(SchemaItem): + """represents a table-level Constraint such as a composite primary key, foreign key, or unique constraint. + + Also follows list behavior with regards to the underlying set of columns.""" + def __init__(self, name=None): + self.name = name + self.columns = [] + def __contains__(self, x): + return x in self.columns + def __add__(self, other): + return self.columns + other + def __iter__(self): + return iter(self.columns) + def __len__(self): + return len(self.columns) + def __getitem__(self, index): + return self.columns[index] + def copy(self): + raise NotImplementedError() + +class ForeignKeyConstraint(Constraint): + """table-level foreign key constraint, represents a colleciton of ForeignKey objects.""" + def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None): + super(ForeignKeyConstraint, self).__init__(name) + self.__colnames = columns + self.__refcolnames = refcolumns + self.elements = [] + self.onupdate = onupdate + self.ondelete = ondelete + def _set_parent(self, table): + self.table = table + table.constraints.append(self) + for (c, r) in zip(self.__colnames, self.__refcolnames): + self.append(c,r) + def accept_schema_visitor(self, visitor): + visitor.visit_foreign_key_constraint(self) + def append(self, col, refcol): + fk = ForeignKey(refcol, constraint=self) + fk._set_parent(self.table.c[col]) + self._append_fk(fk) + def _append_fk(self, fk): + self.columns.append(self.table.c[fk.parent.name]) + self.elements.append(fk) + def copy(self): + return ForeignKeyConstraint([x.parent.name for x in self.elements], [x._get_colspec() for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete) + +class PrimaryKeyConstraint(Constraint): + def __init__(self, *columns, **kwargs): + super(PrimaryKeyConstraint, self).__init__(name=kwargs.pop('name', None)) + self.__colnames = list(columns) + def _set_parent(self, table): + table.primary_key = self + for c in self.__colnames: + self.append(table.c[c]) + def accept_schema_visitor(self, visitor): + visitor.visit_primary_key_constraint(self) + def append(self, col): + self.columns.append(col) + col.primary_key=True + def copy(self): + return PrimaryKeyConstraint(name=self.name, *[c.name for c in self]) + +class UniqueConstraint(Constraint): + def __init__(self, name=None, *columns): + super(Constraint, self).__init__(name) + self.__colnames = list(columns) + def _set_parent(self, table): + table.constraints.append(self) + for c in self.__colnames: + self.append(table.c[c]) + def append(self, col): + self.columns.append(col) + def accept_schema_visitor(self, visitor): + visitor.visit_unique_constraint(self) + class Index(SchemaItem): """Represents an index of columns from a database table """ @@ -746,7 +851,13 @@ class SchemaVisitor(sql.ClauseVisitor): def visit_sequence(self, sequence): """visit a Sequence.""" pass - + def visit_primary_key_constraint(self, constraint): + pass + def visit_foreign_key_constraint(self, constraint): + pass + def visit_unique_constraint(self, constraint): + pass + default_metadata = DynamicMetaData('default') diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 7b17927f0..8109d8cd5 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -1225,15 +1225,12 @@ class TableClause(FromClause): super(TableClause, self).__init__(name) self.name = self.fullname = name self._columns = util.OrderedProperties() - self._indexes = util.OrderedProperties() self._foreign_keys = [] self._primary_key = [] for c in columns: self.append_column(c) self._oid_column = ColumnClause('oid', self, hidden=True) - indexes = property(lambda s:s._indexes) - def named_with_column(self): return True def append_column(self, c): @@ -1250,16 +1247,11 @@ class TableClause(FromClause): for ci in c.orig_set: self._orig_cols[ci] = c return self._orig_cols - columns = property(lambda s:s._columns) - c = property(lambda s:s._columns) - primary_key = property(lambda s:s._primary_key) - foreign_keys = property(lambda s:s._foreign_keys) original_columns = property(_orig_columns) def _clear(self): """clears all attributes on this TableClause so that new items can be added again""" self.columns.clear() - self.indexes.clear() self.foreign_keys[:] = [] self.primary_key[:] = [] try: |