summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/orm/attributes.py248
-rw-r--r--lib/sqlalchemy/orm/dynamic.py2
-rw-r--r--lib/sqlalchemy/orm/mapper.py49
-rw-r--r--lib/sqlalchemy/orm/session.py15
-rw-r--r--test/orm/attributes.py61
-rw-r--r--test/orm/collection.py11
-rw-r--r--test/perf/masseagerload.py2
7 files changed, 216 insertions, 172 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index f369c5396..7290e2ac2 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -84,13 +84,13 @@ class InstrumentedAttribute(interfaces.PropComparator):
return self.get(obj)
def commit_to_state(self, state, obj, value=NO_VALUE):
- """commit the a copy of thte value of 'obj' to the given CommittedState"""
-
+ """commit the object's current state to its 'committed' state."""
+
if value is NO_VALUE:
if self.key in obj.__dict__:
value = obj.__dict__[self.key]
if value is not NO_VALUE:
- state.data[self.key] = self.copy(value)
+ state.committed_state[self.key] = self.copy(value)
def clause_element(self):
return self.comparator.clause_element()
@@ -119,7 +119,7 @@ class InstrumentedAttribute(interfaces.PropComparator):
will also not have a `hasparent` flag.
"""
- return item._state.get(('hasparent', id(self)), optimistic)
+ return item._state.parents.get(id(self), optimistic)
def sethasparent(self, item, value):
"""Set a boolean flag on the given item corresponding to
@@ -127,7 +127,7 @@ class InstrumentedAttribute(interfaces.PropComparator):
attribute represented by this ``InstrumentedAttribute``.
"""
- item._state[('hasparent', id(self))] = value
+ item._state.parents[id(self)] = value
def get_history(self, obj, passive=False):
"""Return a new ``AttributeHistory`` object for the given object/this attribute's key.
@@ -165,11 +165,11 @@ class InstrumentedAttribute(interfaces.PropComparator):
if callable_ is None:
self.initialize(obj)
else:
- obj._state[('callable', self)] = callable_
+ obj._state.callables[self] = callable_
def _get_callable(self, obj):
- if ('callable', self) in obj._state:
- return obj._state[('callable', self)]
+ if self in obj._state.callables:
+ return obj._state.callables[self]
elif self.callable_ is not None:
return self.callable_(obj)
else:
@@ -183,7 +183,7 @@ class InstrumentedAttribute(interfaces.PropComparator):
"""
try:
- del obj._state[('callable', self)]
+ del obj._state.callables[self]
except KeyError:
pass
self.clear(obj)
@@ -223,10 +223,8 @@ class InstrumentedAttribute(interfaces.PropComparator):
state = obj._state
# if an instance-wide "trigger" was set, call that
# and start again
- if 'trigger' in state:
- trig = state['trigger']
- del state['trigger']
- trig()
+ if state.trigger:
+ state.call_trigger()
return self.get(obj, passive=passive)
callable_ = self._get_callable(obj)
@@ -265,11 +263,10 @@ class InstrumentedAttribute(interfaces.PropComparator):
"""
state = obj._state
- orig = state.get('original', None)
- if orig is not None:
- self.commit_to_state(orig, obj, value)
+ if state.committed_state is not None:
+ self.commit_to_state(state, obj, value)
# remove per-instance callable, if any
- state.pop(('callable', self), None)
+ state.callables.pop(self, None)
obj.__dict__[self.key] = value
return value
@@ -278,21 +275,21 @@ class InstrumentedAttribute(interfaces.PropComparator):
return value
def fire_append_event(self, obj, value, initiator):
- obj._state['modified'] = True
+ obj._state.modified = True
if self.trackparent and value is not None:
self.sethasparent(value, True)
for ext in self.extensions:
ext.append(obj, value, initiator or self)
def fire_remove_event(self, obj, value, initiator):
- obj._state['modified'] = True
+ obj._state.modified = True
if self.trackparent and value is not None:
self.sethasparent(value, False)
for ext in self.extensions:
ext.remove(obj, value, initiator or self)
def fire_replace_event(self, obj, value, previous, initiator):
- obj._state['modified'] = True
+ obj._state.modified = True
if self.trackparent:
if value is not None:
self.sethasparent(value, True)
@@ -334,7 +331,7 @@ class InstrumentedScalarAttribute(InstrumentedAttribute):
if self.mutable_scalars:
h = self.get_history(obj, passive=True)
if h is not None and h.is_modified():
- obj._state['modified'] = True
+ obj._state.modified = True
return True
else:
return False
@@ -354,10 +351,8 @@ class InstrumentedScalarAttribute(InstrumentedAttribute):
state = obj._state
# if an instance-wide "trigger" was set, call that
- if 'trigger' in state:
- trig = state['trigger']
- del state['trigger']
- trig()
+ if state.trigger:
+ state.call_trigger()
old = self.get(obj)
obj.__dict__[self.key] = value
@@ -415,7 +410,7 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute):
if self.key not in obj.__dict__:
return
- obj._state['modified'] = True
+ obj._state.modified = True
collection = self.get_collection(obj)
collection.clear_with_event()
@@ -453,10 +448,8 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute):
state = obj._state
# if an instance-wide "trigger" was set, call that
- if 'trigger' in state:
- trig = state['trigger']
- del state['trigger']
- trig()
+ if state.trigger:
+ state.call_trigger()
old = self.get(obj)
old_collection = self.get_collection(obj, old)
@@ -466,7 +459,7 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute):
collection=new_collection)
obj.__dict__[self.key] = user_data
- state['modified'] = True
+ state.modified = True
# mark all the old elements as detached from the parent
if old_collection:
@@ -477,17 +470,16 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute):
"""Set an attribute value on the given instance and 'commit' it."""
state = obj._state
- orig = state.get('original', None)
collection, user_data = self._build_collection(obj)
self._load_collection(obj, value or [], emit_events=False,
collection=collection)
value = user_data
- if orig is not None:
- self.commit_to_state(orig, obj, value)
+ if state.committed_state is not None:
+ self.commit_to_state(state, obj, value)
# remove per-instance callable, if any
- state.pop(('callable', self), None)
+ state.callables.pop(self, None)
obj.__dict__[self.key] = value
return value
@@ -543,38 +535,57 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
def remove(self, obj, child, initiator):
getattr(child.__class__, self.key).remove(child, obj, initiator)
-class CommittedState(object):
- """Store the original state of an object when the ``commit()`
- method on the attribute manager is called.
- """
-
-
- def __init__(self, manager, obj):
- self.data = {}
+class InstanceState(object):
+ """tracks state information at the instance level."""
+
+ def __init__(self, obj):
+ self.committed_state = None
+ self.modified = False
+ self.trigger = None
+ self.callables = {}
+ self.parents = {}
+
+ def __getstate__(self):
+ return {'committed_state':self.committed_state, 'parents':self.parents, 'modified':self.modified}
+
+ def __setstate__(self, state):
+ self.committed_state = state['committed_state']
+ self.parents = state['parents']
+ self.modified = state['modified']
+ self.callables = {}
+ self.trigger = None
+
+ def call_trigger(self):
+ trig = self.trigger
+ self.trigger = None
+ trig()
+
+ def commit(self, manager, obj):
+ self.committed_state = {}
+ self.modified = False
for attr in manager.managed_attributes(obj.__class__):
attr.commit_to_state(self, obj)
def rollback(self, manager, obj):
- for attr in manager.managed_attributes(obj.__class__):
- if attr.key in self.data:
- if not hasattr(attr, 'get_collection'):
- obj.__dict__[attr.key] = self.data[attr.key]
+ if not self.committed_state:
+ manager._clear(obj)
+ else:
+ for attr in manager.managed_attributes(obj.__class__):
+ if attr.key in self.committed_state:
+ if not hasattr(attr, 'get_collection'):
+ obj.__dict__[attr.key] = self.committed_state[attr.key]
+ else:
+ collection = attr.get_collection(obj)
+ collection.clear_without_event()
+ for item in self.committed_state[attr.key]:
+ collection.append_without_event(item)
else:
- collection = attr.get_collection(obj)
- collection.clear_without_event()
- for item in self.data[attr.key]:
- collection.append_without_event(item)
- else:
- if attr.key in obj.__dict__:
- del obj.__dict__[attr.key]
-
- def __repr__(self):
- return "CommittedState: %s" % repr(self.data)
+ if attr.key in obj.__dict__:
+ del obj.__dict__[attr.key]
class AttributeHistory(object):
"""Calculate the *history* of a particular attribute on a
- particular instance, based on the ``CommittedState`` associated
- with the instance, if any.
+ particular instance.
"""
def __init__(self, attr, obj, current, passive=False):
@@ -583,9 +594,8 @@ class AttributeHistory(object):
# get the "original" value. if a lazy load was fired when we got
# the 'current' value, this "original" was also populated just
# now as well (therefore we have to get it second)
- orig = obj._state.get('original', None)
- if orig is not None:
- original = orig.data.get(attr.key)
+ if obj._state.committed_state:
+ original = obj._state.committed_state.get(attr.key, None)
else:
original = None
@@ -652,11 +662,7 @@ class AttributeManager(object):
"""
for o in obj:
- orig = o._state.get('original')
- if orig is not None:
- orig.rollback(self, o)
- else:
- self._clear(o)
+ o._state.rollback(self, o)
def _clear(self, obj):
for attr in self.managed_attributes(obj.__class__):
@@ -664,19 +670,12 @@ class AttributeManager(object):
del obj.__dict__[attr.key]
except KeyError:
pass
-
+
def commit(self, *obj):
- """Create a ``CommittedState`` instance for each object in the given list, representing
- its *unchanged* state, and associates it with the instance.
-
- ``AttributeHistory`` objects will indicate the modified state of
- instance attributes as compared to its value in this
- ``CommittedState`` object.
- """
+ """Establish the "committed state" for each object in the given list."""
for o in obj:
- o._state['original'] = CommittedState(self, o)
- o._state['modified'] = False
+ o._state.commit(self, o)
def managed_attributes(self, class_):
"""Return a list of all ``InstrumentedAttribute`` objects
@@ -706,7 +705,7 @@ class AttributeManager(object):
for attr in self.managed_attributes(object.__class__):
if attr.check_mutable_modified(object):
return True
- return object._state.get('modified', False)
+ return object._state.modified
def get_history(self, obj, key, **kwargs):
"""Return a new ``AttributeHistory`` object for the given
@@ -743,12 +742,10 @@ class AttributeManager(object):
removed.
"""
+ s = obj._state
self._clear(obj)
- try:
- del obj._state['original']
- except KeyError:
- pass
- obj._state['trigger'] = callable
+ s.committed_state = None
+ s.trigger = callable
def untrigger_history(self, obj):
"""Remove a trigger function set by trigger_history.
@@ -756,14 +753,14 @@ class AttributeManager(object):
Does not restore the previous state of the object.
"""
- del obj._state['trigger']
+ obj._state.trigger = None
def has_trigger(self, obj):
"""Return True if the given object has a trigger function set
by ``trigger_history()``.
"""
- return 'trigger' in obj._state
+ return obj._state.trigger is not None
def reset_instance_attribute(self, obj, key):
"""Remove any per-instance callable functions corresponding to
@@ -774,16 +771,6 @@ class AttributeManager(object):
attr = getattr(obj.__class__, key)
attr.reset(obj)
- def reset_class_managed(self, class_):
- """Remove all ``InstrumentedAttribute`` property objects from
- the given class.
- """
-
- for attr in self.noninherited_managed_attributes(class_):
- delattr(class_, attr.key)
- self._inherited_attribute_cache.pop(class_,None)
- self._noninherited_attribute_cache.pop(class_,None)
-
def is_class_managed(self, class_, key):
"""Return True if the given `key` correponds to an
instrumented property on the given class.
@@ -826,7 +813,71 @@ class AttributeManager(object):
return getattr(obj_or_cls, key)
else:
return getattr(obj_or_cls.__class__, key)
+
+ def manage(self, obj):
+ if not hasattr(obj, '_state'):
+ obj._state = InstanceState(obj)
+
+ def new_instance(self, class_):
+ """create a new instance of class_ without its __init__() method being called."""
+
+ s = class_.__new__(class_)
+ s._state = InstanceState(s)
+ return s
+
+ def register_class(self, class_, extra_init=None, on_exception=None):
+ """decorate the constructor of the given class to establish attribute
+ management on new instances."""
+ oldinit = None
+ doinit = False
+
+ def init(instance, *args, **kwargs):
+ instance._state = InstanceState(instance)
+
+ if extra_init:
+ extra_init(class_, oldinit, instance, args, kwargs)
+
+ if doinit:
+ try:
+ oldinit(instance, *args, **kwargs)
+ except:
+ if on_exception:
+ on_exception(class_, oldinit, instance, args, kwargs)
+ raise
+
+ # override oldinit
+ oldinit = class_.__init__
+ if oldinit is None or not hasattr(oldinit, '_oldinit'):
+ init._oldinit = oldinit
+ class_.__init__ = init
+ # if oldinit is already one of our 'init' methods, replace it
+ elif hasattr(oldinit, '_oldinit'):
+ init._oldinit = oldinit._oldinit
+ class_.__init = init
+ oldinit = oldinit._oldinit
+
+ if oldinit is not None:
+ doinit = oldinit is not object.__init__
+ try:
+ init.__name__ = oldinit.__name__
+ init.__doc__ = oldinit.__doc__
+ except:
+ # cant set __name__ in py 2.3 !
+ pass
+
+ def unregister_class(self, class_):
+ if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
+ if class_.__init__._oldinit is not None:
+ class_.__init__ = class_.__init__._oldinit
+ else:
+ delattr(class_, '__init__')
+
+ for attr in self.noninherited_managed_attributes(class_):
+ delattr(class_, attr.key)
+ self._inherited_attribute_cache.pop(class_,None)
+ self._noninherited_attribute_cache.pop(class_,None)
+
def register_attribute(self, class_, key, uselist, callable_=None, **kwargs):
"""Register an attribute at the class level to be instrumented
for all instances of the class.
@@ -837,13 +888,6 @@ class AttributeManager(object):
self._inherited_attribute_cache.pop(class_, None)
self._noninherited_attribute_cache.pop(class_, None)
- if not hasattr(class_, '_state'):
- def _get_state(self):
- if not hasattr(self, '_sa_attr_state'):
- self._sa_attr_state = {}
- return self._sa_attr_state
- class_._state = property(_get_state)
-
typecallable = kwargs.pop('typecallable', None)
if isinstance(typecallable, InstrumentedAttribute):
typecallable = None
diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py
index 1d4b5f6c9..aa5105150 100644
--- a/lib/sqlalchemy/orm/dynamic.py
+++ b/lib/sqlalchemy/orm/dynamic.py
@@ -34,7 +34,7 @@ class DynamicCollectionAttribute(attributes.InstrumentedAttribute):
old_collection = self.get(obj).assign(value)
# TODO: emit events ???
- state['modified'] = True
+ state.modified = True
def delete(self, *args, **kwargs):
raise NotImplementedError()
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 960282255..5d495d7a9 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -198,14 +198,9 @@ class Mapper(object):
def dispose(self):
# disaable any attribute-based compilation
self.__props_init = True
- attribute_manager.reset_class_managed(self.class_)
if hasattr(self.class_, 'c'):
del self.class_.c
- if hasattr(self.class_, '__init__') and hasattr(self.class_.__init__, '_oldinit'):
- if self.class_.__init__._oldinit is not None:
- self.class_.__init__ = self.class_.__init__._oldinit
- else:
- delattr(self.class_, '__init__')
+ attribute_manager.unregister_class(self.class_)
def compile(self):
"""Compile this mapper into its final internal format.
@@ -664,34 +659,14 @@ class Mapper(object):
if not self.non_primary and (self.class_key in mapper_registry):
raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined with entity name '%s'. Use non_primary=True to create a non primary Mapper, or to create a new primary mapper, remove this mapper first via sqlalchemy.orm.clear_mapper(mapper), or preferably sqlalchemy.orm.clear_mappers() to clear all mappers." % (self.class_, self.entity_name))
- attribute_manager.reset_class_managed(self.class_)
-
- oldinit = self.class_.__init__
- doinit = oldinit is not None and oldinit is not object.__init__
-
- def init(instance, *args, **kwargs):
+ def extra_init(class_, oldinit, instance, args, kwargs):
self.compile()
- self.extension.init_instance(self, self.class_, oldinit, instance, args, kwargs)
-
- if doinit:
- try:
- oldinit(instance, *args, **kwargs)
- except:
- # call init_failed but suppress exceptions into warnings so that original __init__
- # exception is raised
- util.warn_exception(self.extension.init_failed, self, self.class_, oldinit, instance, args, kwargs)
- raise
-
- # override oldinit, ensuring that its not already a Mapper-decorated init method
- if oldinit is None or not hasattr(oldinit, '_oldinit'):
- try:
- init.__name__ = oldinit.__name__
- init.__doc__ = oldinit.__doc__
- except:
- # cant set __name__ in py 2.3 !
- pass
- init._oldinit = oldinit
- self.class_.__init__ = init
+ self.extension.init_instance(self, class_, oldinit, instance, args, kwargs)
+
+ def on_exception(class_, oldinit, instance, args, kwargs):
+ util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs)
+
+ attribute_manager.register_class(self.class_, extra_init=extra_init, on_exception=on_exception)
_COMPILE_MUTEX.acquire()
try:
@@ -1436,7 +1411,7 @@ class Mapper(object):
# plugin point
instance = extension.create_instance(self, context, row, self.class_)
if instance is EXT_CONTINUE:
- instance = self._create_instance(context.session)
+ instance = attribute_manager.new_instance(self.class_)
instance._entity_name = self.entity_name
if self.__should_log_debug:
self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
@@ -1459,12 +1434,6 @@ class Mapper(object):
return instance
- def _create_instance(self, session):
- obj = self.class_.__new__(self.class_)
- obj._entity_name = self.entity_name
-
- return obj
-
def _deferred_inheritance_condition(self, needs_tables):
cond = self.inherit_condition
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index ebd4bd3d3..b616570ab 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -847,7 +847,7 @@ class Session(object):
try:
key = getattr(object, '_instance_key', None)
if key is None:
- merged = mapper.class_.__new__(mapper.class_)
+ merged = attribute_manager.new_instance(mapper.class_)
else:
if key in self.identity_map:
merged = self.identity_map[key]
@@ -940,16 +940,9 @@ class Session(object):
"or is already persistent in a "
"different Session" % repr(obj))
else:
- m = _class_mapper(obj.__class__, entity_name=kwargs.get('entity_name', None))
-
- # this would be a nice exception to raise...however this is incompatible with a contextual
- # session which puts all objects into the session upon construction.
- #if m._is_orphan(object):
- # raise exceptions.InvalidRequestError("Instance '%s' is an orphan, "
- # "and must be attached to a parent "
- # "object to be saved" % (repr(object)))
-
- m._assign_entity_name(obj)
+ # TODO: consolidate the steps here
+ attribute_manager.manage(obj)
+ obj._entity_name = kwargs.get('entity_name', None)
self._attach(obj)
self.uow.register_new(obj)
diff --git a/test/orm/attributes.py b/test/orm/attributes.py
index 8ca2d1b8e..6314656b9 100644
--- a/test/orm/attributes.py
+++ b/test/orm/attributes.py
@@ -5,14 +5,17 @@ from sqlalchemy.orm.collections import collection
from sqlalchemy import exceptions
from testlib import *
+# these test classes defined at the module
+# level to support pickling
class MyTest(object):pass
class MyTest2(object):pass
-
+
class AttributesTest(PersistTest):
"""tests for the attributes.py module, which deals with tracking attribute changes on an object."""
- def testbasic(self):
+ def test_basic(self):
class User(object):pass
manager = attributes.AttributeManager()
+ manager.register_class(User)
manager.register_attribute(User, 'user_id', uselist = False)
manager.register_attribute(User, 'user_name', uselist = False)
manager.register_attribute(User, 'email_address', uselist = False)
@@ -39,8 +42,11 @@ class AttributesTest(PersistTest):
print repr(u.__dict__)
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
- def testpickleness(self):
+ def test_pickleness(self):
+
manager = attributes.AttributeManager()
+ manager.register_class(MyTest)
+ manager.register_class(MyTest2)
manager.register_attribute(MyTest, 'user_id', uselist = False)
manager.register_attribute(MyTest, 'user_name', uselist = False)
manager.register_attribute(MyTest, 'email_address', uselist = False)
@@ -97,10 +103,12 @@ class AttributesTest(PersistTest):
self.assert_(o4.mt2[0].a == 'abcde')
self.assert_(o4.mt2[0].b is None)
- def testlist(self):
+ def test_list(self):
class User(object):pass
class Address(object):pass
manager = attributes.AttributeManager()
+ manager.register_class(User)
+ manager.register_class(Address)
manager.register_attribute(User, 'user_id', uselist = False)
manager.register_attribute(User, 'user_name', uselist = False)
manager.register_attribute(User, 'addresses', uselist = True)
@@ -138,10 +146,12 @@ class AttributesTest(PersistTest):
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
self.assert_(len(manager.get_history(u, 'addresses').unchanged_items()) == 1)
- def testbackref(self):
+ def test_backref(self):
class Student(object):pass
class Course(object):pass
manager = attributes.AttributeManager()
+ manager.register_class(Student)
+ manager.register_class(Course)
manager.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'))
manager.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'))
@@ -166,7 +176,9 @@ class AttributesTest(PersistTest):
self.assert_(c.students == [s2,s3])
class Post(object):pass
class Blog(object):pass
-
+
+ manager.register_class(Post)
+ manager.register_class(Blog)
manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True)
manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True)
b = Blog()
@@ -190,6 +202,8 @@ class AttributesTest(PersistTest):
class Port(object):pass
class Jack(object):pass
+ manager.register_class(Port)
+ manager.register_class(Jack)
manager.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'))
manager.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'))
p = Port()
@@ -201,13 +215,15 @@ class AttributesTest(PersistTest):
j.port = None
self.assert_(p.jack is None)
- def testlazytrackparent(self):
+ def test_lazytrackparent(self):
"""test that the "hasparent" flag works properly when lazy loaders and backrefs are used"""
manager = attributes.AttributeManager()
class Post(object):pass
class Blog(object):pass
-
+ manager.register_class(Post)
+ manager.register_class(Blog)
+
# set up instrumented attributes with backrefs
manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True)
manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True)
@@ -234,12 +250,14 @@ class AttributesTest(PersistTest):
assert getattr(Blog, 'posts').hasparent(p2)
assert getattr(Post, 'blog').hasparent(b2)
- def testinheritance(self):
+ def test_inheritance(self):
"""tests that attributes are polymorphic"""
class Foo(object):pass
class Bar(Foo):pass
manager = attributes.AttributeManager()
+ manager.register_class(Foo)
+ manager.register_class(Bar)
def func1():
print "func1"
@@ -261,12 +279,14 @@ class AttributesTest(PersistTest):
assert x.element2 == 'this is the shared attr'
assert y.element2 == 'this is the shared attr'
- def testinheritance2(self):
+ def test_inheritance2(self):
"""test that the attribute manager can properly traverse the managed attributes of an object,
if the object is of a descendant class with managed attributes in the parent class"""
class Foo(object):pass
class Bar(Foo):pass
manager = attributes.AttributeManager()
+ manager.register_class(Foo)
+ manager.register_class(Bar)
manager.register_attribute(Foo, 'element', uselist=False)
x = Bar()
x.element = 'this is the element'
@@ -277,7 +297,7 @@ class AttributesTest(PersistTest):
assert hist.added_items() == []
assert hist.unchanged_items() == ['this is the element']
- def testlazyhistory(self):
+ def test_lazyhistory(self):
"""tests that history functions work with lazy-loading attributes"""
class Foo(object):pass
class Bar(object):
@@ -287,6 +307,8 @@ class AttributesTest(PersistTest):
return "Bar: id %d" % self.id
manager = attributes.AttributeManager()
+ manager.register_class(Foo)
+ manager.register_class(Bar)
def func1():
return "this is func 1"
@@ -305,11 +327,13 @@ class AttributesTest(PersistTest):
print h.unchanged_items()
- def testparenttrack(self):
+ def test_parenttrack(self):
class Foo(object):pass
class Bar(object):pass
manager = attributes.AttributeManager()
+ manager.register_class(Foo)
+ manager.register_class(Bar)
manager.register_attribute(Foo, 'element', uselist=False, trackparent=True)
manager.register_attribute(Bar, 'element', uselist=False, trackparent=True)
@@ -330,10 +354,11 @@ class AttributesTest(PersistTest):
b2.element = None
assert not getattr(Bar, 'element').hasparent(f2)
- def testmutablescalars(self):
+ def test_mutablescalars(self):
"""test detection of changes on mutable scalar items"""
class Foo(object):pass
manager = attributes.AttributeManager()
+ manager.register_class(Foo)
manager.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True)
x = Foo()
x.element = ['one', 'two', 'three']
@@ -341,8 +366,9 @@ class AttributesTest(PersistTest):
x.element[1] = 'five'
assert manager.is_modified(x)
- manager.reset_class_managed(Foo)
+ manager.unregister_class(Foo)
manager = attributes.AttributeManager()
+ manager.register_class(Foo)
manager.register_attribute(Foo, 'element', uselist=False)
x = Foo()
x.element = ['one', 'two', 'three']
@@ -350,7 +376,7 @@ class AttributesTest(PersistTest):
x.element[1] = 'five'
assert not manager.is_modified(x)
- def testdescriptorattributes(self):
+ def test_descriptorattributes(self):
"""changeset: 1633 broke ability to use ORM to map classes with unusual
descriptor attributes (for example, classes that inherit from ones
implementing zope.interface.Interface).
@@ -363,11 +389,12 @@ class AttributesTest(PersistTest):
A = des()
manager = attributes.AttributeManager()
- manager.reset_class_managed(Foo)
+ manager.unregister_class(Foo)
- def testcollectionclasses(self):
+ def test_collectionclasses(self):
manager = attributes.AttributeManager()
class Foo(object):pass
+ manager.register_class(Foo)
manager.register_attribute(Foo, "collection", uselist=True, typecallable=set)
assert isinstance(Foo().collection, set)
diff --git a/test/orm/collection.py b/test/orm/collection.py
index 0cc8cf7e0..9d5ae7ab9 100644
--- a/test/orm/collection.py
+++ b/test/orm/collection.py
@@ -36,6 +36,7 @@ class Entity(object):
return str((id(self), self.a, self.b, self.c))
manager = attributes.AttributeManager()
+manager.register_class(Entity)
_id = 1
def entity_maker():
@@ -55,6 +56,7 @@ class CollectionsTest(PersistTest):
pass
canary = Canary()
+ manager.register_class(Foo)
manager.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable)
@@ -92,6 +94,7 @@ class CollectionsTest(PersistTest):
pass
canary = Canary()
+ manager.register_class(Foo)
manager.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable)
@@ -233,6 +236,7 @@ class CollectionsTest(PersistTest):
pass
canary = Canary()
+ manager.register_class(Foo)
manager.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable)
@@ -341,6 +345,7 @@ class CollectionsTest(PersistTest):
pass
canary = Canary()
+ manager.register_class(Foo)
manager.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable)
@@ -473,6 +478,7 @@ class CollectionsTest(PersistTest):
pass
canary = Canary()
+ manager.register_class(Foo)
manager.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable)
@@ -577,6 +583,7 @@ class CollectionsTest(PersistTest):
pass
canary = Canary()
+ manager.register_class(Foo)
manager.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable)
@@ -694,6 +701,7 @@ class CollectionsTest(PersistTest):
pass
canary = Canary()
+ manager.register_class(Foo)
manager.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable)
@@ -868,6 +876,7 @@ class CollectionsTest(PersistTest):
pass
canary = Canary()
+ manager.register_class(Foo)
manager.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable)
@@ -1001,6 +1010,7 @@ class CollectionsTest(PersistTest):
class Foo(object):
pass
canary = Canary()
+ manager.register_class(Foo)
manager.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=Custom)
@@ -1070,6 +1080,7 @@ class CollectionsTest(PersistTest):
canary = Canary()
creator = entity_maker
+ manager.register_class(Foo)
manager.register_attribute(Foo, 'attr', True, extension=canary)
obj = Foo()
diff --git a/test/perf/masseagerload.py b/test/perf/masseagerload.py
index ad438c1fa..38696e85b 100644
--- a/test/perf/masseagerload.py
+++ b/test/perf/masseagerload.py
@@ -35,7 +35,7 @@ def load():
#print l
subitems.insert().execute(*l)
-@profiling.profiled('masseagerload', always=True)
+@profiling.profiled('masseagerload', always=True, sort=['cumulative'])
def masseagerload(session):
query = session.query(Item)
l = query.select()