diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:12:07 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:12:07 +0000 |
commit | 5b7363c3386dd297146d6fe1a65aa1c4b76c377d (patch) | |
tree | 22e20f1bcb19552ffb2c23b0f913e5e094239e41 /examples/derived_attributes/attributes.py | |
parent | ed4fc64bb0ac61c27bc4af32962fb129e74a36bf (diff) | |
download | sqlalchemy-5b7363c3386dd297146d6fe1a65aa1c4b76c377d.tar.gz |
- merged ants' derived attributes example from 0.4 branch
- disabled PG schema test for now (want to see the buildbot succeed)
Diffstat (limited to 'examples/derived_attributes/attributes.py')
-rw-r--r-- | examples/derived_attributes/attributes.py | 127 |
1 files changed, 127 insertions, 0 deletions
diff --git a/examples/derived_attributes/attributes.py b/examples/derived_attributes/attributes.py new file mode 100644 index 000000000..f53badc74 --- /dev/null +++ b/examples/derived_attributes/attributes.py @@ -0,0 +1,127 @@ +"""A couple of helper descriptors to allow to use the same code as query +criterion creators and as instance code. As this doesn't do advanced +magic recompiling, you can only use basic expression-like code.""" + +import new + +class MethodDescriptor(object): + def __init__(self, func): + self.func = func + def __get__(self, instance, owner): + if instance is None: + return new.instancemethod(self.func, owner, owner.__class__) + else: + return new.instancemethod(self.func, instance, owner) + +class PropertyDescriptor(object): + def __init__(self, fget, fset, fdel): + self.fget = fget + self.fset = fset + self.fdel = fdel + def __get__(self, instance, owner): + if instance is None: + return self.fget(owner) + else: + return self.fget(instance) + def __set__(self, instance, value): + self.fset(instance, value) + def __delete__(self, instance): + self.fdel(instance) + +def hybrid(func): + return MethodDescriptor(func) + +def hybrid_property(fget, fset=None, fdel=None): + return PropertyDescriptor(fget, fset, fdel) + +### Example code + +from sqlalchemy import * +from sqlalchemy.orm import * + +metadata = MetaData('sqlite://') +metadata.bind.echo = True + +print "Set up database metadata" + +interval_table1 = Table('interval1', metadata, + Column('id', Integer, primary_key=True), + Column('start', Integer, nullable=False), + Column('end', Integer, nullable=False)) + +interval_table2 = Table('interval2', metadata, + Column('id', Integer, primary_key=True), + Column('start', Integer, nullable=False), + Column('length', Integer, nullable=False)) + +metadata.create_all() + +# A base class for intervals + +class BaseInterval(object): + @hybrid + def contains(self,point): + return (self.start <= point) & (point < self.end) + + @hybrid + def intersects(self, other): + return (self.start < other.end) & (self.end > other.start) + + def __repr__(self): + return "%s(%s..%s)" % (self.__class__.__name__, self.start, self.end) + +# Interval stored as endpoints + +class Interval1(BaseInterval): + def __init__(self, start, end): + self.start = start + self.end = end + + length = hybrid_property(lambda s: s.end - s.start) + +mapper(Interval1, interval_table1) + +# Interval stored as start and length + +class Interval2(BaseInterval): + def __init__(self, start, length): + self.start = start + self.length = length + + end = hybrid_property(lambda s: s.start + s.length) + +mapper(Interval2, interval_table2) + +print "Create the data" + +session = create_session() + +intervals = [Interval1(1,4), Interval1(3,15), Interval1(11,16)] + +for interval in intervals: + session.save(interval) + session.save(Interval2(interval.start, interval.length)) + +session.flush() + +print "Clear the cache and do some queries" + +session.clear() + +for Interval in (Interval1, Interval2): + print "Querying using interval class %s" % Interval.__name__ + + print + print '-- length less than 10' + print [(i, i.length) for i in session.query(Interval).filter(Interval.length < 10).all()] + + print + print '-- contains 12' + print session.query(Interval).filter(Interval.contains(12)).all() + + print + print '-- intersects 2..10' + other = Interval1(2,10) + result = session.query(Interval).filter(Interval.intersects(other)).order_by(Interval.length).all() + print [(interval, interval.intersects(other)) for interval in result] + |