summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/databases/mysql.py2
-rw-r--r--lib/sqlalchemy/databases/postgres.py2
-rw-r--r--lib/sqlalchemy/orm/properties.py14
-rw-r--r--lib/sqlalchemy/schema.py30
-rw-r--r--lib/sqlalchemy/sql.py16
5 files changed, 35 insertions, 29 deletions
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index 2c29bbe2a..88efcb755 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -444,7 +444,7 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
if not column.nullable:
colspec += " NOT NULL"
if column.primary_key:
- if not column.foreign_key and first_pk and column.autoincrement and isinstance(column.type, sqltypes.Integer):
+ if len(column.foreign_keys)==0 and first_pk and column.autoincrement and isinstance(column.type, sqltypes.Integer):
colspec += " AUTO_INCREMENT"
return colspec
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index 6fe51ad9a..e052fe8c0 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -490,7 +490,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
- if column.primary_key and not column.foreign_key and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+ if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
colspec += " SERIAL"
else:
colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index af3995039..2ad2c2b8c 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -216,7 +216,7 @@ class PropertyLoader(StrategizedProperty):
elif len([c for c in self.foreignkey if self.parent.unjoined_table.corresponding_column(c, False) is not None]):
return sync.MANYTOONE
else:
- raise exceptions.ArgumentError("Cant determine relation direction '%s', for '%s' in mapper '%s' with primary join\n '%s'" %(repr(self.foreignkey), self.key, str(self.mapper), str(self.primaryjoin)))
+ raise exceptions.ArgumentError("Cant determine relation direction for '%s' in mapper '%s' with primary join\n '%s'" %(self.key, str(self.mapper), str(self.primaryjoin)))
def _find_dependent(self):
"""searches through the primary join condition to determine which side
@@ -226,12 +226,16 @@ class PropertyLoader(StrategizedProperty):
def foo(binary):
if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
- if binary.left.foreign_key is not None and binary.left.foreign_key.references(binary.right.table):
- foreignkeys.add(binary.left)
- elif binary.right.foreign_key is not None and binary.right.foreign_key.references(binary.left.table):
- foreignkeys.add(binary.right)
+ for f in binary.left.foreign_keys:
+ if f.references(binary.right.table):
+ foreignkeys.add(binary.left)
+ for f in binary.right.foreign_keys:
+ if f.references(binary.left.table):
+ foreignkeys.add(binary.right)
visitor = mapperutil.BinaryVisitor(foo)
self.primaryjoin.accept_visitor(visitor)
+ if len(foreignkeys) == 0:
+ raise exceptions.ArgumentError("On relation '%s', can't figure out which side is the foreign key for join condition '%s'. Specify the 'foreignkey' argument to the relation." % (self.key, str(self.primaryjoin)))
self.foreignkey = foreignkeys
def get_join(self):
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 1d4209561..18d1d7b14 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -433,12 +433,12 @@ class Column(SchemaItem, sql.ColumnClause):
self.__originating_column = self
if self.index is not None and self.unique is not None:
raise exceptions.ArgumentError("Column may not define both index and unique")
- self._foreign_key = None
+ self._foreign_keys = util.Set()
if len(kwargs):
raise exceptions.ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys()))
primary_key = util.SimpleProperty('_primary_key')
- foreign_key = util.SimpleProperty('_foreign_key')
+ foreign_keys = util.SimpleProperty('_foreign_keys')
columns = property(lambda self:[self])
def __str__(self):
@@ -459,7 +459,7 @@ class Column(SchemaItem, sql.ColumnClause):
def __repr__(self):
return "Column(%s)" % string.join(
[repr(self.name)] + [repr(self.type)] +
- [repr(x) for x in [self.foreign_key] if x is not None] +
+ [repr(x) for x in self.foreign_keys if x is not None] +
["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default', 'onupdate']]
, ',')
@@ -501,11 +501,8 @@ class Column(SchemaItem, sql.ColumnClause):
This is a copy of this Column referenced
by a different parent (such as an alias or select statement)"""
- if self.foreign_key is None:
- fk = None
- else:
- fk = self.foreign_key.copy()
- c = Column(name or self.name, self.type, fk, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden, quote=self.quote)
+ fk = [ForeignKey(f._colspec) for f in self.foreign_keys]
+ c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden, quote=self.quote, *fk)
c.table = selectable
c.orig_set = self.orig_set
c.__originating_column = self.__originating_column
@@ -513,8 +510,7 @@ class Column(SchemaItem, sql.ColumnClause):
selectable.columns[c.key] = c
if self.primary_key:
selectable.primary_key.append(c)
- if fk is not None:
- c._init_items(fk)
+ [c._init_items(f) for f in fk]
return c
def _case_sens(self):
@@ -530,8 +526,8 @@ class Column(SchemaItem, sql.ColumnClause):
self.default.accept_schema_visitor(visitor)
if self.onupdate is not None:
self.onupdate.accept_schema_visitor(visitor)
- if self.foreign_key is not None:
- self.foreign_key.accept_schema_visitor(visitor)
+ for f in self.foreign_keys:
+ f.accept_schema_visitor(visitor)
visitor.visit_column(self)
@@ -631,11 +627,11 @@ class ForeignKey(SchemaItem):
# if a foreign key was already set up for the parent column, replace it with
# this one
- if self.parent.foreign_key is not None:
- self.parent.table.foreign_keys.remove(self.parent.foreign_key)
- self.parent.foreign_key = self
- self.parent.table.foreign_keys.append(self)
-
+ #if self.parent.foreign_key is not None:
+ # self.parent.table.foreign_keys.remove(self.parent.foreign_key)
+ #self.parent.foreign_key = self
+ self.parent.foreign_keys.add(self)
+ self.parent.table.foreign_keys.add(self)
class DefaultGenerator(SchemaItem):
"""Base class for column "default" values."""
def __init__(self, for_update=False, metadata=None):
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index a07536bc9..c113edaa3 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -618,8 +618,14 @@ class ColumnElement(Selectable, CompareMixin):
may correspond to several TableClause-attached columns)."""
primary_key = property(lambda self:getattr(self, '_primary_key', False), doc="primary key flag. indicates if this Column represents part or whole of a primary key.")
- foreign_key = property(lambda self:getattr(self, '_foreign_key', False), doc="foreign key accessor. points to a ForeignKey object which represents a Foreign Key placed on this column's ultimate ancestor.")
+ foreign_keys = property(lambda self:getattr(self, '_foreign_keys', []), doc="foreign key accessor. points to a ForeignKey object which represents a Foreign Key placed on this column's ultimate ancestor.")
columns = property(lambda self:[self], doc="Columns accessor which just returns self, to provide compatibility with Selectable objects.")
+ def _one_fkey(self):
+ if len(self._foreign_keys):
+ return list(self._foreign_keys)[0]
+ else:
+ return None
+ foreign_key = property(_one_fkey)
def _get_orig_set(self):
try:
@@ -731,7 +737,7 @@ class FromClause(Selectable):
return
self._columns = util.OrderedProperties()
self._primary_key = []
- self._foreign_keys = []
+ self._foreign_keys = util.Set()
self._orig_cols = {}
export = self._exportable_columns()
for column in export:
@@ -1077,8 +1083,8 @@ class Join(FromClause):
self._columns[column._label] = column
if column.primary_key:
self._primary_key.append(column)
- if column.foreign_key:
- self._foreign_keys.append(column.foreign_key)
+ for f in column.foreign_keys:
+ self._foreign_keys.add(f)
return column
def _match_primaries(self, primary, secondary):
crit = []
@@ -1252,7 +1258,7 @@ class TableClause(FromClause):
super(TableClause, self).__init__(name)
self.name = self.fullname = name
self._columns = util.OrderedProperties()
- self._foreign_keys = []
+ self._foreign_keys = util.Set()
self._primary_key = []
for c in columns:
self.append_column(c)