Remove duplicates of cmp_to_key (#12542, reviewed by Raymond Hettinger)

This commit is contained in:
Éric Araujo 2011-07-26 15:13:47 +02:00
parent 45dedaafc2
commit 41bade96a4
2 changed files with 13 additions and 26 deletions

View file

@ -4,17 +4,10 @@
import sys import sys
import os import os
from functools import cmp_to_key
from test import support, seq_tests from test import support, seq_tests
def CmpToKey(mycmp):
'Convert a cmp= function into a key= function'
class K(object):
def __init__(self, obj):
self.obj = obj
def __lt__(self, other):
return mycmp(self.obj, other.obj) == -1
return K
class CommonTest(seq_tests.CommonTest): class CommonTest(seq_tests.CommonTest):
@ -443,7 +436,7 @@ def revcmp(a, b):
return 1 return 1
else: # a > b else: # a > b
return -1 return -1
u.sort(key=CmpToKey(revcmp)) u.sort(key=cmp_to_key(revcmp))
self.assertEqual(u, self.type2test([2,1,0,-1,-2])) self.assertEqual(u, self.type2test([2,1,0,-1,-2]))
# The following dumps core in unpatched Python 1.5: # The following dumps core in unpatched Python 1.5:
@ -456,7 +449,7 @@ def myComparison(x,y):
else: # xmod > ymod else: # xmod > ymod
return 1 return 1
z = self.type2test(range(12)) z = self.type2test(range(12))
z.sort(key=CmpToKey(myComparison)) z.sort(key=cmp_to_key(myComparison))
self.assertRaises(TypeError, z.sort, 2) self.assertRaises(TypeError, z.sort, 2)
@ -468,7 +461,8 @@ def selfmodifyingComparison(x,y):
return -1 return -1
else: # x > y else: # x > y
return 1 return 1
self.assertRaises(ValueError, z.sort, key=CmpToKey(selfmodifyingComparison)) self.assertRaises(ValueError, z.sort,
key=cmp_to_key(selfmodifyingComparison))
self.assertRaises(TypeError, z.sort, 42, 42, 42, 42) self.assertRaises(TypeError, z.sort, 42, 42, 42, 42)

View file

@ -2,18 +2,11 @@
import random import random
import sys import sys
import unittest import unittest
from functools import cmp_to_key
verbose = support.verbose verbose = support.verbose
nerrors = 0 nerrors = 0
def CmpToKey(mycmp):
'Convert a cmp= function into a key= function'
class K(object):
def __init__(self, obj):
self.obj = obj
def __lt__(self, other):
return mycmp(self.obj, other.obj) == -1
return K
def check(tag, expected, raw, compare=None): def check(tag, expected, raw, compare=None):
global nerrors global nerrors
@ -23,7 +16,7 @@ def check(tag, expected, raw, compare=None):
orig = raw[:] # save input in case of error orig = raw[:] # save input in case of error
if compare: if compare:
raw.sort(key=CmpToKey(compare)) raw.sort(key=cmp_to_key(compare))
else: else:
raw.sort() raw.sort()
@ -108,7 +101,7 @@ def __repr__(self):
print(" Checking against an insane comparison function.") print(" Checking against an insane comparison function.")
print(" If the implementation isn't careful, this may segfault.") print(" If the implementation isn't careful, this may segfault.")
s = x[:] s = x[:]
s.sort(key=CmpToKey(lambda a, b: int(random.random() * 3) - 1)) s.sort(key=cmp_to_key(lambda a, b: int(random.random() * 3) - 1))
check("an insane function left some permutation", x, s) check("an insane function left some permutation", x, s)
if len(x) >= 2: if len(x) >= 2:
@ -165,12 +158,12 @@ def mutating_cmp(x, y):
L.pop() L.pop()
return (x > y) - (x < y) return (x > y) - (x < y)
L = [1,2] L = [1,2]
self.assertRaises(ValueError, L.sort, key=CmpToKey(mutating_cmp)) self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
def mutating_cmp(x, y): def mutating_cmp(x, y):
L.append(3) L.append(3)
del L[:] del L[:]
return (x > y) - (x < y) return (x > y) - (x < y)
self.assertRaises(ValueError, L.sort, key=CmpToKey(mutating_cmp)) self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
memorywaster = [memorywaster] memorywaster = [memorywaster]
#============================================================================== #==============================================================================
@ -185,7 +178,7 @@ def test_decorated(self):
def my_cmp(x, y): def my_cmp(x, y):
xlower, ylower = x.lower(), y.lower() xlower, ylower = x.lower(), y.lower()
return (xlower > ylower) - (xlower < ylower) return (xlower > ylower) - (xlower < ylower)
copy.sort(key=CmpToKey(my_cmp)) copy.sort(key=cmp_to_key(my_cmp))
def test_baddecorator(self): def test_baddecorator(self):
data = 'The quick Brown fox Jumped over The lazy Dog'.split() data = 'The quick Brown fox Jumped over The lazy Dog'.split()
@ -261,8 +254,8 @@ def my_cmp(x, y):
def my_cmp_reversed(x, y): def my_cmp_reversed(x, y):
x0, y0 = x[0], y[0] x0, y0 = x[0], y[0]
return (y0 > x0) - (y0 < x0) return (y0 > x0) - (y0 < x0)
data.sort(key=CmpToKey(my_cmp), reverse=True) data.sort(key=cmp_to_key(my_cmp), reverse=True)
copy1.sort(key=CmpToKey(my_cmp_reversed)) copy1.sort(key=cmp_to_key(my_cmp_reversed))
self.assertEqual(data, copy1) self.assertEqual(data, copy1)
copy2.sort(key=lambda x: x[0], reverse=True) copy2.sort(key=lambda x: x[0], reverse=True)
self.assertEqual(data, copy2) self.assertEqual(data, copy2)