diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2011-10-15 14:31:02 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2011-10-15 14:31:02 -0400 |
commit | 80c68c0e22e2b45b3eaffcb7485d6a9f5eb02ba4 (patch) | |
tree | 5df8ed1160fd4410ba70d485a6077085a9762761 /test | |
parent | f1eea63468d8e5d84edceb2b0028984e5917dde0 (diff) | |
download | sqlalchemy-80c68c0e22e2b45b3eaffcb7485d6a9f5eb02ba4.tar.gz |
- Reinstated "comparator_factory" argument to
composite(), removed when 0.7 was released.
[ticket:2248]
Diffstat (limited to 'test')
-rw-r--r-- | test/orm/test_composites.py | 128 |
1 files changed, 127 insertions, 1 deletions
diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index 0d3cc20d6..0c16e57a1 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -5,7 +5,7 @@ from sqlalchemy import MetaData, Integer, String, ForeignKey, func, \ util, select from test.lib.schema import Table, Column from sqlalchemy.orm import mapper, relationship, backref, \ - class_mapper, \ + class_mapper, CompositeProperty, \ validates, aliased from sqlalchemy.orm import attributes, \ composite, relationship, \ @@ -634,3 +634,129 @@ class ConfigurationTest(fixtures.MappedTest): deferred=True) }) self._test_roundtrip() + +class ComparatorTest(fixtures.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('edge', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('x1', Integer), + Column('y1', Integer), + Column('x2', Integer), + Column('y2', Integer), + ) + + @classmethod + def setup_mappers(cls): + class Point(cls.Comparable): + def __init__(self, x, y): + self.x = x + self.y = y + def __composite_values__(self): + return [self.x, self.y] + def __eq__(self, other): + return isinstance(other, Point) and \ + other.x == self.x and \ + other.y == self.y + def __ne__(self, other): + return not isinstance(other, Point) or \ + not self.__eq__(other) + + class Edge(cls.Comparable): + def __init__(self, start, end): + self.start = start + self.end = end + + def __eq__(self, other): + return isinstance(other, Edge) and \ + other.id == self.id + + def _fixture(self, custom): + edge, Edge, Point = (self.tables.edge, + self.classes.Edge, + self.classes.Point) + + if custom: + class CustomComparator(sa.orm.CompositeProperty.Comparator): + def near(self, other, d): + clauses = self.__clause_element__().clauses + diff_x = clauses[0] - other.x + diff_y = clauses[1] - other.y + return diff_x * diff_x + diff_y * diff_y <= d * d + + mapper(Edge, edge, properties={ + 'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1, + comparator_factory=CustomComparator), + 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2) + }) + else: + mapper(Edge, edge, properties={ + 'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1), + 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2) + }) + + def test_comparator_behavior_default(self): + self._fixture(False) + self._test_comparator_behavior() + + def test_comparator_behavior_custom(self): + self._fixture(True) + self._test_comparator_behavior() + + def _test_comparator_behavior(self): + Edge, Point = (self.classes.Edge, + self.classes.Point) + + sess = Session() + e1 = Edge(Point(3, 4), Point(5, 6)) + e2 = Edge(Point(14, 5), Point(2, 7)) + sess.add_all([e1, e2]) + sess.commit() + + assert sess.query(Edge).\ + filter(Edge.start==Point(3, 4)).one() is \ + e1 + + assert sess.query(Edge).\ + filter(Edge.start!=Point(3, 4)).first() is \ + e2 + + eq_( + sess.query(Edge).filter(Edge.start==None).all(), + [] + ) + + def test_default_comparator_factory(self): + self._fixture(False) + Edge = self.classes.Edge + start_prop = Edge.start.property + + assert start_prop.comparator_factory is CompositeProperty.Comparator + + def test_custom_comparator_factory(self): + self._fixture(True) + Edge, Point = (self.classes.Edge, + self.classes.Point) + + edge_1, edge_2 = Edge(Point(0, 0), Point(3, 5)), \ + Edge(Point(0, 1), Point(3, 5)) + + sess = Session() + sess.add_all([edge_1, edge_2]) + sess.commit() + + near_edges = sess.query(Edge).filter( + Edge.start.near(Point(1, 1), 1) + ).all() + + assert edge_1 not in near_edges + assert edge_2 in near_edges + + near_edges = sess.query(Edge).filter( + Edge.start.near(Point(0, 1), 1) + ).all() + + assert edge_1 in near_edges and edge_2 in near_edges + + |