diff options
-rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 248 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/dynamic.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 49 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 15 | ||||
-rw-r--r-- | test/orm/attributes.py | 61 | ||||
-rw-r--r-- | test/orm/collection.py | 11 | ||||
-rw-r--r-- | test/perf/masseagerload.py | 2 |
7 files changed, 216 insertions, 172 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index f369c5396..7290e2ac2 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -84,13 +84,13 @@ class InstrumentedAttribute(interfaces.PropComparator): return self.get(obj) def commit_to_state(self, state, obj, value=NO_VALUE): - """commit the a copy of thte value of 'obj' to the given CommittedState""" - + """commit the object's current state to its 'committed' state.""" + if value is NO_VALUE: if self.key in obj.__dict__: value = obj.__dict__[self.key] if value is not NO_VALUE: - state.data[self.key] = self.copy(value) + state.committed_state[self.key] = self.copy(value) def clause_element(self): return self.comparator.clause_element() @@ -119,7 +119,7 @@ class InstrumentedAttribute(interfaces.PropComparator): will also not have a `hasparent` flag. """ - return item._state.get(('hasparent', id(self)), optimistic) + return item._state.parents.get(id(self), optimistic) def sethasparent(self, item, value): """Set a boolean flag on the given item corresponding to @@ -127,7 +127,7 @@ class InstrumentedAttribute(interfaces.PropComparator): attribute represented by this ``InstrumentedAttribute``. """ - item._state[('hasparent', id(self))] = value + item._state.parents[id(self)] = value def get_history(self, obj, passive=False): """Return a new ``AttributeHistory`` object for the given object/this attribute's key. @@ -165,11 +165,11 @@ class InstrumentedAttribute(interfaces.PropComparator): if callable_ is None: self.initialize(obj) else: - obj._state[('callable', self)] = callable_ + obj._state.callables[self] = callable_ def _get_callable(self, obj): - if ('callable', self) in obj._state: - return obj._state[('callable', self)] + if self in obj._state.callables: + return obj._state.callables[self] elif self.callable_ is not None: return self.callable_(obj) else: @@ -183,7 +183,7 @@ class InstrumentedAttribute(interfaces.PropComparator): """ try: - del obj._state[('callable', self)] + del obj._state.callables[self] except KeyError: pass self.clear(obj) @@ -223,10 +223,8 @@ class InstrumentedAttribute(interfaces.PropComparator): state = obj._state # if an instance-wide "trigger" was set, call that # and start again - if 'trigger' in state: - trig = state['trigger'] - del state['trigger'] - trig() + if state.trigger: + state.call_trigger() return self.get(obj, passive=passive) callable_ = self._get_callable(obj) @@ -265,11 +263,10 @@ class InstrumentedAttribute(interfaces.PropComparator): """ state = obj._state - orig = state.get('original', None) - if orig is not None: - self.commit_to_state(orig, obj, value) + if state.committed_state is not None: + self.commit_to_state(state, obj, value) # remove per-instance callable, if any - state.pop(('callable', self), None) + state.callables.pop(self, None) obj.__dict__[self.key] = value return value @@ -278,21 +275,21 @@ class InstrumentedAttribute(interfaces.PropComparator): return value def fire_append_event(self, obj, value, initiator): - obj._state['modified'] = True + obj._state.modified = True if self.trackparent and value is not None: self.sethasparent(value, True) for ext in self.extensions: ext.append(obj, value, initiator or self) def fire_remove_event(self, obj, value, initiator): - obj._state['modified'] = True + obj._state.modified = True if self.trackparent and value is not None: self.sethasparent(value, False) for ext in self.extensions: ext.remove(obj, value, initiator or self) def fire_replace_event(self, obj, value, previous, initiator): - obj._state['modified'] = True + obj._state.modified = True if self.trackparent: if value is not None: self.sethasparent(value, True) @@ -334,7 +331,7 @@ class InstrumentedScalarAttribute(InstrumentedAttribute): if self.mutable_scalars: h = self.get_history(obj, passive=True) if h is not None and h.is_modified(): - obj._state['modified'] = True + obj._state.modified = True return True else: return False @@ -354,10 +351,8 @@ class InstrumentedScalarAttribute(InstrumentedAttribute): state = obj._state # if an instance-wide "trigger" was set, call that - if 'trigger' in state: - trig = state['trigger'] - del state['trigger'] - trig() + if state.trigger: + state.call_trigger() old = self.get(obj) obj.__dict__[self.key] = value @@ -415,7 +410,7 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute): if self.key not in obj.__dict__: return - obj._state['modified'] = True + obj._state.modified = True collection = self.get_collection(obj) collection.clear_with_event() @@ -453,10 +448,8 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute): state = obj._state # if an instance-wide "trigger" was set, call that - if 'trigger' in state: - trig = state['trigger'] - del state['trigger'] - trig() + if state.trigger: + state.call_trigger() old = self.get(obj) old_collection = self.get_collection(obj, old) @@ -466,7 +459,7 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute): collection=new_collection) obj.__dict__[self.key] = user_data - state['modified'] = True + state.modified = True # mark all the old elements as detached from the parent if old_collection: @@ -477,17 +470,16 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute): """Set an attribute value on the given instance and 'commit' it.""" state = obj._state - orig = state.get('original', None) collection, user_data = self._build_collection(obj) self._load_collection(obj, value or [], emit_events=False, collection=collection) value = user_data - if orig is not None: - self.commit_to_state(orig, obj, value) + if state.committed_state is not None: + self.commit_to_state(state, obj, value) # remove per-instance callable, if any - state.pop(('callable', self), None) + state.callables.pop(self, None) obj.__dict__[self.key] = value return value @@ -543,38 +535,57 @@ class GenericBackrefExtension(interfaces.AttributeExtension): def remove(self, obj, child, initiator): getattr(child.__class__, self.key).remove(child, obj, initiator) -class CommittedState(object): - """Store the original state of an object when the ``commit()` - method on the attribute manager is called. - """ - - - def __init__(self, manager, obj): - self.data = {} +class InstanceState(object): + """tracks state information at the instance level.""" + + def __init__(self, obj): + self.committed_state = None + self.modified = False + self.trigger = None + self.callables = {} + self.parents = {} + + def __getstate__(self): + return {'committed_state':self.committed_state, 'parents':self.parents, 'modified':self.modified} + + def __setstate__(self, state): + self.committed_state = state['committed_state'] + self.parents = state['parents'] + self.modified = state['modified'] + self.callables = {} + self.trigger = None + + def call_trigger(self): + trig = self.trigger + self.trigger = None + trig() + + def commit(self, manager, obj): + self.committed_state = {} + self.modified = False for attr in manager.managed_attributes(obj.__class__): attr.commit_to_state(self, obj) def rollback(self, manager, obj): - for attr in manager.managed_attributes(obj.__class__): - if attr.key in self.data: - if not hasattr(attr, 'get_collection'): - obj.__dict__[attr.key] = self.data[attr.key] + if not self.committed_state: + manager._clear(obj) + else: + for attr in manager.managed_attributes(obj.__class__): + if attr.key in self.committed_state: + if not hasattr(attr, 'get_collection'): + obj.__dict__[attr.key] = self.committed_state[attr.key] + else: + collection = attr.get_collection(obj) + collection.clear_without_event() + for item in self.committed_state[attr.key]: + collection.append_without_event(item) else: - collection = attr.get_collection(obj) - collection.clear_without_event() - for item in self.data[attr.key]: - collection.append_without_event(item) - else: - if attr.key in obj.__dict__: - del obj.__dict__[attr.key] - - def __repr__(self): - return "CommittedState: %s" % repr(self.data) + if attr.key in obj.__dict__: + del obj.__dict__[attr.key] class AttributeHistory(object): """Calculate the *history* of a particular attribute on a - particular instance, based on the ``CommittedState`` associated - with the instance, if any. + particular instance. """ def __init__(self, attr, obj, current, passive=False): @@ -583,9 +594,8 @@ class AttributeHistory(object): # get the "original" value. if a lazy load was fired when we got # the 'current' value, this "original" was also populated just # now as well (therefore we have to get it second) - orig = obj._state.get('original', None) - if orig is not None: - original = orig.data.get(attr.key) + if obj._state.committed_state: + original = obj._state.committed_state.get(attr.key, None) else: original = None @@ -652,11 +662,7 @@ class AttributeManager(object): """ for o in obj: - orig = o._state.get('original') - if orig is not None: - orig.rollback(self, o) - else: - self._clear(o) + o._state.rollback(self, o) def _clear(self, obj): for attr in self.managed_attributes(obj.__class__): @@ -664,19 +670,12 @@ class AttributeManager(object): del obj.__dict__[attr.key] except KeyError: pass - + def commit(self, *obj): - """Create a ``CommittedState`` instance for each object in the given list, representing - its *unchanged* state, and associates it with the instance. - - ``AttributeHistory`` objects will indicate the modified state of - instance attributes as compared to its value in this - ``CommittedState`` object. - """ + """Establish the "committed state" for each object in the given list.""" for o in obj: - o._state['original'] = CommittedState(self, o) - o._state['modified'] = False + o._state.commit(self, o) def managed_attributes(self, class_): """Return a list of all ``InstrumentedAttribute`` objects @@ -706,7 +705,7 @@ class AttributeManager(object): for attr in self.managed_attributes(object.__class__): if attr.check_mutable_modified(object): return True - return object._state.get('modified', False) + return object._state.modified def get_history(self, obj, key, **kwargs): """Return a new ``AttributeHistory`` object for the given @@ -743,12 +742,10 @@ class AttributeManager(object): removed. """ + s = obj._state self._clear(obj) - try: - del obj._state['original'] - except KeyError: - pass - obj._state['trigger'] = callable + s.committed_state = None + s.trigger = callable def untrigger_history(self, obj): """Remove a trigger function set by trigger_history. @@ -756,14 +753,14 @@ class AttributeManager(object): Does not restore the previous state of the object. """ - del obj._state['trigger'] + obj._state.trigger = None def has_trigger(self, obj): """Return True if the given object has a trigger function set by ``trigger_history()``. """ - return 'trigger' in obj._state + return obj._state.trigger is not None def reset_instance_attribute(self, obj, key): """Remove any per-instance callable functions corresponding to @@ -774,16 +771,6 @@ class AttributeManager(object): attr = getattr(obj.__class__, key) attr.reset(obj) - def reset_class_managed(self, class_): - """Remove all ``InstrumentedAttribute`` property objects from - the given class. - """ - - for attr in self.noninherited_managed_attributes(class_): - delattr(class_, attr.key) - self._inherited_attribute_cache.pop(class_,None) - self._noninherited_attribute_cache.pop(class_,None) - def is_class_managed(self, class_, key): """Return True if the given `key` correponds to an instrumented property on the given class. @@ -826,7 +813,71 @@ class AttributeManager(object): return getattr(obj_or_cls, key) else: return getattr(obj_or_cls.__class__, key) + + def manage(self, obj): + if not hasattr(obj, '_state'): + obj._state = InstanceState(obj) + + def new_instance(self, class_): + """create a new instance of class_ without its __init__() method being called.""" + + s = class_.__new__(class_) + s._state = InstanceState(s) + return s + + def register_class(self, class_, extra_init=None, on_exception=None): + """decorate the constructor of the given class to establish attribute + management on new instances.""" + oldinit = None + doinit = False + + def init(instance, *args, **kwargs): + instance._state = InstanceState(instance) + + if extra_init: + extra_init(class_, oldinit, instance, args, kwargs) + + if doinit: + try: + oldinit(instance, *args, **kwargs) + except: + if on_exception: + on_exception(class_, oldinit, instance, args, kwargs) + raise + + # override oldinit + oldinit = class_.__init__ + if oldinit is None or not hasattr(oldinit, '_oldinit'): + init._oldinit = oldinit + class_.__init__ = init + # if oldinit is already one of our 'init' methods, replace it + elif hasattr(oldinit, '_oldinit'): + init._oldinit = oldinit._oldinit + class_.__init = init + oldinit = oldinit._oldinit + + if oldinit is not None: + doinit = oldinit is not object.__init__ + try: + init.__name__ = oldinit.__name__ + init.__doc__ = oldinit.__doc__ + except: + # cant set __name__ in py 2.3 ! + pass + + def unregister_class(self, class_): + if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'): + if class_.__init__._oldinit is not None: + class_.__init__ = class_.__init__._oldinit + else: + delattr(class_, '__init__') + + for attr in self.noninherited_managed_attributes(class_): + delattr(class_, attr.key) + self._inherited_attribute_cache.pop(class_,None) + self._noninherited_attribute_cache.pop(class_,None) + def register_attribute(self, class_, key, uselist, callable_=None, **kwargs): """Register an attribute at the class level to be instrumented for all instances of the class. @@ -837,13 +888,6 @@ class AttributeManager(object): self._inherited_attribute_cache.pop(class_, None) self._noninherited_attribute_cache.pop(class_, None) - if not hasattr(class_, '_state'): - def _get_state(self): - if not hasattr(self, '_sa_attr_state'): - self._sa_attr_state = {} - return self._sa_attr_state - class_._state = property(_get_state) - typecallable = kwargs.pop('typecallable', None) if isinstance(typecallable, InstrumentedAttribute): typecallable = None diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 1d4b5f6c9..aa5105150 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -34,7 +34,7 @@ class DynamicCollectionAttribute(attributes.InstrumentedAttribute): old_collection = self.get(obj).assign(value) # TODO: emit events ??? - state['modified'] = True + state.modified = True def delete(self, *args, **kwargs): raise NotImplementedError() diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 960282255..5d495d7a9 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -198,14 +198,9 @@ class Mapper(object): def dispose(self): # disaable any attribute-based compilation self.__props_init = True - attribute_manager.reset_class_managed(self.class_) if hasattr(self.class_, 'c'): del self.class_.c - if hasattr(self.class_, '__init__') and hasattr(self.class_.__init__, '_oldinit'): - if self.class_.__init__._oldinit is not None: - self.class_.__init__ = self.class_.__init__._oldinit - else: - delattr(self.class_, '__init__') + attribute_manager.unregister_class(self.class_) def compile(self): """Compile this mapper into its final internal format. @@ -664,34 +659,14 @@ class Mapper(object): if not self.non_primary and (self.class_key in mapper_registry): raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined with entity name '%s'. Use non_primary=True to create a non primary Mapper, or to create a new primary mapper, remove this mapper first via sqlalchemy.orm.clear_mapper(mapper), or preferably sqlalchemy.orm.clear_mappers() to clear all mappers." % (self.class_, self.entity_name)) - attribute_manager.reset_class_managed(self.class_) - - oldinit = self.class_.__init__ - doinit = oldinit is not None and oldinit is not object.__init__ - - def init(instance, *args, **kwargs): + def extra_init(class_, oldinit, instance, args, kwargs): self.compile() - self.extension.init_instance(self, self.class_, oldinit, instance, args, kwargs) - - if doinit: - try: - oldinit(instance, *args, **kwargs) - except: - # call init_failed but suppress exceptions into warnings so that original __init__ - # exception is raised - util.warn_exception(self.extension.init_failed, self, self.class_, oldinit, instance, args, kwargs) - raise - - # override oldinit, ensuring that its not already a Mapper-decorated init method - if oldinit is None or not hasattr(oldinit, '_oldinit'): - try: - init.__name__ = oldinit.__name__ - init.__doc__ = oldinit.__doc__ - except: - # cant set __name__ in py 2.3 ! - pass - init._oldinit = oldinit - self.class_.__init__ = init + self.extension.init_instance(self, class_, oldinit, instance, args, kwargs) + + def on_exception(class_, oldinit, instance, args, kwargs): + util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs) + + attribute_manager.register_class(self.class_, extra_init=extra_init, on_exception=on_exception) _COMPILE_MUTEX.acquire() try: @@ -1436,7 +1411,7 @@ class Mapper(object): # plugin point instance = extension.create_instance(self, context, row, self.class_) if instance is EXT_CONTINUE: - instance = self._create_instance(context.session) + instance = attribute_manager.new_instance(self.class_) instance._entity_name = self.entity_name if self.__should_log_debug: self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey))) @@ -1459,12 +1434,6 @@ class Mapper(object): return instance - def _create_instance(self, session): - obj = self.class_.__new__(self.class_) - obj._entity_name = self.entity_name - - return obj - def _deferred_inheritance_condition(self, needs_tables): cond = self.inherit_condition diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index ebd4bd3d3..b616570ab 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -847,7 +847,7 @@ class Session(object): try: key = getattr(object, '_instance_key', None) if key is None: - merged = mapper.class_.__new__(mapper.class_) + merged = attribute_manager.new_instance(mapper.class_) else: if key in self.identity_map: merged = self.identity_map[key] @@ -940,16 +940,9 @@ class Session(object): "or is already persistent in a " "different Session" % repr(obj)) else: - m = _class_mapper(obj.__class__, entity_name=kwargs.get('entity_name', None)) - - # this would be a nice exception to raise...however this is incompatible with a contextual - # session which puts all objects into the session upon construction. - #if m._is_orphan(object): - # raise exceptions.InvalidRequestError("Instance '%s' is an orphan, " - # "and must be attached to a parent " - # "object to be saved" % (repr(object))) - - m._assign_entity_name(obj) + # TODO: consolidate the steps here + attribute_manager.manage(obj) + obj._entity_name = kwargs.get('entity_name', None) self._attach(obj) self.uow.register_new(obj) diff --git a/test/orm/attributes.py b/test/orm/attributes.py index 8ca2d1b8e..6314656b9 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -5,14 +5,17 @@ from sqlalchemy.orm.collections import collection from sqlalchemy import exceptions from testlib import * +# these test classes defined at the module +# level to support pickling class MyTest(object):pass class MyTest2(object):pass - + class AttributesTest(PersistTest): """tests for the attributes.py module, which deals with tracking attribute changes on an object.""" - def testbasic(self): + def test_basic(self): class User(object):pass manager = attributes.AttributeManager() + manager.register_class(User) manager.register_attribute(User, 'user_id', uselist = False) manager.register_attribute(User, 'user_name', uselist = False) manager.register_attribute(User, 'email_address', uselist = False) @@ -39,8 +42,11 @@ class AttributesTest(PersistTest): print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') - def testpickleness(self): + def test_pickleness(self): + manager = attributes.AttributeManager() + manager.register_class(MyTest) + manager.register_class(MyTest2) manager.register_attribute(MyTest, 'user_id', uselist = False) manager.register_attribute(MyTest, 'user_name', uselist = False) manager.register_attribute(MyTest, 'email_address', uselist = False) @@ -97,10 +103,12 @@ class AttributesTest(PersistTest): self.assert_(o4.mt2[0].a == 'abcde') self.assert_(o4.mt2[0].b is None) - def testlist(self): + def test_list(self): class User(object):pass class Address(object):pass manager = attributes.AttributeManager() + manager.register_class(User) + manager.register_class(Address) manager.register_attribute(User, 'user_id', uselist = False) manager.register_attribute(User, 'user_name', uselist = False) manager.register_attribute(User, 'addresses', uselist = True) @@ -138,10 +146,12 @@ class AttributesTest(PersistTest): self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') self.assert_(len(manager.get_history(u, 'addresses').unchanged_items()) == 1) - def testbackref(self): + def test_backref(self): class Student(object):pass class Course(object):pass manager = attributes.AttributeManager() + manager.register_class(Student) + manager.register_class(Course) manager.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students')) manager.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses')) @@ -166,7 +176,9 @@ class AttributesTest(PersistTest): self.assert_(c.students == [s2,s3]) class Post(object):pass class Blog(object):pass - + + manager.register_class(Post) + manager.register_class(Blog) manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True) manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True) b = Blog() @@ -190,6 +202,8 @@ class AttributesTest(PersistTest): class Port(object):pass class Jack(object):pass + manager.register_class(Port) + manager.register_class(Jack) manager.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port')) manager.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack')) p = Port() @@ -201,13 +215,15 @@ class AttributesTest(PersistTest): j.port = None self.assert_(p.jack is None) - def testlazytrackparent(self): + def test_lazytrackparent(self): """test that the "hasparent" flag works properly when lazy loaders and backrefs are used""" manager = attributes.AttributeManager() class Post(object):pass class Blog(object):pass - + manager.register_class(Post) + manager.register_class(Blog) + # set up instrumented attributes with backrefs manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True) manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True) @@ -234,12 +250,14 @@ class AttributesTest(PersistTest): assert getattr(Blog, 'posts').hasparent(p2) assert getattr(Post, 'blog').hasparent(b2) - def testinheritance(self): + def test_inheritance(self): """tests that attributes are polymorphic""" class Foo(object):pass class Bar(Foo):pass manager = attributes.AttributeManager() + manager.register_class(Foo) + manager.register_class(Bar) def func1(): print "func1" @@ -261,12 +279,14 @@ class AttributesTest(PersistTest): assert x.element2 == 'this is the shared attr' assert y.element2 == 'this is the shared attr' - def testinheritance2(self): + def test_inheritance2(self): """test that the attribute manager can properly traverse the managed attributes of an object, if the object is of a descendant class with managed attributes in the parent class""" class Foo(object):pass class Bar(Foo):pass manager = attributes.AttributeManager() + manager.register_class(Foo) + manager.register_class(Bar) manager.register_attribute(Foo, 'element', uselist=False) x = Bar() x.element = 'this is the element' @@ -277,7 +297,7 @@ class AttributesTest(PersistTest): assert hist.added_items() == [] assert hist.unchanged_items() == ['this is the element'] - def testlazyhistory(self): + def test_lazyhistory(self): """tests that history functions work with lazy-loading attributes""" class Foo(object):pass class Bar(object): @@ -287,6 +307,8 @@ class AttributesTest(PersistTest): return "Bar: id %d" % self.id manager = attributes.AttributeManager() + manager.register_class(Foo) + manager.register_class(Bar) def func1(): return "this is func 1" @@ -305,11 +327,13 @@ class AttributesTest(PersistTest): print h.unchanged_items() - def testparenttrack(self): + def test_parenttrack(self): class Foo(object):pass class Bar(object):pass manager = attributes.AttributeManager() + manager.register_class(Foo) + manager.register_class(Bar) manager.register_attribute(Foo, 'element', uselist=False, trackparent=True) manager.register_attribute(Bar, 'element', uselist=False, trackparent=True) @@ -330,10 +354,11 @@ class AttributesTest(PersistTest): b2.element = None assert not getattr(Bar, 'element').hasparent(f2) - def testmutablescalars(self): + def test_mutablescalars(self): """test detection of changes on mutable scalar items""" class Foo(object):pass manager = attributes.AttributeManager() + manager.register_class(Foo) manager.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True) x = Foo() x.element = ['one', 'two', 'three'] @@ -341,8 +366,9 @@ class AttributesTest(PersistTest): x.element[1] = 'five' assert manager.is_modified(x) - manager.reset_class_managed(Foo) + manager.unregister_class(Foo) manager = attributes.AttributeManager() + manager.register_class(Foo) manager.register_attribute(Foo, 'element', uselist=False) x = Foo() x.element = ['one', 'two', 'three'] @@ -350,7 +376,7 @@ class AttributesTest(PersistTest): x.element[1] = 'five' assert not manager.is_modified(x) - def testdescriptorattributes(self): + def test_descriptorattributes(self): """changeset: 1633 broke ability to use ORM to map classes with unusual descriptor attributes (for example, classes that inherit from ones implementing zope.interface.Interface). @@ -363,11 +389,12 @@ class AttributesTest(PersistTest): A = des() manager = attributes.AttributeManager() - manager.reset_class_managed(Foo) + manager.unregister_class(Foo) - def testcollectionclasses(self): + def test_collectionclasses(self): manager = attributes.AttributeManager() class Foo(object):pass + manager.register_class(Foo) manager.register_attribute(Foo, "collection", uselist=True, typecallable=set) assert isinstance(Foo().collection, set) diff --git a/test/orm/collection.py b/test/orm/collection.py index 0cc8cf7e0..9d5ae7ab9 100644 --- a/test/orm/collection.py +++ b/test/orm/collection.py @@ -36,6 +36,7 @@ class Entity(object): return str((id(self), self.a, self.b, self.c)) manager = attributes.AttributeManager() +manager.register_class(Entity) _id = 1 def entity_maker(): @@ -55,6 +56,7 @@ class CollectionsTest(PersistTest): pass canary = Canary() + manager.register_class(Foo) manager.register_attribute(Foo, 'attr', True, extension=canary, typecallable=typecallable) @@ -92,6 +94,7 @@ class CollectionsTest(PersistTest): pass canary = Canary() + manager.register_class(Foo) manager.register_attribute(Foo, 'attr', True, extension=canary, typecallable=typecallable) @@ -233,6 +236,7 @@ class CollectionsTest(PersistTest): pass canary = Canary() + manager.register_class(Foo) manager.register_attribute(Foo, 'attr', True, extension=canary, typecallable=typecallable) @@ -341,6 +345,7 @@ class CollectionsTest(PersistTest): pass canary = Canary() + manager.register_class(Foo) manager.register_attribute(Foo, 'attr', True, extension=canary, typecallable=typecallable) @@ -473,6 +478,7 @@ class CollectionsTest(PersistTest): pass canary = Canary() + manager.register_class(Foo) manager.register_attribute(Foo, 'attr', True, extension=canary, typecallable=typecallable) @@ -577,6 +583,7 @@ class CollectionsTest(PersistTest): pass canary = Canary() + manager.register_class(Foo) manager.register_attribute(Foo, 'attr', True, extension=canary, typecallable=typecallable) @@ -694,6 +701,7 @@ class CollectionsTest(PersistTest): pass canary = Canary() + manager.register_class(Foo) manager.register_attribute(Foo, 'attr', True, extension=canary, typecallable=typecallable) @@ -868,6 +876,7 @@ class CollectionsTest(PersistTest): pass canary = Canary() + manager.register_class(Foo) manager.register_attribute(Foo, 'attr', True, extension=canary, typecallable=typecallable) @@ -1001,6 +1010,7 @@ class CollectionsTest(PersistTest): class Foo(object): pass canary = Canary() + manager.register_class(Foo) manager.register_attribute(Foo, 'attr', True, extension=canary, typecallable=Custom) @@ -1070,6 +1080,7 @@ class CollectionsTest(PersistTest): canary = Canary() creator = entity_maker + manager.register_class(Foo) manager.register_attribute(Foo, 'attr', True, extension=canary) obj = Foo() diff --git a/test/perf/masseagerload.py b/test/perf/masseagerload.py index ad438c1fa..38696e85b 100644 --- a/test/perf/masseagerload.py +++ b/test/perf/masseagerload.py @@ -35,7 +35,7 @@ def load(): #print l subitems.insert().execute(*l) -@profiling.profiled('masseagerload', always=True) +@profiling.profiled('masseagerload', always=True, sort=['cumulative']) def masseagerload(session): query = session.query(Item) l = query.select() |