summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/associationproxy.py
diff options
context:
space:
mode:
authorJason Kirtland <jek@discorporate.us>2007-08-23 15:48:51 +0000
committerJason Kirtland <jek@discorporate.us>2007-08-23 15:48:51 +0000
commit6ebcdbd6d86787dedb12df6b86116bbf484e8ebb (patch)
tree5b5830f7d3da0214543171ce58e0f2359c70b88b /lib/sqlalchemy/ext/associationproxy.py
parent51dc8b088d37b7132f207949a0a00cd3db651e37 (diff)
downloadsqlalchemy-6ebcdbd6d86787dedb12df6b86116bbf484e8ebb.tar.gz
Expand custom assocproxy getter/setter support to scalar proxies
Diffstat (limited to 'lib/sqlalchemy/ext/associationproxy.py')
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py15
1 files changed, 13 insertions, 2 deletions
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index 5f75bfeb7..0130721e1 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -148,9 +148,11 @@ class AssociationProxy(object):
return
elif self.scalar is None:
self.scalar = self._target_is_scalar()
+ if self.scalar:
+ self._initialize_scalar_accessors()
if self.scalar:
- return getattr(getattr(obj, self.target_collection), self.value_attr)
+ return self._scalar_get(getattr(obj, self.target_collection))
else:
try:
return getattr(obj, self.key)
@@ -162,6 +164,8 @@ class AssociationProxy(object):
def __set__(self, obj, values):
if self.scalar is None:
self.scalar = self._target_is_scalar()
+ if self.scalar:
+ self._initialize_scalar_accessors()
if self.scalar:
creator = self.creator and self.creator or self.target_class
@@ -169,7 +173,7 @@ class AssociationProxy(object):
if target is None:
setattr(obj, self.target_collection, creator(values))
else:
- setattr(target, self.value_attr, values)
+ self._scalar_set(target, values)
else:
proxy = self.__get__(obj, None)
proxy.clear()
@@ -178,6 +182,13 @@ class AssociationProxy(object):
def __delete__(self, obj):
delattr(obj, self.key)
+ def _initialize_scalar_accessors(self):
+ if self.getset_factory:
+ get, set = self.getset_factory(None, self)
+ else:
+ get, set = self._default_getset(None)
+ self._scalar_get, self._scalar_set = get, set
+
def _default_getset(self, collection_class):
attr = self.value_attr
getter = util.attrgetter(attr)