summaryrefslogtreecommitdiff
path: root/Lib
diff options
context:
space:
mode:
authorRaymond Hettinger <python@rcn.com>2010-04-05 18:56:31 +0000
committerRaymond Hettinger <python@rcn.com>2010-04-05 18:56:31 +0000
commit4af7429af46652eb0e6c000f123ec59e701f221d (patch)
tree96ef0250b66151466e1327814c53b718462dea4a /Lib
parent25adc5318176b4109a5a38974f4c0ccae36c40da (diff)
downloadcpython-4af7429af46652eb0e6c000f123ec59e701f221d.tar.gz
Forward port total_ordering() and cmp_to_key().
Diffstat (limited to 'Lib')
-rw-r--r--Lib/functools.py47
-rw-r--r--Lib/pstats.py12
-rw-r--r--Lib/test/test_functools.py84
-rw-r--r--Lib/unittest/loader.py3
-rw-r--r--Lib/unittest/util.py9
5 files changed, 134 insertions, 21 deletions
diff --git a/Lib/functools.py b/Lib/functools.py
index a54f030832..539dc90ecd 100644
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -49,3 +49,50 @@ def wraps(wrapped,
"""
return partial(update_wrapper, wrapped=wrapped,
assigned=assigned, updated=updated)
+
+def total_ordering(cls):
+ 'Class decorator that fills-in missing ordering methods'
+ convert = {
+ '__lt__': [('__gt__', lambda self, other: other < self),
+ ('__le__', lambda self, other: not other < self),
+ ('__ge__', lambda self, other: not self < other)],
+ '__le__': [('__ge__', lambda self, other: other <= self),
+ ('__lt__', lambda self, other: not other <= self),
+ ('__gt__', lambda self, other: not self <= other)],
+ '__gt__': [('__lt__', lambda self, other: other > self),
+ ('__ge__', lambda self, other: not other > self),
+ ('__le__', lambda self, other: not self > other)],
+ '__ge__': [('__le__', lambda self, other: other >= self),
+ ('__gt__', lambda self, other: not other >= self),
+ ('__lt__', lambda self, other: not self >= other)]
+ }
+ roots = set(dir(cls)) & set(convert)
+ assert roots, 'must define at least one ordering operation: < > <= >='
+ root = max(roots) # prefer __lt __ to __le__ to __gt__ to __ge__
+ for opname, opfunc in convert[root]:
+ if opname not in roots:
+ opfunc.__name__ = opname
+ opfunc.__doc__ = getattr(int, opname).__doc__
+ setattr(cls, opname, opfunc)
+ return cls
+
+def cmp_to_key(mycmp):
+ 'Convert a cmp= function into a key= function'
+ class K(object):
+ def __init__(self, obj, *args):
+ self.obj = obj
+ def __lt__(self, other):
+ return mycmp(self.obj, other.obj) < 0
+ def __gt__(self, other):
+ return mycmp(self.obj, other.obj) > 0
+ def __eq__(self, other):
+ return mycmp(self.obj, other.obj) == 0
+ def __le__(self, other):
+ return mycmp(self.obj, other.obj) <= 0
+ def __ge__(self, other):
+ return mycmp(self.obj, other.obj) >= 0
+ def __ne__(self, other):
+ return mycmp(self.obj, other.obj) != 0
+ def __hash__(self):
+ raise TypeError('hash not implemented')
+ return K
diff --git a/Lib/pstats.py b/Lib/pstats.py
index e2fee37f0a..14c460680c 100644
--- a/Lib/pstats.py
+++ b/Lib/pstats.py
@@ -37,6 +37,7 @@ import os
import time
import marshal
import re
+from functools import cmp_to_key
__all__ = ["Stats"]
@@ -226,7 +227,7 @@ class Stats:
stats_list.append((cc, nc, tt, ct) + func +
(func_std_string(func), func))
- stats_list.sort(key=CmpToKey(TupleComp(sort_tuple).compare))
+ stats_list.sort(key=cmp_to_key(TupleComp(sort_tuple).compare))
self.fcn_list = fcn_list = []
for tuple in stats_list:
@@ -458,15 +459,6 @@ class TupleComp:
return direction
return 0
-def CmpToKey(mycmp):
- 'Convert a cmp= function into a key= function'
- class K(object):
- def __init__(self, obj):
- self.obj = obj
- def __lt__(self, other):
- return mycmp(self.obj, other.obj) == -1
- return K
-
#**************************************************************************
# func_name is a triple (file:string, line:int, name:string)
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index ae47dae95d..5cc2a50e3d 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -364,7 +364,89 @@ class TestReduce(unittest.TestCase):
d = {"one": 1, "two": 2, "three": 3}
self.assertEqual(self.func(add, d), "".join(d.keys()))
-
+class TestCmpToKey(unittest.TestCase):
+ def test_cmp_to_key(self):
+ def mycmp(x, y):
+ return y - x
+ self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
+ [4, 3, 2, 1, 0])
+
+ def test_hash(self):
+ def mycmp(x, y):
+ return y - x
+ key = functools.cmp_to_key(mycmp)
+ k = key(10)
+ self.assertRaises(TypeError, hash(k))
+
+class TestTotalOrdering(unittest.TestCase):
+
+ def test_total_ordering_lt(self):
+ @functools.total_ordering
+ class A:
+ def __init__(self, value):
+ self.value = value
+ def __lt__(self, other):
+ return self.value < other.value
+ self.assert_(A(1) < A(2))
+ self.assert_(A(2) > A(1))
+ self.assert_(A(1) <= A(2))
+ self.assert_(A(2) >= A(1))
+ self.assert_(A(2) <= A(2))
+ self.assert_(A(2) >= A(2))
+
+ def test_total_ordering_le(self):
+ @functools.total_ordering
+ class A:
+ def __init__(self, value):
+ self.value = value
+ def __le__(self, other):
+ return self.value <= other.value
+ self.assert_(A(1) < A(2))
+ self.assert_(A(2) > A(1))
+ self.assert_(A(1) <= A(2))
+ self.assert_(A(2) >= A(1))
+ self.assert_(A(2) <= A(2))
+ self.assert_(A(2) >= A(2))
+
+ def test_total_ordering_gt(self):
+ @functools.total_ordering
+ class A:
+ def __init__(self, value):
+ self.value = value
+ def __gt__(self, other):
+ return self.value > other.value
+ self.assert_(A(1) < A(2))
+ self.assert_(A(2) > A(1))
+ self.assert_(A(1) <= A(2))
+ self.assert_(A(2) >= A(1))
+ self.assert_(A(2) <= A(2))
+ self.assert_(A(2) >= A(2))
+
+ def test_total_ordering_ge(self):
+ @functools.total_ordering
+ class A:
+ def __init__(self, value):
+ self.value = value
+ def __ge__(self, other):
+ return self.value >= other.value
+ self.assert_(A(1) < A(2))
+ self.assert_(A(2) > A(1))
+ self.assert_(A(1) <= A(2))
+ self.assert_(A(2) >= A(1))
+ self.assert_(A(2) <= A(2))
+ self.assert_(A(2) >= A(2))
+
+ def test_total_ordering_no_overwrite(self):
+ # new methods should not overwrite existing
+ @functools.total_ordering
+ class A(int):
+ raise Exception()
+ self.assert_(A(1) < A(2))
+ self.assert_(A(2) > A(1))
+ self.assert_(A(1) <= A(2))
+ self.assert_(A(2) >= A(1))
+ self.assert_(A(2) <= A(2))
+ self.assert_(A(2) >= A(2))
def test_main(verbose=None):
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py
index 5d11b6e8ff..f00f38d1a1 100644
--- a/Lib/unittest/loader.py
+++ b/Lib/unittest/loader.py
@@ -5,6 +5,7 @@ import re
import sys
import traceback
import types
+import functools
from fnmatch import fnmatch
@@ -141,7 +142,7 @@ class TestLoader(object):
testFnNames = testFnNames = list(filter(isTestMethod,
dir(testCaseClass)))
if self.sortTestMethodsUsing:
- testFnNames.sort(key=util.CmpToKey(self.sortTestMethodsUsing))
+ testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
return testFnNames
def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
diff --git a/Lib/unittest/util.py b/Lib/unittest/util.py
index 736c20274d..ea8a68dc9f 100644
--- a/Lib/unittest/util.py
+++ b/Lib/unittest/util.py
@@ -70,15 +70,6 @@ def unorderable_list_difference(expected, actual):
# anything left in actual is unexpected
return missing, actual
-def CmpToKey(mycmp):
- 'Convert a cmp= function into a key= function'
- class K(object):
- def __init__(self, obj, *args):
- self.obj = obj
- def __lt__(self, other):
- return mycmp(self.obj, other.obj) == -1
- return K
-
def three_way_cmp(x, y):
"""Return -1 if x < y, 0 if x == y and 1 if x > y"""
return (x > y) - (x < y)