summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2005-12-06 03:32:24 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2005-12-06 03:32:24 +0000
commit19aae75e6aea39e59357f334039baf6861647e40 (patch)
treef4d11218c23f1fb3b626f9608c6d9ff6215972bd /lib/sqlalchemy
parent2e1bda393c6137949326ac6a88e2e2ef41d83449 (diff)
downloadsqlalchemy-19aae75e6aea39e59357f334039baf6861647e40.tar.gz
first take at backreference handlers
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/attributes.py46
-rw-r--r--lib/sqlalchemy/util.py12
2 files changed, 50 insertions, 8 deletions
diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py
index aa768532d..9bdfcab6a 100644
--- a/lib/sqlalchemy/attributes.py
+++ b/lib/sqlalchemy/attributes.py
@@ -44,10 +44,11 @@ class PropHistory(object):
"""manages the value of a particular scalar attribute on a particular object instance."""
# make our own NONE to distinguish from "None"
NONE = object()
- def __init__(self, obj, key, **kwargs):
+ def __init__(self, obj, key, backrefmanager=None, **kwargs):
self.obj = obj
self.key = key
self.orig = PropHistory.NONE
+ self.backrefmanager = backrefmanager
def gethistory(self, *args, **kwargs):
return self
def history_contains(self, obj):
@@ -61,9 +62,13 @@ class PropHistory(object):
raise ("assigning a list to scalar property '%s' on '%s' instance %d" % (self.key, self.obj.__class__.__name__, id(self.obj)))
self.orig = self.obj.__dict__.get(self.key, None)
self.obj.__dict__[self.key] = value
+ if self.backrefmanager is not None and self.orig is not value:
+ self.backrefmanager.set(self.obj, value, self.orig)
def delattr(self):
self.orig = self.obj.__dict__.get(self.key, None)
self.obj.__dict__[self.key] = None
+ if self.backrefmanager is not None:
+ self.backrefmanager.set(self.obj, None, self.orig)
def rollback(self):
if self.orig is not PropHistory.NONE:
self.obj.__dict__[self.key] = self.orig
@@ -88,9 +93,10 @@ class PropHistory(object):
class ListElement(util.HistoryArraySet):
"""manages the value of a particular list-based attribute on a particular object instance."""
- def __init__(self, obj, key, data=None, **kwargs):
+ def __init__(self, obj, key, data=None, backrefmanager=None, **kwargs):
self.obj = obj
self.key = key
+ self.backrefmanager = backrefmanager
# if we are given a list, try to behave nicely with an existing
# list that might be set on the object already
try:
@@ -120,11 +126,15 @@ class ListElement(util.HistoryArraySet):
res = util.HistoryArraySet._setrecord(self, item)
if res:
self.list_value_changed(self.obj, self.key, item, self, False)
+ if self.backrefmanager is not None:
+ self.backrefmanager.append(self.obj, item)
return res
def _delrecord(self, item):
res = util.HistoryArraySet._delrecord(self, item)
if res:
self.list_value_changed(self.obj, self.key, item, self, True)
+ if self.backrefmanager is not None:
+ self.backrefmanager.delete(self.obj, item)
return res
class CallableProp(object):
@@ -175,6 +185,38 @@ class CallableProp(object):
def rollback(self):
pass
+class BackrefManager(object):
+ def __init__(self, key):
+ self.key = key
+ def append(self, parent, child):
+ pass
+ def delete(self, parent, child):
+ pass
+ def set(self, parent, child, oldchild):
+ pass
+
+
+class ListBackrefManager(BackrefManager):
+ def append(self, parent, child):
+ getattr(child, self.key).append(parent)
+ def delete(self, parent, child):
+ getattr(child, self.key).remove(parent)
+
+class OneToManyBackrefManager(BackrefManager):
+ def append(self, parent, child):
+ setattr(child, self.key, parent)
+ def delete(self, parent, child):
+ setattr(child, self.key, None)
+
+class ManyToOneBackrefManager(BackrefManager):
+ def set(self, parent, child, oldchild):
+ if oldchild is not None:
+ try:
+ getattr(oldchild, self.key).remove(parent)
+ except:
+ print "wha? oldchild is ", repr(oldchild)
+ if child is not None:
+ getattr(child, self.key).append(parent)
class AttributeManager(object):
"""maintains a set of per-attribute callable/history manager objects for a set of objects."""
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index c5ac8b979..443db2e3d 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -213,9 +213,9 @@ class HistoryArraySet(UserList.UserList):
self.records[item] = False
elif val is True:
del self.records[item]
+ return True
except KeyError:
- pass
- return True
+ return False
def commit(self):
for key in self.records.keys():
value = self.records[key]
@@ -274,11 +274,11 @@ class HistoryArraySet(UserList.UserList):
self.data.insert(i, item)
def pop(self, i=-1):
item = self.data[i]
- self._delrecord(item)
- return self.data.pop(i)
+ if self._delrecord(item):
+ return self.data.pop(i)
def remove(self, item):
- self._delrecord(item)
- self.data.remove(item)
+ if self._delrecord(item):
+ self.data.remove(item)
def __add__(self, other):
raise NotImplementedError()
def __radd__(self, other):