diff options
author | Jason Kirtland <jek@discorporate.us> | 2007-05-03 00:57:59 +0000 |
---|---|---|
committer | Jason Kirtland <jek@discorporate.us> | 2007-05-03 00:57:59 +0000 |
commit | 300d1d2c136462201c79ff19cb6b8c2bbc0c8dfd (patch) | |
tree | 1ae489dc49024405e3c1eca3dfc4f1b8daca6883 /lib/sqlalchemy/ext/associationproxy.py | |
parent | 612c49f545b5374be45dbb4da21a5d708ebb894f (diff) | |
download | sqlalchemy-300d1d2c136462201c79ff19cb6b8c2bbc0c8dfd.tar.gz |
- New association proxy implementation, implementing complete proxies to list, dict and set-based relation collections (and scalar relations). Extensive tests.
- Added util.duck_type_collection
Diffstat (limited to 'lib/sqlalchemy/ext/associationproxy.py')
-rw-r--r-- | lib/sqlalchemy/ext/associationproxy.py | 664 |
1 files changed, 584 insertions, 80 deletions
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 65b95ccba..0913d6c48 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -6,116 +6,620 @@ transparent proxied access to the endpoint of an association object. See the example ``examples/association/proxied_association.py``. """ -from sqlalchemy.orm import class_mapper +from sqlalchemy.orm.attributes import InstrumentedList +import sqlalchemy.exceptions as exceptions +import sqlalchemy.orm as orm +import sqlalchemy.util as util + +def association_proxy(targetcollection, attr, **kw): + """Convenience function for use in mapped classes. Implements a Python + property representing a relation as a collection of simpler values. The + proxied property will mimic the collection type of the target (list, dict + or set), or in the case of a one to one relation, a simple scalar value. + + targetcollection + Name of the relation attribute we'll proxy to, usually created with + 'relation()' in a mapper setup. + + attr + Attribute on the associated instances we'll proxy for. For example, + given a target collection of [obj1, obj2], a list created by this proxy + property would look like + [getattr(obj1, attr), getattr(obj2, attr)] + + If the relation is one-to-one or otherwise uselist=False, then simply: + getattr(obj, attr) + + creator (optional) + When new items are added to this proxied collection, new instances of + the class collected by the target collection will be created. For + list and set collections, the target class constructor will be called + with the 'value' for the new instance. For dict types, two arguments + are passed: key and value. + + If you want to construct instances differently, supply a 'creator' + function that takes arguments as above and returns instances. + + For scalar relations, creator() will be called if the target is None. + If the target is present, set operations are proxied to setattr() on the + associated object. + + If you have an associated object with multiple attributes, you may set up + multiple association proxies mapping to different attributes. See the + unit tests for examples, and for examples of how creator() functions can + be used to construct the scalar relation on-demand in this situation. + + Passes along any other arguments to AssociationProxy + """ + + return AssociationProxy(targetcollection, attr, **kw) + class AssociationProxy(object): - """A property object that automatically sets up ``AssociationLists`` on a parent object.""" + """A property object that automatically sets up `AssociationLists` + on an object.""" - def __init__(self, targetcollection, attr, creator=None): - """Create a new association property. + def __init__(self, targetcollection, attr, creator=None, + proxy_factory=None, proxy_bulk_set=None): + """Arguments are: - targetcollection - The attribute name which stores the collection of Associations. + targetcollection + Name of the collection we'll proxy to, usually created with + 'relation()' in a mapper setup. - attr - Name of the attribute on the Association in which to get/set target values. + attr + Attribute on the collected instances we'll proxy for. For example, + given a target collection of [obj1, obj2], + a list created by this proxy property would look like + [getattr(obj1, attr), getattr(obj2, attr)] - creator - Optional callable which is used to create a new association - object. This callable is given a single argument which is - an instance of the *proxied* object. If creator is not - given, the association object is created using the class - associated with the targetcollection attribute, using its - ``__init__()`` constructor and setting the proxied - attribute. + creator + Optional. When new items are added to this proxied collection, new + instances of the class collected by the target collection will be + created. For list and set collections, the target class + constructor will be called with the 'value' for the new instance. + For dict types, two arguments are passed: key and value. + + If you want to construct instances differently, supply a 'creator' + function that takes arguments as above and returns instances. + + proxy_factory + Optional. The type of collection to emulate is determined by + sniffing the target collection. If your collection type can't be + determined by duck typing or you'd like to use a different collection + implementation, you may supply a factory function to produce those + collections. Only applicable to non-scalar relations. + + proxy_bulk_set + Optional, use with proxy_factory. See the _set() method for + details. """ - self.targetcollection = targetcollection - self.attr = attr + self.target_collection = targetcollection # backwards compat name... + self.value_attr = attr self.creator = creator - - def __init_deferred(self): - prop = class_mapper(self._owner_class).props[self.targetcollection] - self._cls = prop.mapper.class_ - self._uselist = prop.uselist + self.proxy_factory = proxy_factory + self.proxy_bulk_set = proxy_bulk_set - def _get_class(self): - try: - return self._cls - except AttributeError: - self.__init_deferred() - return self._cls + self.scalar = None + self.owning_class = None + self.key = '_%s_%s_%s' % (type(self).__name__, + targetcollection, id(self)) + self.collection_class = None - def _get_uselist(self): - try: - return self._uselist - except AttributeError: - self.__init_deferred() - return self._uselist + def _get_property(self): + return orm.class_mapper(self.owning_class).props[self.target_collection] - cls = property(_get_class) - uselist = property(_get_uselist) + def _target_class(self): + return self._get_property().mapper.class_ + target_class = property(_target_class) - def create(self, target, **kw): - if self.creator is not None: - return self.creator(target, **kw) - else: - assoc = self.cls(**kw) - setattr(assoc, self.attr, target) - return assoc - - def __get__(self, obj, owner): - self._owner_class = owner + + def __get__(self, obj, class_): if obj is None: - return self - storage_key = '_AssociationProxy_%s_%s' % (self.targetcollection, self.attr) - if self.uselist: + self.owning_class = class_ + return + elif self.scalar is None: + self.scalar = not self._get_property().uselist + + if self.scalar: + return getattr(getattr(obj, self.target_collection), self.value_attr) + else: try: - return getattr(obj, storage_key) + return getattr(obj, self.key) except AttributeError: - a = _AssociationList(self, obj) - setattr(obj, storage_key, a) - return a + proxy = self._new(getattr(obj, self.target_collection)) + setattr(obj, self.key, proxy) + return proxy + + def __set__(self, obj, values): + if self.scalar: + creator = self.creator and self.creator or self.target_class + target = getattr(obj, self.target_collection) + if target is None: + setattr(obj, self.target_collection, creator(values)) + else: + setattr(target, self.value_attr, values) + else: + proxy = self.__get__(obj, None) + proxy.clear() + self._set(proxy, values) + + def __delete__(self, obj): + delattr(obj, self.key) + + def _new(self, collection): + creator = self.creator and self.creator or self.target_class + + # Prefer class typing here to spot dicts with the required append() + # method. + if isinstance(collection.data, dict): + self.collection_class = dict else: - return getattr(getattr(obj, self.targetcollection), self.attr) + self.collection_class = util.duck_type_collection(collection.data) + + if self.proxy_factory: + return self.proxy_factory(collection, creator, self.value_attr) - def __set__(self, obj, value): - if self.uselist: - setattr(obj, self.targetcollection, [self.create(x) for x in value]) + value_attr = self.value_attr + getter = lambda o: getattr(o, value_attr) + setter = lambda o, v: setattr(o, value_attr, v) + + if self.collection_class is list: + return _AssociationList(collection, creator, getter, setter) + elif self.collection_class is dict: + kv_setter = lambda o, k, v: setattr(o, value_attr, v) + return _AssociationDict(collection, creator, getter, setter) + elif self.collection_class is util.Set: + return _AssociationSet(collection, creator, getter, setter) else: - setattr(obj, self.targetcollection, self.create(value)) + raise exceptions.ArgumentError( + 'could not guess which interface to use for ' + 'collection_class "%s" backing "%s"; specify a ' + 'proxy_factory and proxy_bulk_set manually' % + (self.collection_class.__name__, self.target_collection)) - def __del__(self, obj): - delattr(obj, self.targetcollection) + def _set(self, proxy, values): + if self.proxy_bulk_set: + self.proxy_bulk_set(proxy, values) + elif self.collection_class is list: + proxy.extend(values) + elif self.collection_class is dict: + proxy.update(values) + elif self.collection_class is util.Set: + proxy.update(values) + else: + raise exceptions.ArgumentError( + 'no proxy_bulk_set supplied for custom ' + 'collection_class implementation') class _AssociationList(object): - """Generic proxying list which proxies list operations to a - different list-holding attribute of the parent object, converting - Association objects to and from a target attribute on each - Association object. + """Generic proxying list which proxies list operations to a another list, + converting association objects to and from a simplified value. """ - def __init__(self, proxy, parent): - """Create a new ``AssociationList``.""" - self.proxy = proxy - self.parent = parent + def __init__(self, collection, creator, getter, setter): + """ + collection + A list-based collection of entities (usually an object attribute + managed by a SQLAlchemy relation()) + + creator + A function that creates new target entities. Given one parameter: + value. The assertion is assumed: + obj = creator(somevalue) + assert getter(obj) == somevalue + + getter + A function. Given an associated object, return the 'value'. + + setter + A function. Given an associated object and a value, store + that value on the object. + """ + + self.col = collection + self.creator = creator + self.getter = getter + self.setter = setter + + # For compatibility with 0.3.1 through 0.3.7- pass kw through to creator. + # (see append() below) + def _create(self, value, **kw): + return self.creator(value, **kw) + + def _get(self, object): + return self.getter(object) + + def _set(self, object, value): + return self.setter(object, value) - def append(self, item, **kw): - a = self.proxy.create(item, **kw) - getattr(self.parent, self.proxy.targetcollection).append(a) + def __len__(self): + return len(self.col) + + def __nonzero__(self): + return True if self.col else False + + def __getitem__(self, index): + return self._get(self.col[index]) + + def __setitem__(self, index, value): + self._set(self.col[index], value) + + def __delitem__(self, index): + del self.col[index] + + def __contains__(self, value): + for member in self.col: + if self._get(member) == value: + return True + return False + + def __getslice__(self, start, end): + return [self._get(member) for member in self.col[start:end]] + + def __setslice__(self, start, end, values): + members = [self._create(v) for v in values] + self.col[start:end] = members + + def __delslice__(self, start, end): + del self.col[start:end] def __iter__(self): - return iter([getattr(x, self.proxy.attr) for x in getattr(self.parent, self.proxy.targetcollection)]) + """Iterate over proxied values. For the actual domain objects, + iterate over .col instead or just use the underlying collection + directly from its property on the parent.""" + for member in self.col: + yield self._get(member) + raise StopIteration + + # For compatibility with 0.3.1 through 0.3.7- pass kw through to creator + # on append() only. (Can't on __setitem__, __contains__, etc., obviously.) + def append(self, value, **kw): + item = self._create(value, **kw) + self.col.append(item) + + def extend(self, values): + for v in values: + self.append(v) + + def insert(self, index, value): + self.col[index:index] = [self._create(value)] + + def clear(self): + del self.col[0:len(self.col)] + + def __eq__(self, other): return list(self) == other + def __ne__(self, other): return list(self) != other + def __lt__(self, other): return list(self) < other + def __le__(self, other): return list(self) <= other + def __gt__(self, other): return list(self) > other + def __ge__(self, other): return list(self) >= other + def __cmp__(self, other): return cmp(list(self), other) + + def copy(self): + return list(self) def __repr__(self): - return repr([getattr(x, self.proxy.attr) for x in getattr(self.parent, self.proxy.targetcollection)]) + return repr(list(self)) + + def hash(self): + raise TypeError("%s objects are unhashable" % type(self).__name__) + +_NotProvided = object() +class _AssociationDict(object): + """Generic proxying list which proxies dict operations to a another dict, + converting association objects to and from a simplified value. + """ + + def __init__(self, collection, creator, getter, setter): + """ + collection + A list-based collection of entities (usually an object attribute + managed by a SQLAlchemy relation()) + + creator + A function that creates new target entities. Given two parameters: + key and value. The assertion is assumed: + obj = creator(somekey, somevalue) + assert getter(somekey) == somevalue + + getter + A function. Given an associated object and a key, return the 'value'. + + setter + A function. Given an associated object, a key and a value, store + that value on the object. + """ + + self.col = collection + self.creator = creator + self.getter = getter + self.setter = setter + + def _create(self, key, value): + return self.creator(key, value) + + def _get(self, object): + return self.getter(object) + + def _set(self, object, key, value): + return self.setter(object, key, value) def __len__(self): - return len(getattr(self.parent, self.proxy.targetcollection)) + return len(self.col) - def __getitem__(self, index): - return getattr(getattr(self.parent, self.proxy.targetcollection)[index], self.proxy.attr) + def __nonzero__(self): + return True if self.col else False - def __setitem__(self, index, value): - a = self.proxy.create(item) - getattr(self.parent, self.proxy.targetcollection)[index] = a + def __getitem__(self, key): + return self._get(self.col[key]) + + def __setitem__(self, key, value): + if key in self.col: + self._set(self.col[key], key, value) + else: + self.col[key] = self._create(key, value) + + def __delitem__(self, key): + del self.col[key] + + def __contains__(self, key): + return key in self.col + has_key = __contains__ + + def __iter__(self): + return iter(self.col) + + def clear(self): + self.col.clear() + + def __eq__(self, other): return dict(self) == other + def __ne__(self, other): return dict(self) != other + def __lt__(self, other): return dict(self) < other + def __le__(self, other): return dict(self) <= other + def __gt__(self, other): return dict(self) > other + def __ge__(self, other): return dict(self) >= other + def __cmp__(self, other): return cmp(dict(self), other) + + def __repr__(self): + return repr(dict(self.items())) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def setdefault(self, key, default=None): + if key not in self.col: + self.col[key] = self._create(key, default) + return default + else: + return self[key] + + def keys(self): + return self.col.keys() + def iterkeys(self): + return self.col.iterkeys() + + def values(self): + return [ self._get(member) for member in self.col.values() ] + def itervalues(self): + for key in self.col: + yield self._get(self.col[key]) + raise StopIteration + + def items(self): + return [(k, self._get(self.col[k])) for k in self] + def iteritems(self): + for key in self.col: + yield (key, self._get(self.col[key])) + raise StopIteration + + def pop(self, key, default=_NotProvided): + if default is _NotProvided: + member = self.col.pop(key) + else: + member = self.col.pop(key, default) + return self._get(member) + + def popitem(self): + item = self.col.popitem() + return (item[0], self._get(item[1])) + + def update(self, *a, **kw): + if len(a) > 1: + raise TypeError('update expected at most 1 arguments, got %i' % + len(a)) + elif len(a) == 1: + seq_or_map = a[0] + for item in seq_or_map: + if isinstance(item, tuple): + self[item[0]] = item[1] + else: + self[item] = seq_or_map[item] + + for key, value in kw: + self[key] = value + + def copy(self): + return dict(self.items()) + + def hash(self): + raise TypeError("%s objects are unhashable" % type(self).__name__) + +class _AssociationSet(object): + """Generic proxying list which proxies set operations to a another set, + converting association objects to and from a simplified value. + """ + + def __init__(self, collection, creator, getter, setter): + """ + collection + A list-based collection of entities (usually an object attribute + managed by a SQLAlchemy relation()) + + creator + A function that creates new target entities. Given one parameter: + value. The assertion is assumed: + obj = creator(somevalue) + assert getter(obj) == somevalue + + getter + A function. Given an associated object, return the 'value'. + + setter + A function. Given an associated object and a value, store + that value on the object. + """ + + self.col = collection + self.creator = creator + self.getter = getter + self.setter = setter + + def _create(self, value): + return self.creator(value) + + def _get(self, object): + return self.getter(object) + + def _set(self, object, value): + return self.setter(object, value) + + def __len__(self): + return len(self.col) + + def __nonzero__(self): + return True if self.col else False + + def __contains__(self, value): + for member in self.col: + if self._get(member) == value: + return True + return False + + def __iter__(self): + """Iterate over proxied values. For the actual domain objects, + iterate over .col instead or just use the underlying collection + directly from its property on the parent.""" + for member in self.col: + yield self._get(member) + raise StopIteration + + def add(self, value): + if value not in self: + # must shove this through InstrumentedList.append() which will + # eventually call the collection_class .add() + self.col.append(self._create(value)) + + # for discard and remove, choosing a more expensive check strategy rather + # than call self.creator() + def discard(self, value): + for member in self.col: + if self._get(member) == value: + self.col.discard(member) + break + + def remove(self, value): + for member in self.col: + if self._get(member) == value: + self.col.discard(member) + return + raise KeyError(value) + + def pop(self): + if not self.col: + raise KeyError('pop from an empty set') + # grumble, pop() is borked on InstrumentedList (#548) + if isinstance(self.col, InstrumentedList): + member = list(self.col)[0] + self.col.remove(member) + else: + member = self.col.pop() + return self._get(member) + + def update(self, other): + for value in other: + self.add(value) + + __ior__ = update + + def _set(self): + return util.Set(iter(self)) + + def union(self, other): + return util.Set(self).union(other) + + __or__ = union + + def difference(self, other): + return util.Set(self).difference(other) + + __sub__ = difference + + def difference_update(self, other): + for value in other: + self.discard(value) + + __isub__ = difference_update + + def intersection(self, other): + return util.Set(self).intersection(other) + + __and__ = intersection + + def intersection_update(self, other): + want, have = self.intersection(other), util.Set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + __iand__ = intersection_update + + def symmetric_difference(self, other): + return util.Set(self).symmetric_difference(other) + + __xor__ = symmetric_difference + + def symmetric_difference_update(self, other): + want, have = self.symmetric_difference(other), util.Set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + __ixor__ = symmetric_difference_update + + def issubset(self, other): + return util.Set(self).issubset(other) + + def issuperset(self, other): + return util.Set(self).issuperset(other) + + def clear(self): + self.col.clear() + + def copy(self): + return util.Set(self) + + def __eq__(self, other): return util.Set(self) == other + def __ne__(self, other): return util.Set(self) != other + def __lt__(self, other): return util.Set(self) < other + def __le__(self, other): return util.Set(self) <= other + def __gt__(self, other): return util.Set(self) > other + def __ge__(self, other): return util.Set(self) >= other + + def __repr__(self): + return repr(util.Set(self)) + + def hash(self): + raise TypeError("%s objects are unhashable" % type(self).__name__) |