summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/associationproxy.py
blob: c9160ded4672e7e6c4aa90b836bf67523b7c3452 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""contains the AssociationProxy class, a Python property object which
provides transparent proxied access to the endpoint of an association object.

See the example examples/association/proxied_association.py.
"""

from sqlalchemy.orm import class_mapper

class AssociationProxy(object):
    """a property object that automatically sets up AssociationLists on a parent object."""
    def __init__(self, targetcollection, attr, creator=None):
        """create a new association property.
        
        targetcollection - the attribute name which stores the collection of Associations
        
        attr - name of the attribute on the Association in which to get/set target values
        
        creator - optional callable which is used to create a new association object.  this 
        callable is given a single argument which is an instance of the "proxied" object.
        if creator is not given, the association object is created using the class associated
        with the targetcollection attribute, using its __init__() constructor and setting
        the proxied attribute.
        """
        self.targetcollection = targetcollection
        self.attr = attr
        self.creator = creator
    def __init_deferred(self):
        prop = class_mapper(self._owner_class).props[self.targetcollection]
        self._cls = prop.mapper.class_
        self._uselist = prop.uselist
    def _get_class(self):
        try:
            return self._cls
        except AttributeError:
            self.__init_deferred()
            return self._cls
    def _get_uselist(self):
        try:
            return self._uselist
        except AttributeError:
            self.__init_deferred()
            return self._uselist
    cls = property(_get_class)
    uselist = property(_get_uselist)
    def create(self, target, **kw):
        if self.creator is not None:
            return self.creator(target, **kw)
        else:
            assoc = self.cls(**kw)
            setattr(assoc, self.attr, target)
            return assoc
    def __get__(self, obj, owner):
        self._owner_class = owner
        if obj is None:
            return self
        storage_key = '_AssociationProxy_%s' % self.targetcollection
        if self.uselist:
            try:
                return getattr(obj, storage_key)
            except AttributeError:
                a = _AssociationList(self, obj)
                setattr(obj, storage_key, a)
                return a
        else:
            return getattr(getattr(obj, self.targetcollection), self.attr)
    def __set__(self, obj, value):
        if self.uselist:
            setattr(obj, self.targetcollection, [self.create(x) for x in value])
        else:
            setattr(obj, self.targetcollection, self.create(value))
    def __del__(self, obj):
        delattr(obj, self.targetcollection)

class _AssociationList(object):
    """generic proxying list which proxies list operations to a different 
    list-holding attribute of the parent object, converting Association objects
    to and from a target attribute on each Association object."""
    def __init__(self, proxy, parent):
        """create a new AssociationList."""
        self.proxy = proxy
        self.parent = parent
    def append(self, item, **kw):
        a = self.proxy.create(item, **kw)
        getattr(self.parent, self.proxy.targetcollection).append(a)
    def __iter__(self):
        return iter([getattr(x, self.proxy.attr) for x in getattr(self.parent, self.proxy.targetcollection)])
    def __repr__(self):
        return repr([getattr(x, self.proxy.attr) for x in getattr(self.parent, self.proxy.targetcollection)])
    def __len__(self):
        return len(getattr(self.parent, self.proxy.targetcollection))
    def __getitem__(self, index):
        return getattr(getattr(self.parent, self.proxy.targetcollection)[index], self.proxy.attr)
    def __setitem__(self, index, value):
        a = self.proxy.create(item)
        getattr(self.parent, self.proxy.targetcollection)[index] = a