summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES1
-rw-r--r--lib/sqlalchemy/orm/mapper.py7
-rw-r--r--lib/sqlalchemy/orm/properties.py5
-rw-r--r--lib/sqlalchemy/orm/query.py7
-rw-r--r--test/orm/compile.py4
-rw-r--r--test/orm/inheritance4.py220
6 files changed, 237 insertions, 7 deletions
diff --git a/CHANGES b/CHANGES
index d6e2e8463..34103bf0a 100644
--- a/CHANGES
+++ b/CHANGES
@@ -46,6 +46,7 @@ relationships to an inheriting mapper (which is also self-referential)
[ticket:244]
- added 'checkfirst' argument to table.create()/table.drop(), as
well as table.exists() [ticket:234]
+- some other ongoing fixes to inheritance [ticket:245]
0.2.5
- fixed endless loop bug in select_by(), if the traversal hit
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 8c8ce4cff..7cc99b225 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -76,6 +76,7 @@ class Mapper(object):
self.always_refresh = always_refresh
self.version_id_col = version_id_col
self.concrete = concrete
+ self.single = False
self.inherits = inherits
self.select_table = select_table
self.local_table = local_table
@@ -230,6 +231,7 @@ class Mapper(object):
# inherit_condition is optional.
if self.local_table is None:
self.local_table = self.inherits.local_table
+ self.single = True
if not self.local_table is self.inherits.local_table:
if self.concrete:
self._synchronizer= None
@@ -348,7 +350,7 @@ class Mapper(object):
self.inherits._inheriting_mappers.add(self)
for key, prop in self.inherits.__props.iteritems():
if not self.__props.has_key(key):
- p = prop.adapt_to_inherited(key, self)
+ prop.adapt_to_inherited(key, self)
# load properties from the main table object,
# not overriding those set up in the 'properties' argument
@@ -539,7 +541,7 @@ class Mapper(object):
prop.init(key, self)
for mapper in self._inheriting_mappers:
- p = prop.adapt_to_inherited(key, mapper)
+ prop.adapt_to_inherited(key, mapper)
def __str__(self):
return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + str(self.local_table)
@@ -1133,7 +1135,6 @@ class MapperProperty(object):
p.localparent = newparent
p.parent = self.parent
p.inherits = getattr(self, 'inherits', self)
- return p
def do_init(self):
"""template method for subclasses"""
pass
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 6b3bb7883..c34b0223a 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -48,6 +48,11 @@ class ColumnProperty(mapper.MapperProperty):
# set a scalar object instance directly on the object,
# bypassing SmartProperty event handlers.
instance.__dict__[self.key] = row[self.columns[0]]
+ def adapt_to_inherited(self, key, newparent):
+ if newparent.concrete:
+ return
+ else:
+ super(ColumnProperty, self).adapt_to_inherited(key, newparent)
def __repr__(self):
return "ColumnProperty(%s)" % repr([str(c) for c in self.columns])
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 8e87ac09a..268273990 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -18,7 +18,6 @@ class Query(object):
else:
self.mapper = class_or_mapper.compile()
self.mapper = self.mapper.get_select_mapper().compile()
-
self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh)
self.order_by = kwargs.pop('order_by', self.mapper.order_by)
self.extension = kwargs.pop('extension', self.mapper.extension)
@@ -317,7 +316,10 @@ class Query(object):
if order_by is False:
if self.table.default_order_by() is not None:
order_by = self.table.default_order_by()
-
+
+ if self.mapper.single and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
+ whereclause = sql.and_(whereclause, self.mapper.polymorphic_on==self.mapper.polymorphic_identity)
+
if self._should_nest(**kwargs):
from_obj.append(self.table)
@@ -366,5 +368,6 @@ class Query(object):
# give all the attached properties a chance to modify the query
for key, value in self.mapper.props.iteritems():
value.setup(key, statement, **kwargs)
+
return statement
diff --git a/test/orm/compile.py b/test/orm/compile.py
index 3268b9230..96e56e597 100644
--- a/test/orm/compile.py
+++ b/test/orm/compile.py
@@ -8,7 +8,7 @@ class CompileTest(testbase.AssertMixin):
def testone(self):
global metadata, order, employee, product, tax, orderproduct
- metadata = BoundMetaData(engine)
+ metadata = BoundMetaData(testbase.db)
order = Table('orders', metadata,
Column('id', Integer, primary_key=True),
@@ -69,7 +69,7 @@ class CompileTest(testbase.AssertMixin):
def testtwo(self):
"""test that conflicting backrefs raises an exception"""
global metadata, order, employee, product, tax, orderproduct
- metadata = BoundMetaData(engine)
+ metadata = BoundMetaData(testbase.db)
order = Table('orders', metadata,
Column('id', Integer, primary_key=True),
diff --git a/test/orm/inheritance4.py b/test/orm/inheritance4.py
new file mode 100644
index 000000000..859d40f88
--- /dev/null
+++ b/test/orm/inheritance4.py
@@ -0,0 +1,220 @@
+# TODO: make unit tests out of all this
+
+####### multiple table
+
+from sqlalchemy import *
+
+#db = create_engine('sqlite:///', echo=True)
+db = create_engine('postgres://scott:tiger@127.0.0.1/test', echo=True)
+metadata = BoundMetaData(db)
+
+session = create_session()
+
+class Employee(object):
+ def __init__(self, name):
+ self.name = name
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name
+
+class Manager(Employee):
+ def __init__(self, name, manager_data):
+ self.name = name
+ self.manager_data = manager_data
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name + " " + self.manager_data
+
+class Engineer(Employee):
+ def __init__(self, name, engineer_info):
+ self.name = name
+ self.engineer_info = engineer_info
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name + " " + self.engineer_info
+
+
+people = Table('people', metadata,
+ Column('person_id', Integer, primary_key=True),
+ Column('name', String(50)),
+ Column('type', String(30)))
+
+engineers = Table('engineers', metadata,
+ Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
+ Column('engineer_info', String(50)),
+ )
+
+managers = Table('managers', metadata,
+ Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
+ Column('manager_data', String(50)),
+ )
+
+people_managers = Table('people_managers', metadata,
+ Column('person_id', Integer, ForeignKey("people.person_id")),
+ Column('manager_id', Integer, ForeignKey("managers.person_id"))
+)
+
+person_join = polymorphic_union( {
+ 'engineer':people.join(engineers),
+ 'manager':people.join(managers),
+ 'person':people.select(people.c.type=='person'),
+ }, None, 'pjoin')
+
+
+
+
+person_mapper = mapper(Employee, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person',
+ properties = dict(managers = relation(Manager, secondary=people_managers, lazy=False))
+ )
+
+
+
+mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
+mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager')
+
+
+
+def create_some_employees():
+ people.create()
+ engineers.create()
+ managers.create()
+ people_managers.create()
+ session.save(Manager('Tom', 'knows how to manage things'))
+ session.save(Engineer('Kurt', 'knows how to hack'))
+ session.flush()
+ session.query(Manager).select()
+try:
+ create_some_employees()
+finally:
+ metadata.drop_all()
+
+
+####### concrete table
+
+from sqlalchemy import *
+
+db = create_engine("sqlite:///:memory:")
+
+metadata = BoundMetaData(db)
+session = create_session()
+
+class Employee(object):
+ def __init__(self, name):
+ self.name = name
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name
+
+class Manager(Employee):
+ def __init__(self, name, manager_data):
+ self.name = name
+ self.manager_data = manager_data
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name + " " + self.manager_data
+
+class Engineer(Employee):
+ def __init__(self, name, engineer_info):
+ self.name = name
+ self.engineer_info = engineer_info
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name + " " + self.engineer_info
+
+
+managers_table = Table('managers', metadata,
+ Column('employee_id', Integer, primary_key=True),
+ Column('name', String(50)),
+ Column('manager_data', String(50)),
+).create()
+
+engineers_table = Table('engineers', metadata,
+ Column('employee_id', Integer, primary_key=True),
+ Column('name', String(50)),
+ Column('engineer_info', String(50)),
+).create()
+
+
+
+pjoin = polymorphic_union({
+ 'manager':managers_table,
+ 'engineer':engineers_table
+}, 'type', 'pjoin')
+
+employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type)
+manager_mapper = mapper(Manager, managers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='manager')
+engineer_mapper = mapper(Engineer, engineers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='engineer')
+
+
+
+session.save(Manager('Tom', 'knows how to manage things'))
+session.save(Engineer('Kurt', 'knows how to hack'))
+session.flush()
+
+
+# this gives you [Engineer Kurt knows how to hack, Manager Tom knows how to manage things]
+# as it should be
+session.query(Employee).select()
+
+# this fails
+session.query(Engineer).select()
+
+# this fails
+session.query(Manager).select()
+
+
+############ single table
+from sqlalchemy import *
+
+db = create_engine("sqlite:///:memory:")
+
+metadata = BoundMetaData(db)
+session = create_session()
+
+class Employee(object):
+ def __init__(self, name):
+ self.name = name
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name
+
+class Manager(Employee):
+ def __init__(self, name, manager_data):
+ self.name = name
+ self.manager_data = manager_data
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name + " " + self.manager_data
+
+class Engineer(Employee):
+ def __init__(self, name, engineer_info):
+ self.name = name
+ self.engineer_info = engineer_info
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name + " " + self.engineer_info
+
+
+employees_table = Table('employees', metadata,
+ Column('employee_id', Integer, primary_key=True),
+ Column('name', String(50)),
+ Column('manager_data', String(50)),
+ Column('engineer_info', String(50)),
+ Column('type', String(20))
+)
+
+employee_mapper = mapper(Employee, employees_table, polymorphic_on=employees_table.c.type)
+manager_mapper = mapper(Manager, inherits=employee_mapper, polymorphic_identity='manager')
+engineer_mapper = mapper(Engineer, inherits=employee_mapper, polymorphic_identity='engineer')
+
+
+employees_table.create()
+
+session.save(Manager('Tom', 'knows how to manage things'))
+session.save(Engineer('Kurt', 'knows how to hack'))
+session.flush()
+
+
+# this gives you [Engineer Kurt knows how to hack, Manager Tom knows how to manage things]
+# as it should be
+session.query(Employee).select()
+
+# this gives you [Engineer Kurt knows how to hack, Manager Tom knows how to manage things]
+# instead of [Engineer Kurt knows how to hack]
+session.query(Engineer).select()
+
+
+# this gives you [Engineer Kurt knows how to hack, Manager Tom knows how to manage things]
+# instead of [Manager Tom knows how to manage things]
+session.query(Manager).select()