summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-09-05 17:25:32 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-09-05 17:25:32 +0000
commit9717c7170917856e07331796d8439d30ac93006e (patch)
tree8e4da9bec57d973a3ba00496e807272acae5ba37 /lib/sqlalchemy
parent43a927f6e03e1c87c4ea9f1ffe6dce6d794ccdda (diff)
downloadsqlalchemy-9717c7170917856e07331796d8439d30ac93006e.tar.gz
merged current entity_management brach r3457-r3462. cleans up
'_state' mamangement in attributes, moves __init__() instrumntation into attributes.py, and reduces method call overhead by removing '_state' property. future enhancements may include _state maintaining a weakref to the instance and a strong ref to its __dict__ so that garbage-collected instances can get added to 'dirty', when weak-referenced identity map is used.
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/attributes.py248
-rw-r--r--lib/sqlalchemy/orm/dynamic.py2
-rw-r--r--lib/sqlalchemy/orm/mapper.py49
-rw-r--r--lib/sqlalchemy/orm/session.py15
4 files changed, 160 insertions, 154 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index f369c5396..7290e2ac2 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -84,13 +84,13 @@ class InstrumentedAttribute(interfaces.PropComparator):
return self.get(obj)
def commit_to_state(self, state, obj, value=NO_VALUE):
- """commit the a copy of thte value of 'obj' to the given CommittedState"""
-
+ """commit the object's current state to its 'committed' state."""
+
if value is NO_VALUE:
if self.key in obj.__dict__:
value = obj.__dict__[self.key]
if value is not NO_VALUE:
- state.data[self.key] = self.copy(value)
+ state.committed_state[self.key] = self.copy(value)
def clause_element(self):
return self.comparator.clause_element()
@@ -119,7 +119,7 @@ class InstrumentedAttribute(interfaces.PropComparator):
will also not have a `hasparent` flag.
"""
- return item._state.get(('hasparent', id(self)), optimistic)
+ return item._state.parents.get(id(self), optimistic)
def sethasparent(self, item, value):
"""Set a boolean flag on the given item corresponding to
@@ -127,7 +127,7 @@ class InstrumentedAttribute(interfaces.PropComparator):
attribute represented by this ``InstrumentedAttribute``.
"""
- item._state[('hasparent', id(self))] = value
+ item._state.parents[id(self)] = value
def get_history(self, obj, passive=False):
"""Return a new ``AttributeHistory`` object for the given object/this attribute's key.
@@ -165,11 +165,11 @@ class InstrumentedAttribute(interfaces.PropComparator):
if callable_ is None:
self.initialize(obj)
else:
- obj._state[('callable', self)] = callable_
+ obj._state.callables[self] = callable_
def _get_callable(self, obj):
- if ('callable', self) in obj._state:
- return obj._state[('callable', self)]
+ if self in obj._state.callables:
+ return obj._state.callables[self]
elif self.callable_ is not None:
return self.callable_(obj)
else:
@@ -183,7 +183,7 @@ class InstrumentedAttribute(interfaces.PropComparator):
"""
try:
- del obj._state[('callable', self)]
+ del obj._state.callables[self]
except KeyError:
pass
self.clear(obj)
@@ -223,10 +223,8 @@ class InstrumentedAttribute(interfaces.PropComparator):
state = obj._state
# if an instance-wide "trigger" was set, call that
# and start again
- if 'trigger' in state:
- trig = state['trigger']
- del state['trigger']
- trig()
+ if state.trigger:
+ state.call_trigger()
return self.get(obj, passive=passive)
callable_ = self._get_callable(obj)
@@ -265,11 +263,10 @@ class InstrumentedAttribute(interfaces.PropComparator):
"""
state = obj._state
- orig = state.get('original', None)
- if orig is not None:
- self.commit_to_state(orig, obj, value)
+ if state.committed_state is not None:
+ self.commit_to_state(state, obj, value)
# remove per-instance callable, if any
- state.pop(('callable', self), None)
+ state.callables.pop(self, None)
obj.__dict__[self.key] = value
return value
@@ -278,21 +275,21 @@ class InstrumentedAttribute(interfaces.PropComparator):
return value
def fire_append_event(self, obj, value, initiator):
- obj._state['modified'] = True
+ obj._state.modified = True
if self.trackparent and value is not None:
self.sethasparent(value, True)
for ext in self.extensions:
ext.append(obj, value, initiator or self)
def fire_remove_event(self, obj, value, initiator):
- obj._state['modified'] = True
+ obj._state.modified = True
if self.trackparent and value is not None:
self.sethasparent(value, False)
for ext in self.extensions:
ext.remove(obj, value, initiator or self)
def fire_replace_event(self, obj, value, previous, initiator):
- obj._state['modified'] = True
+ obj._state.modified = True
if self.trackparent:
if value is not None:
self.sethasparent(value, True)
@@ -334,7 +331,7 @@ class InstrumentedScalarAttribute(InstrumentedAttribute):
if self.mutable_scalars:
h = self.get_history(obj, passive=True)
if h is not None and h.is_modified():
- obj._state['modified'] = True
+ obj._state.modified = True
return True
else:
return False
@@ -354,10 +351,8 @@ class InstrumentedScalarAttribute(InstrumentedAttribute):
state = obj._state
# if an instance-wide "trigger" was set, call that
- if 'trigger' in state:
- trig = state['trigger']
- del state['trigger']
- trig()
+ if state.trigger:
+ state.call_trigger()
old = self.get(obj)
obj.__dict__[self.key] = value
@@ -415,7 +410,7 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute):
if self.key not in obj.__dict__:
return
- obj._state['modified'] = True
+ obj._state.modified = True
collection = self.get_collection(obj)
collection.clear_with_event()
@@ -453,10 +448,8 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute):
state = obj._state
# if an instance-wide "trigger" was set, call that
- if 'trigger' in state:
- trig = state['trigger']
- del state['trigger']
- trig()
+ if state.trigger:
+ state.call_trigger()
old = self.get(obj)
old_collection = self.get_collection(obj, old)
@@ -466,7 +459,7 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute):
collection=new_collection)
obj.__dict__[self.key] = user_data
- state['modified'] = True
+ state.modified = True
# mark all the old elements as detached from the parent
if old_collection:
@@ -477,17 +470,16 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute):
"""Set an attribute value on the given instance and 'commit' it."""
state = obj._state
- orig = state.get('original', None)
collection, user_data = self._build_collection(obj)
self._load_collection(obj, value or [], emit_events=False,
collection=collection)
value = user_data
- if orig is not None:
- self.commit_to_state(orig, obj, value)
+ if state.committed_state is not None:
+ self.commit_to_state(state, obj, value)
# remove per-instance callable, if any
- state.pop(('callable', self), None)
+ state.callables.pop(self, None)
obj.__dict__[self.key] = value
return value
@@ -543,38 +535,57 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
def remove(self, obj, child, initiator):
getattr(child.__class__, self.key).remove(child, obj, initiator)
-class CommittedState(object):
- """Store the original state of an object when the ``commit()`
- method on the attribute manager is called.
- """
-
-
- def __init__(self, manager, obj):
- self.data = {}
+class InstanceState(object):
+ """tracks state information at the instance level."""
+
+ def __init__(self, obj):
+ self.committed_state = None
+ self.modified = False
+ self.trigger = None
+ self.callables = {}
+ self.parents = {}
+
+ def __getstate__(self):
+ return {'committed_state':self.committed_state, 'parents':self.parents, 'modified':self.modified}
+
+ def __setstate__(self, state):
+ self.committed_state = state['committed_state']
+ self.parents = state['parents']
+ self.modified = state['modified']
+ self.callables = {}
+ self.trigger = None
+
+ def call_trigger(self):
+ trig = self.trigger
+ self.trigger = None
+ trig()
+
+ def commit(self, manager, obj):
+ self.committed_state = {}
+ self.modified = False
for attr in manager.managed_attributes(obj.__class__):
attr.commit_to_state(self, obj)
def rollback(self, manager, obj):
- for attr in manager.managed_attributes(obj.__class__):
- if attr.key in self.data:
- if not hasattr(attr, 'get_collection'):
- obj.__dict__[attr.key] = self.data[attr.key]
+ if not self.committed_state:
+ manager._clear(obj)
+ else:
+ for attr in manager.managed_attributes(obj.__class__):
+ if attr.key in self.committed_state:
+ if not hasattr(attr, 'get_collection'):
+ obj.__dict__[attr.key] = self.committed_state[attr.key]
+ else:
+ collection = attr.get_collection(obj)
+ collection.clear_without_event()
+ for item in self.committed_state[attr.key]:
+ collection.append_without_event(item)
else:
- collection = attr.get_collection(obj)
- collection.clear_without_event()
- for item in self.data[attr.key]:
- collection.append_without_event(item)
- else:
- if attr.key in obj.__dict__:
- del obj.__dict__[attr.key]
-
- def __repr__(self):
- return "CommittedState: %s" % repr(self.data)
+ if attr.key in obj.__dict__:
+ del obj.__dict__[attr.key]
class AttributeHistory(object):
"""Calculate the *history* of a particular attribute on a
- particular instance, based on the ``CommittedState`` associated
- with the instance, if any.
+ particular instance.
"""
def __init__(self, attr, obj, current, passive=False):
@@ -583,9 +594,8 @@ class AttributeHistory(object):
# get the "original" value. if a lazy load was fired when we got
# the 'current' value, this "original" was also populated just
# now as well (therefore we have to get it second)
- orig = obj._state.get('original', None)
- if orig is not None:
- original = orig.data.get(attr.key)
+ if obj._state.committed_state:
+ original = obj._state.committed_state.get(attr.key, None)
else:
original = None
@@ -652,11 +662,7 @@ class AttributeManager(object):
"""
for o in obj:
- orig = o._state.get('original')
- if orig is not None:
- orig.rollback(self, o)
- else:
- self._clear(o)
+ o._state.rollback(self, o)
def _clear(self, obj):
for attr in self.managed_attributes(obj.__class__):
@@ -664,19 +670,12 @@ class AttributeManager(object):
del obj.__dict__[attr.key]
except KeyError:
pass
-
+
def commit(self, *obj):
- """Create a ``CommittedState`` instance for each object in the given list, representing
- its *unchanged* state, and associates it with the instance.
-
- ``AttributeHistory`` objects will indicate the modified state of
- instance attributes as compared to its value in this
- ``CommittedState`` object.
- """
+ """Establish the "committed state" for each object in the given list."""
for o in obj:
- o._state['original'] = CommittedState(self, o)
- o._state['modified'] = False
+ o._state.commit(self, o)
def managed_attributes(self, class_):
"""Return a list of all ``InstrumentedAttribute`` objects
@@ -706,7 +705,7 @@ class AttributeManager(object):
for attr in self.managed_attributes(object.__class__):
if attr.check_mutable_modified(object):
return True
- return object._state.get('modified', False)
+ return object._state.modified
def get_history(self, obj, key, **kwargs):
"""Return a new ``AttributeHistory`` object for the given
@@ -743,12 +742,10 @@ class AttributeManager(object):
removed.
"""
+ s = obj._state
self._clear(obj)
- try:
- del obj._state['original']
- except KeyError:
- pass
- obj._state['trigger'] = callable
+ s.committed_state = None
+ s.trigger = callable
def untrigger_history(self, obj):
"""Remove a trigger function set by trigger_history.
@@ -756,14 +753,14 @@ class AttributeManager(object):
Does not restore the previous state of the object.
"""
- del obj._state['trigger']
+ obj._state.trigger = None
def has_trigger(self, obj):
"""Return True if the given object has a trigger function set
by ``trigger_history()``.
"""
- return 'trigger' in obj._state
+ return obj._state.trigger is not None
def reset_instance_attribute(self, obj, key):
"""Remove any per-instance callable functions corresponding to
@@ -774,16 +771,6 @@ class AttributeManager(object):
attr = getattr(obj.__class__, key)
attr.reset(obj)
- def reset_class_managed(self, class_):
- """Remove all ``InstrumentedAttribute`` property objects from
- the given class.
- """
-
- for attr in self.noninherited_managed_attributes(class_):
- delattr(class_, attr.key)
- self._inherited_attribute_cache.pop(class_,None)
- self._noninherited_attribute_cache.pop(class_,None)
-
def is_class_managed(self, class_, key):
"""Return True if the given `key` correponds to an
instrumented property on the given class.
@@ -826,7 +813,71 @@ class AttributeManager(object):
return getattr(obj_or_cls, key)
else:
return getattr(obj_or_cls.__class__, key)
+
+ def manage(self, obj):
+ if not hasattr(obj, '_state'):
+ obj._state = InstanceState(obj)
+
+ def new_instance(self, class_):
+ """create a new instance of class_ without its __init__() method being called."""
+
+ s = class_.__new__(class_)
+ s._state = InstanceState(s)
+ return s
+
+ def register_class(self, class_, extra_init=None, on_exception=None):
+ """decorate the constructor of the given class to establish attribute
+ management on new instances."""
+ oldinit = None
+ doinit = False
+
+ def init(instance, *args, **kwargs):
+ instance._state = InstanceState(instance)
+
+ if extra_init:
+ extra_init(class_, oldinit, instance, args, kwargs)
+
+ if doinit:
+ try:
+ oldinit(instance, *args, **kwargs)
+ except:
+ if on_exception:
+ on_exception(class_, oldinit, instance, args, kwargs)
+ raise
+
+ # override oldinit
+ oldinit = class_.__init__
+ if oldinit is None or not hasattr(oldinit, '_oldinit'):
+ init._oldinit = oldinit
+ class_.__init__ = init
+ # if oldinit is already one of our 'init' methods, replace it
+ elif hasattr(oldinit, '_oldinit'):
+ init._oldinit = oldinit._oldinit
+ class_.__init = init
+ oldinit = oldinit._oldinit
+
+ if oldinit is not None:
+ doinit = oldinit is not object.__init__
+ try:
+ init.__name__ = oldinit.__name__
+ init.__doc__ = oldinit.__doc__
+ except:
+ # cant set __name__ in py 2.3 !
+ pass
+
+ def unregister_class(self, class_):
+ if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
+ if class_.__init__._oldinit is not None:
+ class_.__init__ = class_.__init__._oldinit
+ else:
+ delattr(class_, '__init__')
+
+ for attr in self.noninherited_managed_attributes(class_):
+ delattr(class_, attr.key)
+ self._inherited_attribute_cache.pop(class_,None)
+ self._noninherited_attribute_cache.pop(class_,None)
+
def register_attribute(self, class_, key, uselist, callable_=None, **kwargs):
"""Register an attribute at the class level to be instrumented
for all instances of the class.
@@ -837,13 +888,6 @@ class AttributeManager(object):
self._inherited_attribute_cache.pop(class_, None)
self._noninherited_attribute_cache.pop(class_, None)
- if not hasattr(class_, '_state'):
- def _get_state(self):
- if not hasattr(self, '_sa_attr_state'):
- self._sa_attr_state = {}
- return self._sa_attr_state
- class_._state = property(_get_state)
-
typecallable = kwargs.pop('typecallable', None)
if isinstance(typecallable, InstrumentedAttribute):
typecallable = None
diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py
index 1d4b5f6c9..aa5105150 100644
--- a/lib/sqlalchemy/orm/dynamic.py
+++ b/lib/sqlalchemy/orm/dynamic.py
@@ -34,7 +34,7 @@ class DynamicCollectionAttribute(attributes.InstrumentedAttribute):
old_collection = self.get(obj).assign(value)
# TODO: emit events ???
- state['modified'] = True
+ state.modified = True
def delete(self, *args, **kwargs):
raise NotImplementedError()
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 960282255..5d495d7a9 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -198,14 +198,9 @@ class Mapper(object):
def dispose(self):
# disaable any attribute-based compilation
self.__props_init = True
- attribute_manager.reset_class_managed(self.class_)
if hasattr(self.class_, 'c'):
del self.class_.c
- if hasattr(self.class_, '__init__') and hasattr(self.class_.__init__, '_oldinit'):
- if self.class_.__init__._oldinit is not None:
- self.class_.__init__ = self.class_.__init__._oldinit
- else:
- delattr(self.class_, '__init__')
+ attribute_manager.unregister_class(self.class_)
def compile(self):
"""Compile this mapper into its final internal format.
@@ -664,34 +659,14 @@ class Mapper(object):
if not self.non_primary and (self.class_key in mapper_registry):
raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined with entity name '%s'. Use non_primary=True to create a non primary Mapper, or to create a new primary mapper, remove this mapper first via sqlalchemy.orm.clear_mapper(mapper), or preferably sqlalchemy.orm.clear_mappers() to clear all mappers." % (self.class_, self.entity_name))
- attribute_manager.reset_class_managed(self.class_)
-
- oldinit = self.class_.__init__
- doinit = oldinit is not None and oldinit is not object.__init__
-
- def init(instance, *args, **kwargs):
+ def extra_init(class_, oldinit, instance, args, kwargs):
self.compile()
- self.extension.init_instance(self, self.class_, oldinit, instance, args, kwargs)
-
- if doinit:
- try:
- oldinit(instance, *args, **kwargs)
- except:
- # call init_failed but suppress exceptions into warnings so that original __init__
- # exception is raised
- util.warn_exception(self.extension.init_failed, self, self.class_, oldinit, instance, args, kwargs)
- raise
-
- # override oldinit, ensuring that its not already a Mapper-decorated init method
- if oldinit is None or not hasattr(oldinit, '_oldinit'):
- try:
- init.__name__ = oldinit.__name__
- init.__doc__ = oldinit.__doc__
- except:
- # cant set __name__ in py 2.3 !
- pass
- init._oldinit = oldinit
- self.class_.__init__ = init
+ self.extension.init_instance(self, class_, oldinit, instance, args, kwargs)
+
+ def on_exception(class_, oldinit, instance, args, kwargs):
+ util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs)
+
+ attribute_manager.register_class(self.class_, extra_init=extra_init, on_exception=on_exception)
_COMPILE_MUTEX.acquire()
try:
@@ -1436,7 +1411,7 @@ class Mapper(object):
# plugin point
instance = extension.create_instance(self, context, row, self.class_)
if instance is EXT_CONTINUE:
- instance = self._create_instance(context.session)
+ instance = attribute_manager.new_instance(self.class_)
instance._entity_name = self.entity_name
if self.__should_log_debug:
self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
@@ -1459,12 +1434,6 @@ class Mapper(object):
return instance
- def _create_instance(self, session):
- obj = self.class_.__new__(self.class_)
- obj._entity_name = self.entity_name
-
- return obj
-
def _deferred_inheritance_condition(self, needs_tables):
cond = self.inherit_condition
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index ebd4bd3d3..b616570ab 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -847,7 +847,7 @@ class Session(object):
try:
key = getattr(object, '_instance_key', None)
if key is None:
- merged = mapper.class_.__new__(mapper.class_)
+ merged = attribute_manager.new_instance(mapper.class_)
else:
if key in self.identity_map:
merged = self.identity_map[key]
@@ -940,16 +940,9 @@ class Session(object):
"or is already persistent in a "
"different Session" % repr(obj))
else:
- m = _class_mapper(obj.__class__, entity_name=kwargs.get('entity_name', None))
-
- # this would be a nice exception to raise...however this is incompatible with a contextual
- # session which puts all objects into the session upon construction.
- #if m._is_orphan(object):
- # raise exceptions.InvalidRequestError("Instance '%s' is an orphan, "
- # "and must be attached to a parent "
- # "object to be saved" % (repr(object)))
-
- m._assign_entity_name(obj)
+ # TODO: consolidate the steps here
+ attribute_manager.manage(obj)
+ obj._entity_name = kwargs.get('entity_name', None)
self._attach(obj)
self.uow.register_new(obj)