summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES4
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py6
-rw-r--r--test/orm/test_composites.py128
3 files changed, 135 insertions, 3 deletions
diff --git a/CHANGES b/CHANGES
index bcb433d78..cddd86566 100644
--- a/CHANGES
+++ b/CHANGES
@@ -85,6 +85,10 @@ CHANGES
deferred=True option failed due to missing
import [ticket:2253]
+ - Reinstated "comparator_factory" argument to
+ composite(), removed when 0.7 was released.
+ [ticket:2248]
+
- Fixed bug in query.join() which would occur
in a complex multiple-overlapping path scenario,
where the same table could be joined to
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index cb31fadac..594705a8a 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -79,6 +79,8 @@ class CompositeProperty(DescriptorProperty):
self.active_history = kwargs.get('active_history', False)
self.deferred = kwargs.get('deferred', False)
self.group = kwargs.get('group', None)
+ self.comparator_factory = kwargs.pop('comparator_factory',
+ self.__class__.Comparator)
util.set_creation_order(self)
self._create_descriptor()
@@ -257,11 +259,11 @@ class CompositeProperty(DescriptorProperty):
)
def _comparator_factory(self, mapper):
- return CompositeProperty.Comparator(self)
+ return self.comparator_factory(self)
class Comparator(PropComparator):
def __init__(self, prop, adapter=None):
- self.prop = prop
+ self.prop = self.property = prop
self.adapter = adapter
def __clause_element__(self):
diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py
index 0d3cc20d6..0c16e57a1 100644
--- a/test/orm/test_composites.py
+++ b/test/orm/test_composites.py
@@ -5,7 +5,7 @@ from sqlalchemy import MetaData, Integer, String, ForeignKey, func, \
util, select
from test.lib.schema import Table, Column
from sqlalchemy.orm import mapper, relationship, backref, \
- class_mapper, \
+ class_mapper, CompositeProperty, \
validates, aliased
from sqlalchemy.orm import attributes, \
composite, relationship, \
@@ -634,3 +634,129 @@ class ConfigurationTest(fixtures.MappedTest):
deferred=True)
})
self._test_roundtrip()
+
+class ComparatorTest(fixtures.MappedTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ Table('edge', metadata,
+ Column('id', Integer, primary_key=True,
+ test_needs_autoincrement=True),
+ Column('x1', Integer),
+ Column('y1', Integer),
+ Column('x2', Integer),
+ Column('y2', Integer),
+ )
+
+ @classmethod
+ def setup_mappers(cls):
+ class Point(cls.Comparable):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+ def __composite_values__(self):
+ return [self.x, self.y]
+ def __eq__(self, other):
+ return isinstance(other, Point) and \
+ other.x == self.x and \
+ other.y == self.y
+ def __ne__(self, other):
+ return not isinstance(other, Point) or \
+ not self.__eq__(other)
+
+ class Edge(cls.Comparable):
+ def __init__(self, start, end):
+ self.start = start
+ self.end = end
+
+ def __eq__(self, other):
+ return isinstance(other, Edge) and \
+ other.id == self.id
+
+ def _fixture(self, custom):
+ edge, Edge, Point = (self.tables.edge,
+ self.classes.Edge,
+ self.classes.Point)
+
+ if custom:
+ class CustomComparator(sa.orm.CompositeProperty.Comparator):
+ def near(self, other, d):
+ clauses = self.__clause_element__().clauses
+ diff_x = clauses[0] - other.x
+ diff_y = clauses[1] - other.y
+ return diff_x * diff_x + diff_y * diff_y <= d * d
+
+ mapper(Edge, edge, properties={
+ 'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1,
+ comparator_factory=CustomComparator),
+ 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2)
+ })
+ else:
+ mapper(Edge, edge, properties={
+ 'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1),
+ 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2)
+ })
+
+ def test_comparator_behavior_default(self):
+ self._fixture(False)
+ self._test_comparator_behavior()
+
+ def test_comparator_behavior_custom(self):
+ self._fixture(True)
+ self._test_comparator_behavior()
+
+ def _test_comparator_behavior(self):
+ Edge, Point = (self.classes.Edge,
+ self.classes.Point)
+
+ sess = Session()
+ e1 = Edge(Point(3, 4), Point(5, 6))
+ e2 = Edge(Point(14, 5), Point(2, 7))
+ sess.add_all([e1, e2])
+ sess.commit()
+
+ assert sess.query(Edge).\
+ filter(Edge.start==Point(3, 4)).one() is \
+ e1
+
+ assert sess.query(Edge).\
+ filter(Edge.start!=Point(3, 4)).first() is \
+ e2
+
+ eq_(
+ sess.query(Edge).filter(Edge.start==None).all(),
+ []
+ )
+
+ def test_default_comparator_factory(self):
+ self._fixture(False)
+ Edge = self.classes.Edge
+ start_prop = Edge.start.property
+
+ assert start_prop.comparator_factory is CompositeProperty.Comparator
+
+ def test_custom_comparator_factory(self):
+ self._fixture(True)
+ Edge, Point = (self.classes.Edge,
+ self.classes.Point)
+
+ edge_1, edge_2 = Edge(Point(0, 0), Point(3, 5)), \
+ Edge(Point(0, 1), Point(3, 5))
+
+ sess = Session()
+ sess.add_all([edge_1, edge_2])
+ sess.commit()
+
+ near_edges = sess.query(Edge).filter(
+ Edge.start.near(Point(1, 1), 1)
+ ).all()
+
+ assert edge_1 not in near_edges
+ assert edge_2 in near_edges
+
+ near_edges = sess.query(Edge).filter(
+ Edge.start.near(Point(0, 1), 1)
+ ).all()
+
+ assert edge_1 in near_edges and edge_2 in near_edges
+
+