diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/properties.py | 14 | ||||
-rw-r--r-- | lib/sqlalchemy/schema.py | 30 | ||||
-rw-r--r-- | lib/sqlalchemy/sql.py | 16 |
5 files changed, 35 insertions, 29 deletions
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 2c29bbe2a..88efcb755 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -444,7 +444,7 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec += " NOT NULL" if column.primary_key: - if not column.foreign_key and first_pk and column.autoincrement and isinstance(column.type, sqltypes.Integer): + if len(column.foreign_keys)==0 and first_pk and column.autoincrement and isinstance(column.type, sqltypes.Integer): colspec += " AUTO_INCREMENT" return colspec diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 6fe51ad9a..e052fe8c0 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -490,7 +490,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) - if column.primary_key and not column.foreign_key and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): colspec += " SERIAL" else: colspec += " " + column.type.engine_impl(self.engine).get_col_spec() diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index af3995039..2ad2c2b8c 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -216,7 +216,7 @@ class PropertyLoader(StrategizedProperty): elif len([c for c in self.foreignkey if self.parent.unjoined_table.corresponding_column(c, False) is not None]): return sync.MANYTOONE else: - raise exceptions.ArgumentError("Cant determine relation direction '%s', for '%s' in mapper '%s' with primary join\n '%s'" %(repr(self.foreignkey), self.key, str(self.mapper), str(self.primaryjoin))) + raise exceptions.ArgumentError("Cant determine relation direction for '%s' in mapper '%s' with primary join\n '%s'" %(self.key, str(self.mapper), str(self.primaryjoin))) def _find_dependent(self): """searches through the primary join condition to determine which side @@ -226,12 +226,16 @@ class PropertyLoader(StrategizedProperty): def foo(binary): if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return - if binary.left.foreign_key is not None and binary.left.foreign_key.references(binary.right.table): - foreignkeys.add(binary.left) - elif binary.right.foreign_key is not None and binary.right.foreign_key.references(binary.left.table): - foreignkeys.add(binary.right) + for f in binary.left.foreign_keys: + if f.references(binary.right.table): + foreignkeys.add(binary.left) + for f in binary.right.foreign_keys: + if f.references(binary.left.table): + foreignkeys.add(binary.right) visitor = mapperutil.BinaryVisitor(foo) self.primaryjoin.accept_visitor(visitor) + if len(foreignkeys) == 0: + raise exceptions.ArgumentError("On relation '%s', can't figure out which side is the foreign key for join condition '%s'. Specify the 'foreignkey' argument to the relation." % (self.key, str(self.primaryjoin))) self.foreignkey = foreignkeys def get_join(self): diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 1d4209561..18d1d7b14 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -433,12 +433,12 @@ class Column(SchemaItem, sql.ColumnClause): self.__originating_column = self if self.index is not None and self.unique is not None: raise exceptions.ArgumentError("Column may not define both index and unique") - self._foreign_key = None + self._foreign_keys = util.Set() if len(kwargs): raise exceptions.ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys())) primary_key = util.SimpleProperty('_primary_key') - foreign_key = util.SimpleProperty('_foreign_key') + foreign_keys = util.SimpleProperty('_foreign_keys') columns = property(lambda self:[self]) def __str__(self): @@ -459,7 +459,7 @@ class Column(SchemaItem, sql.ColumnClause): def __repr__(self): return "Column(%s)" % string.join( [repr(self.name)] + [repr(self.type)] + - [repr(x) for x in [self.foreign_key] if x is not None] + + [repr(x) for x in self.foreign_keys if x is not None] + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default', 'onupdate']] , ',') @@ -501,11 +501,8 @@ class Column(SchemaItem, sql.ColumnClause): This is a copy of this Column referenced by a different parent (such as an alias or select statement)""" - if self.foreign_key is None: - fk = None - else: - fk = self.foreign_key.copy() - c = Column(name or self.name, self.type, fk, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden, quote=self.quote) + fk = [ForeignKey(f._colspec) for f in self.foreign_keys] + c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden, quote=self.quote, *fk) c.table = selectable c.orig_set = self.orig_set c.__originating_column = self.__originating_column @@ -513,8 +510,7 @@ class Column(SchemaItem, sql.ColumnClause): selectable.columns[c.key] = c if self.primary_key: selectable.primary_key.append(c) - if fk is not None: - c._init_items(fk) + [c._init_items(f) for f in fk] return c def _case_sens(self): @@ -530,8 +526,8 @@ class Column(SchemaItem, sql.ColumnClause): self.default.accept_schema_visitor(visitor) if self.onupdate is not None: self.onupdate.accept_schema_visitor(visitor) - if self.foreign_key is not None: - self.foreign_key.accept_schema_visitor(visitor) + for f in self.foreign_keys: + f.accept_schema_visitor(visitor) visitor.visit_column(self) @@ -631,11 +627,11 @@ class ForeignKey(SchemaItem): # if a foreign key was already set up for the parent column, replace it with # this one - if self.parent.foreign_key is not None: - self.parent.table.foreign_keys.remove(self.parent.foreign_key) - self.parent.foreign_key = self - self.parent.table.foreign_keys.append(self) - + #if self.parent.foreign_key is not None: + # self.parent.table.foreign_keys.remove(self.parent.foreign_key) + #self.parent.foreign_key = self + self.parent.foreign_keys.add(self) + self.parent.table.foreign_keys.add(self) class DefaultGenerator(SchemaItem): """Base class for column "default" values.""" def __init__(self, for_update=False, metadata=None): diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index a07536bc9..c113edaa3 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -618,8 +618,14 @@ class ColumnElement(Selectable, CompareMixin): may correspond to several TableClause-attached columns).""" primary_key = property(lambda self:getattr(self, '_primary_key', False), doc="primary key flag. indicates if this Column represents part or whole of a primary key.") - foreign_key = property(lambda self:getattr(self, '_foreign_key', False), doc="foreign key accessor. points to a ForeignKey object which represents a Foreign Key placed on this column's ultimate ancestor.") + foreign_keys = property(lambda self:getattr(self, '_foreign_keys', []), doc="foreign key accessor. points to a ForeignKey object which represents a Foreign Key placed on this column's ultimate ancestor.") columns = property(lambda self:[self], doc="Columns accessor which just returns self, to provide compatibility with Selectable objects.") + def _one_fkey(self): + if len(self._foreign_keys): + return list(self._foreign_keys)[0] + else: + return None + foreign_key = property(_one_fkey) def _get_orig_set(self): try: @@ -731,7 +737,7 @@ class FromClause(Selectable): return self._columns = util.OrderedProperties() self._primary_key = [] - self._foreign_keys = [] + self._foreign_keys = util.Set() self._orig_cols = {} export = self._exportable_columns() for column in export: @@ -1077,8 +1083,8 @@ class Join(FromClause): self._columns[column._label] = column if column.primary_key: self._primary_key.append(column) - if column.foreign_key: - self._foreign_keys.append(column.foreign_key) + for f in column.foreign_keys: + self._foreign_keys.add(f) return column def _match_primaries(self, primary, secondary): crit = [] @@ -1252,7 +1258,7 @@ class TableClause(FromClause): super(TableClause, self).__init__(name) self.name = self.fullname = name self._columns = util.OrderedProperties() - self._foreign_keys = [] + self._foreign_keys = util.Set() self._primary_key = [] for c in columns: self.append_column(c) |