mirror of
				https://github.com/python/cpython.git
				synced 2025-10-23 01:43:53 +00:00 
			
		
		
		
	 84634254fe
			
		
	
	
		84634254fe
		
			
		
	
	
	
	
		
			
			gh-121025: Improve partialmethod.__repr__ (GH-121033)
It no longer contains redundant commas and spaces.
(cherry picked from commit d2646e3f45)
Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com>
		
	
			
		
			
				
	
	
		
			3204 lines
		
	
	
	
		
			114 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			3204 lines
		
	
	
	
		
			114 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import abc
 | |
| import builtins
 | |
| import collections
 | |
| import collections.abc
 | |
| import copy
 | |
| from itertools import permutations
 | |
| import pickle
 | |
| from random import choice
 | |
| import sys
 | |
| from test import support
 | |
| import threading
 | |
| import time
 | |
| import typing
 | |
| import unittest
 | |
| import unittest.mock
 | |
| import weakref
 | |
| import gc
 | |
| from weakref import proxy
 | |
| import contextlib
 | |
| from inspect import Signature
 | |
| 
 | |
| from test.support import import_helper
 | |
| from test.support import threading_helper
 | |
| 
 | |
| import functools
 | |
| 
 | |
| py_functools = import_helper.import_fresh_module('functools',
 | |
|                                                  blocked=['_functools'])
 | |
| c_functools = import_helper.import_fresh_module('functools',
 | |
|                                                 fresh=['_functools'])
 | |
| 
 | |
| decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal'])
 | |
| 
 | |
| 
 | |
| @contextlib.contextmanager
 | |
| def replaced_module(name, replacement):
 | |
|     original_module = sys.modules[name]
 | |
|     sys.modules[name] = replacement
 | |
|     try:
 | |
|         yield
 | |
|     finally:
 | |
|         sys.modules[name] = original_module
 | |
| 
 | |
| def capture(*args, **kw):
 | |
|     """capture all positional and keyword arguments"""
 | |
|     return args, kw
 | |
| 
 | |
| 
 | |
| def signature(part):
 | |
|     """ return the signature of a partial object """
 | |
|     return (part.func, part.args, part.keywords, part.__dict__)
 | |
| 
 | |
| class MyTuple(tuple):
 | |
|     pass
 | |
| 
 | |
| class BadTuple(tuple):
 | |
|     def __add__(self, other):
 | |
|         return list(self) + list(other)
 | |
| 
 | |
| class MyDict(dict):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class TestPartial:
 | |
| 
 | |
|     def test_basic_examples(self):
 | |
|         p = self.partial(capture, 1, 2, a=10, b=20)
 | |
|         self.assertTrue(callable(p))
 | |
|         self.assertEqual(p(3, 4, b=30, c=40),
 | |
|                          ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
 | |
|         p = self.partial(map, lambda x: x*10)
 | |
|         self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
 | |
| 
 | |
|     def test_attributes(self):
 | |
|         p = self.partial(capture, 1, 2, a=10, b=20)
 | |
|         # attributes should be readable
 | |
|         self.assertEqual(p.func, capture)
 | |
|         self.assertEqual(p.args, (1, 2))
 | |
|         self.assertEqual(p.keywords, dict(a=10, b=20))
 | |
| 
 | |
|     def test_argument_checking(self):
 | |
|         self.assertRaises(TypeError, self.partial)     # need at least a func arg
 | |
|         try:
 | |
|             self.partial(2)()
 | |
|         except TypeError:
 | |
|             pass
 | |
|         else:
 | |
|             self.fail('First arg not checked for callability')
 | |
| 
 | |
|     def test_protection_of_callers_dict_argument(self):
 | |
|         # a caller's dictionary should not be altered by partial
 | |
|         def func(a=10, b=20):
 | |
|             return a
 | |
|         d = {'a':3}
 | |
|         p = self.partial(func, a=5)
 | |
|         self.assertEqual(p(**d), 3)
 | |
|         self.assertEqual(d, {'a':3})
 | |
|         p(b=7)
 | |
|         self.assertEqual(d, {'a':3})
 | |
| 
 | |
|     def test_kwargs_copy(self):
 | |
|         # Issue #29532: Altering a kwarg dictionary passed to a constructor
 | |
|         # should not affect a partial object after creation
 | |
|         d = {'a': 3}
 | |
|         p = self.partial(capture, **d)
 | |
|         self.assertEqual(p(), ((), {'a': 3}))
 | |
|         d['a'] = 5
 | |
|         self.assertEqual(p(), ((), {'a': 3}))
 | |
| 
 | |
|     def test_arg_combinations(self):
 | |
|         # exercise special code paths for zero args in either partial
 | |
|         # object or the caller
 | |
|         p = self.partial(capture)
 | |
|         self.assertEqual(p(), ((), {}))
 | |
|         self.assertEqual(p(1,2), ((1,2), {}))
 | |
|         p = self.partial(capture, 1, 2)
 | |
|         self.assertEqual(p(), ((1,2), {}))
 | |
|         self.assertEqual(p(3,4), ((1,2,3,4), {}))
 | |
| 
 | |
|     def test_kw_combinations(self):
 | |
|         # exercise special code paths for no keyword args in
 | |
|         # either the partial object or the caller
 | |
|         p = self.partial(capture)
 | |
|         self.assertEqual(p.keywords, {})
 | |
|         self.assertEqual(p(), ((), {}))
 | |
|         self.assertEqual(p(a=1), ((), {'a':1}))
 | |
|         p = self.partial(capture, a=1)
 | |
|         self.assertEqual(p.keywords, {'a':1})
 | |
|         self.assertEqual(p(), ((), {'a':1}))
 | |
|         self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
 | |
|         # keyword args in the call override those in the partial object
 | |
|         self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
 | |
| 
 | |
|     def test_positional(self):
 | |
|         # make sure positional arguments are captured correctly
 | |
|         for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
 | |
|             p = self.partial(capture, *args)
 | |
|             expected = args + ('x',)
 | |
|             got, empty = p('x')
 | |
|             self.assertTrue(expected == got and empty == {})
 | |
| 
 | |
|     def test_keyword(self):
 | |
|         # make sure keyword arguments are captured correctly
 | |
|         for a in ['a', 0, None, 3.5]:
 | |
|             p = self.partial(capture, a=a)
 | |
|             expected = {'a':a,'x':None}
 | |
|             empty, got = p(x=None)
 | |
|             self.assertTrue(expected == got and empty == ())
 | |
| 
 | |
|     def test_no_side_effects(self):
 | |
|         # make sure there are no side effects that affect subsequent calls
 | |
|         p = self.partial(capture, 0, a=1)
 | |
|         args1, kw1 = p(1, b=2)
 | |
|         self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
 | |
|         args2, kw2 = p()
 | |
|         self.assertTrue(args2 == (0,) and kw2 == {'a':1})
 | |
| 
 | |
|     def test_error_propagation(self):
 | |
|         def f(x, y):
 | |
|             x / y
 | |
|         self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
 | |
|         self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
 | |
|         self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
 | |
|         self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
 | |
| 
 | |
|     def test_weakref(self):
 | |
|         f = self.partial(int, base=16)
 | |
|         p = proxy(f)
 | |
|         self.assertEqual(f.func, p.func)
 | |
|         f = None
 | |
|         support.gc_collect()  # For PyPy or other GCs.
 | |
|         self.assertRaises(ReferenceError, getattr, p, 'func')
 | |
| 
 | |
|     def test_with_bound_and_unbound_methods(self):
 | |
|         data = list(map(str, range(10)))
 | |
|         join = self.partial(str.join, '')
 | |
|         self.assertEqual(join(data), '0123456789')
 | |
|         join = self.partial(''.join)
 | |
|         self.assertEqual(join(data), '0123456789')
 | |
| 
 | |
|     def test_nested_optimization(self):
 | |
|         partial = self.partial
 | |
|         inner = partial(signature, 'asdf')
 | |
|         nested = partial(inner, bar=True)
 | |
|         flat = partial(signature, 'asdf', bar=True)
 | |
|         self.assertEqual(signature(nested), signature(flat))
 | |
| 
 | |
|     def test_nested_optimization_bug(self):
 | |
|         partial = self.partial
 | |
|         class Builder:
 | |
|             def __call__(self, tag, *children, **attrib):
 | |
|                 return (tag, children, attrib)
 | |
| 
 | |
|             def __getattr__(self, tag):
 | |
|                 return partial(self, tag)
 | |
| 
 | |
|         B = Builder()
 | |
|         m = B.m
 | |
|         assert m(1, 2, a=2) == ('m', (1, 2), dict(a=2))
 | |
| 
 | |
|     def test_nested_partial_with_attribute(self):
 | |
|         # see issue 25137
 | |
|         partial = self.partial
 | |
| 
 | |
|         def foo(bar):
 | |
|             return bar
 | |
| 
 | |
|         p = partial(foo, 'first')
 | |
|         p2 = partial(p, 'second')
 | |
|         p2.new_attr = 'spam'
 | |
|         self.assertEqual(p2.new_attr, 'spam')
 | |
| 
 | |
|     def test_repr(self):
 | |
|         args = (object(), object())
 | |
|         args_repr = ', '.join(repr(a) for a in args)
 | |
|         kwargs = {'a': object(), 'b': object()}
 | |
|         kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
 | |
|                         'b={b!r}, a={a!r}'.format_map(kwargs)]
 | |
|         name = f"{self.partial.__module__}.{self.partial.__qualname__}"
 | |
| 
 | |
|         f = self.partial(capture)
 | |
|         self.assertEqual(f'{name}({capture!r})', repr(f))
 | |
| 
 | |
|         f = self.partial(capture, *args)
 | |
|         self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
 | |
| 
 | |
|         f = self.partial(capture, **kwargs)
 | |
|         self.assertIn(repr(f),
 | |
|                       [f'{name}({capture!r}, {kwargs_repr})'
 | |
|                        for kwargs_repr in kwargs_reprs])
 | |
| 
 | |
|         f = self.partial(capture, *args, **kwargs)
 | |
|         self.assertIn(repr(f),
 | |
|                       [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
 | |
|                        for kwargs_repr in kwargs_reprs])
 | |
| 
 | |
|     def test_recursive_repr(self):
 | |
|         name = f"{self.partial.__module__}.{self.partial.__qualname__}"
 | |
| 
 | |
|         f = self.partial(capture)
 | |
|         f.__setstate__((f, (), {}, {}))
 | |
|         try:
 | |
|             self.assertEqual(repr(f), '%s(...)' % (name,))
 | |
|         finally:
 | |
|             f.__setstate__((capture, (), {}, {}))
 | |
| 
 | |
|         f = self.partial(capture)
 | |
|         f.__setstate__((capture, (f,), {}, {}))
 | |
|         try:
 | |
|             self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
 | |
|         finally:
 | |
|             f.__setstate__((capture, (), {}, {}))
 | |
| 
 | |
|         f = self.partial(capture)
 | |
|         f.__setstate__((capture, (), {'a': f}, {}))
 | |
|         try:
 | |
|             self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
 | |
|         finally:
 | |
|             f.__setstate__((capture, (), {}, {}))
 | |
| 
 | |
|     def test_pickle(self):
 | |
|         with replaced_module('functools', self.module):
 | |
|             f = self.partial(signature, ['asdf'], bar=[True])
 | |
|             f.attr = []
 | |
|             for proto in range(pickle.HIGHEST_PROTOCOL + 1):
 | |
|                 f_copy = pickle.loads(pickle.dumps(f, proto))
 | |
|                 self.assertEqual(signature(f_copy), signature(f))
 | |
| 
 | |
|     def test_copy(self):
 | |
|         f = self.partial(signature, ['asdf'], bar=[True])
 | |
|         f.attr = []
 | |
|         f_copy = copy.copy(f)
 | |
|         self.assertEqual(signature(f_copy), signature(f))
 | |
|         self.assertIs(f_copy.attr, f.attr)
 | |
|         self.assertIs(f_copy.args, f.args)
 | |
|         self.assertIs(f_copy.keywords, f.keywords)
 | |
| 
 | |
|     def test_deepcopy(self):
 | |
|         f = self.partial(signature, ['asdf'], bar=[True])
 | |
|         f.attr = []
 | |
|         f_copy = copy.deepcopy(f)
 | |
|         self.assertEqual(signature(f_copy), signature(f))
 | |
|         self.assertIsNot(f_copy.attr, f.attr)
 | |
|         self.assertIsNot(f_copy.args, f.args)
 | |
|         self.assertIsNot(f_copy.args[0], f.args[0])
 | |
|         self.assertIsNot(f_copy.keywords, f.keywords)
 | |
|         self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
 | |
| 
 | |
|     def test_setstate(self):
 | |
|         f = self.partial(signature)
 | |
|         f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
 | |
| 
 | |
|         self.assertEqual(signature(f),
 | |
|                          (capture, (1,), dict(a=10), dict(attr=[])))
 | |
|         self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
 | |
| 
 | |
|         f.__setstate__((capture, (1,), dict(a=10), None))
 | |
| 
 | |
|         self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
 | |
|         self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
 | |
| 
 | |
|         f.__setstate__((capture, (1,), None, None))
 | |
|         #self.assertEqual(signature(f), (capture, (1,), {}, {}))
 | |
|         self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
 | |
|         self.assertEqual(f(2), ((1, 2), {}))
 | |
|         self.assertEqual(f(), ((1,), {}))
 | |
| 
 | |
|         f.__setstate__((capture, (), {}, None))
 | |
|         self.assertEqual(signature(f), (capture, (), {}, {}))
 | |
|         self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
 | |
|         self.assertEqual(f(2), ((2,), {}))
 | |
|         self.assertEqual(f(), ((), {}))
 | |
| 
 | |
|     def test_setstate_errors(self):
 | |
|         f = self.partial(signature)
 | |
|         self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
 | |
|         self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
 | |
|         self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
 | |
|         self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
 | |
|         self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
 | |
|         self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
 | |
|         self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
 | |
| 
 | |
|     def test_setstate_subclasses(self):
 | |
|         f = self.partial(signature)
 | |
|         f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
 | |
|         s = signature(f)
 | |
|         self.assertEqual(s, (capture, (1,), dict(a=10), {}))
 | |
|         self.assertIs(type(s[1]), tuple)
 | |
|         self.assertIs(type(s[2]), dict)
 | |
|         r = f()
 | |
|         self.assertEqual(r, ((1,), {'a': 10}))
 | |
|         self.assertIs(type(r[0]), tuple)
 | |
|         self.assertIs(type(r[1]), dict)
 | |
| 
 | |
|         f.__setstate__((capture, BadTuple((1,)), {}, None))
 | |
|         s = signature(f)
 | |
|         self.assertEqual(s, (capture, (1,), {}, {}))
 | |
|         self.assertIs(type(s[1]), tuple)
 | |
|         r = f(2)
 | |
|         self.assertEqual(r, ((1, 2), {}))
 | |
|         self.assertIs(type(r[0]), tuple)
 | |
| 
 | |
|     def test_recursive_pickle(self):
 | |
|         with replaced_module('functools', self.module):
 | |
|             f = self.partial(capture)
 | |
|             f.__setstate__((f, (), {}, {}))
 | |
|             try:
 | |
|                 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
 | |
|                     # gh-117008: Small limit since pickle uses C stack memory
 | |
|                     with support.infinite_recursion(100):
 | |
|                         with self.assertRaises(RecursionError):
 | |
|                             pickle.dumps(f, proto)
 | |
|             finally:
 | |
|                 f.__setstate__((capture, (), {}, {}))
 | |
| 
 | |
|             f = self.partial(capture)
 | |
|             f.__setstate__((capture, (f,), {}, {}))
 | |
|             try:
 | |
|                 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
 | |
|                     f_copy = pickle.loads(pickle.dumps(f, proto))
 | |
|                     try:
 | |
|                         self.assertIs(f_copy.args[0], f_copy)
 | |
|                     finally:
 | |
|                         f_copy.__setstate__((capture, (), {}, {}))
 | |
|             finally:
 | |
|                 f.__setstate__((capture, (), {}, {}))
 | |
| 
 | |
|             f = self.partial(capture)
 | |
|             f.__setstate__((capture, (), {'a': f}, {}))
 | |
|             try:
 | |
|                 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
 | |
|                     f_copy = pickle.loads(pickle.dumps(f, proto))
 | |
|                     try:
 | |
|                         self.assertIs(f_copy.keywords['a'], f_copy)
 | |
|                     finally:
 | |
|                         f_copy.__setstate__((capture, (), {}, {}))
 | |
|             finally:
 | |
|                 f.__setstate__((capture, (), {}, {}))
 | |
| 
 | |
|     # Issue 6083: Reference counting bug
 | |
|     def test_setstate_refcount(self):
 | |
|         class BadSequence:
 | |
|             def __len__(self):
 | |
|                 return 4
 | |
|             def __getitem__(self, key):
 | |
|                 if key == 0:
 | |
|                     return max
 | |
|                 elif key == 1:
 | |
|                     return tuple(range(1000000))
 | |
|                 elif key in (2, 3):
 | |
|                     return {}
 | |
|                 raise IndexError
 | |
| 
 | |
|         f = self.partial(object)
 | |
|         self.assertRaises(TypeError, f.__setstate__, BadSequence())
 | |
| 
 | |
| @unittest.skipUnless(c_functools, 'requires the C _functools module')
 | |
| class TestPartialC(TestPartial, unittest.TestCase):
 | |
|     if c_functools:
 | |
|         module = c_functools
 | |
|         partial = c_functools.partial
 | |
| 
 | |
|     def test_attributes_unwritable(self):
 | |
|         # attributes should not be writable
 | |
|         p = self.partial(capture, 1, 2, a=10, b=20)
 | |
|         self.assertRaises(AttributeError, setattr, p, 'func', map)
 | |
|         self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
 | |
|         self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
 | |
| 
 | |
|         p = self.partial(hex)
 | |
|         try:
 | |
|             del p.__dict__
 | |
|         except TypeError:
 | |
|             pass
 | |
|         else:
 | |
|             self.fail('partial object allowed __dict__ to be deleted')
 | |
| 
 | |
|     def test_manually_adding_non_string_keyword(self):
 | |
|         p = self.partial(capture)
 | |
|         # Adding a non-string/unicode keyword to partial kwargs
 | |
|         p.keywords[1234] = 'value'
 | |
|         r = repr(p)
 | |
|         self.assertIn('1234', r)
 | |
|         self.assertIn("'value'", r)
 | |
|         with self.assertRaises(TypeError):
 | |
|             p()
 | |
| 
 | |
|     def test_keystr_replaces_value(self):
 | |
|         p = self.partial(capture)
 | |
| 
 | |
|         class MutatesYourDict(object):
 | |
|             def __str__(self):
 | |
|                 p.keywords[self] = ['sth2']
 | |
|                 return 'astr'
 | |
| 
 | |
|         # Replacing the value during key formatting should keep the original
 | |
|         # value alive (at least long enough).
 | |
|         p.keywords[MutatesYourDict()] = ['sth']
 | |
|         r = repr(p)
 | |
|         self.assertIn('astr', r)
 | |
|         self.assertIn("['sth']", r)
 | |
| 
 | |
| 
 | |
| class TestPartialPy(TestPartial, unittest.TestCase):
 | |
|     module = py_functools
 | |
|     partial = py_functools.partial
 | |
| 
 | |
| 
 | |
| if c_functools:
 | |
|     class CPartialSubclass(c_functools.partial):
 | |
|         pass
 | |
| 
 | |
| class PyPartialSubclass(py_functools.partial):
 | |
|     pass
 | |
| 
 | |
| @unittest.skipUnless(c_functools, 'requires the C _functools module')
 | |
| class TestPartialCSubclass(TestPartialC):
 | |
|     if c_functools:
 | |
|         partial = CPartialSubclass
 | |
| 
 | |
|     # partial subclasses are not optimized for nested calls
 | |
|     test_nested_optimization = None
 | |
| 
 | |
| class TestPartialPySubclass(TestPartialPy):
 | |
|     partial = PyPartialSubclass
 | |
| 
 | |
| class TestPartialMethod(unittest.TestCase):
 | |
| 
 | |
|     class A(object):
 | |
|         nothing = functools.partialmethod(capture)
 | |
|         positional = functools.partialmethod(capture, 1)
 | |
|         keywords = functools.partialmethod(capture, a=2)
 | |
|         both = functools.partialmethod(capture, 3, b=4)
 | |
|         spec_keywords = functools.partialmethod(capture, self=1, func=2)
 | |
| 
 | |
|         nested = functools.partialmethod(positional, 5)
 | |
| 
 | |
|         over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
 | |
| 
 | |
|         static = functools.partialmethod(staticmethod(capture), 8)
 | |
|         cls = functools.partialmethod(classmethod(capture), d=9)
 | |
| 
 | |
|     a = A()
 | |
| 
 | |
|     def test_arg_combinations(self):
 | |
|         self.assertEqual(self.a.nothing(), ((self.a,), {}))
 | |
|         self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
 | |
|         self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
 | |
|         self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
 | |
| 
 | |
|         self.assertEqual(self.a.positional(), ((self.a, 1), {}))
 | |
|         self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
 | |
|         self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
 | |
|         self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
 | |
| 
 | |
|         self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
 | |
|         self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
 | |
|         self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
 | |
|         self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
 | |
| 
 | |
|         self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
 | |
|         self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
 | |
|         self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
 | |
|         self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
 | |
| 
 | |
|         self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
 | |
| 
 | |
|         self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
 | |
| 
 | |
|     def test_nested(self):
 | |
|         self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
 | |
|         self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
 | |
|         self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
 | |
|         self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
 | |
| 
 | |
|         self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
 | |
| 
 | |
|     def test_over_partial(self):
 | |
|         self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
 | |
|         self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
 | |
|         self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
 | |
|         self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
 | |
| 
 | |
|         self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
 | |
| 
 | |
|     def test_bound_method_introspection(self):
 | |
|         obj = self.a
 | |
|         self.assertIs(obj.both.__self__, obj)
 | |
|         self.assertIs(obj.nested.__self__, obj)
 | |
|         self.assertIs(obj.over_partial.__self__, obj)
 | |
|         self.assertIs(obj.cls.__self__, self.A)
 | |
|         self.assertIs(self.A.cls.__self__, self.A)
 | |
| 
 | |
|     def test_unbound_method_retrieval(self):
 | |
|         obj = self.A
 | |
|         self.assertFalse(hasattr(obj.both, "__self__"))
 | |
|         self.assertFalse(hasattr(obj.nested, "__self__"))
 | |
|         self.assertFalse(hasattr(obj.over_partial, "__self__"))
 | |
|         self.assertFalse(hasattr(obj.static, "__self__"))
 | |
|         self.assertFalse(hasattr(self.a.static, "__self__"))
 | |
| 
 | |
|     def test_descriptors(self):
 | |
|         for obj in [self.A, self.a]:
 | |
|             with self.subTest(obj=obj):
 | |
|                 self.assertEqual(obj.static(), ((8,), {}))
 | |
|                 self.assertEqual(obj.static(5), ((8, 5), {}))
 | |
|                 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
 | |
|                 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
 | |
| 
 | |
|                 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
 | |
|                 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
 | |
|                 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
 | |
|                 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
 | |
| 
 | |
|     def test_overriding_keywords(self):
 | |
|         self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
 | |
|         self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
 | |
| 
 | |
|     def test_invalid_args(self):
 | |
|         with self.assertRaises(TypeError):
 | |
|             class B(object):
 | |
|                 method = functools.partialmethod(None, 1)
 | |
|         with self.assertRaises(TypeError):
 | |
|             class B:
 | |
|                 method = functools.partialmethod()
 | |
|         with self.assertRaises(TypeError):
 | |
|             class B:
 | |
|                 method = functools.partialmethod(func=capture, a=1)
 | |
| 
 | |
|     def test_repr(self):
 | |
|         self.assertEqual(repr(vars(self.A)['nothing']),
 | |
|                          'functools.partialmethod({})'.format(capture))
 | |
|         self.assertEqual(repr(vars(self.A)['positional']),
 | |
|                          'functools.partialmethod({}, 1)'.format(capture))
 | |
|         self.assertEqual(repr(vars(self.A)['keywords']),
 | |
|                          'functools.partialmethod({}, a=2)'.format(capture))
 | |
|         self.assertEqual(repr(vars(self.A)['spec_keywords']),
 | |
|                          'functools.partialmethod({}, self=1, func=2)'.format(capture))
 | |
|         self.assertEqual(repr(vars(self.A)['both']),
 | |
|                          'functools.partialmethod({}, 3, b=4)'.format(capture))
 | |
| 
 | |
|     def test_abstract(self):
 | |
|         class Abstract(abc.ABCMeta):
 | |
| 
 | |
|             @abc.abstractmethod
 | |
|             def add(self, x, y):
 | |
|                 pass
 | |
| 
 | |
|             add5 = functools.partialmethod(add, 5)
 | |
| 
 | |
|         self.assertTrue(Abstract.add.__isabstractmethod__)
 | |
|         self.assertTrue(Abstract.add5.__isabstractmethod__)
 | |
| 
 | |
|         for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
 | |
|             self.assertFalse(getattr(func, '__isabstractmethod__', False))
 | |
| 
 | |
|     def test_positional_only(self):
 | |
|         def f(a, b, /):
 | |
|             return a + b
 | |
| 
 | |
|         p = functools.partial(f, 1)
 | |
|         self.assertEqual(p(2), f(1, 2))
 | |
| 
 | |
| 
 | |
| class TestUpdateWrapper(unittest.TestCase):
 | |
| 
 | |
|     def check_wrapper(self, wrapper, wrapped,
 | |
|                       assigned=functools.WRAPPER_ASSIGNMENTS,
 | |
|                       updated=functools.WRAPPER_UPDATES):
 | |
|         # Check attributes were assigned
 | |
|         for name in assigned:
 | |
|             self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
 | |
|         # Check attributes were updated
 | |
|         for name in updated:
 | |
|             wrapper_attr = getattr(wrapper, name)
 | |
|             wrapped_attr = getattr(wrapped, name)
 | |
|             for key in wrapped_attr:
 | |
|                 if name == "__dict__" and key == "__wrapped__":
 | |
|                     # __wrapped__ is overwritten by the update code
 | |
|                     continue
 | |
|                 self.assertIs(wrapped_attr[key], wrapper_attr[key])
 | |
|         # Check __wrapped__
 | |
|         self.assertIs(wrapper.__wrapped__, wrapped)
 | |
| 
 | |
| 
 | |
|     def _default_update(self):
 | |
|         def f[T](a:'This is a new annotation'):
 | |
|             """This is a test"""
 | |
|             pass
 | |
|         f.attr = 'This is also a test'
 | |
|         f.__wrapped__ = "This is a bald faced lie"
 | |
|         def wrapper(b:'This is the prior annotation'):
 | |
|             pass
 | |
|         functools.update_wrapper(wrapper, f)
 | |
|         return wrapper, f
 | |
| 
 | |
|     def test_default_update(self):
 | |
|         wrapper, f = self._default_update()
 | |
|         self.check_wrapper(wrapper, f)
 | |
|         T, = f.__type_params__
 | |
|         self.assertIs(wrapper.__wrapped__, f)
 | |
|         self.assertEqual(wrapper.__name__, 'f')
 | |
|         self.assertEqual(wrapper.__qualname__, f.__qualname__)
 | |
|         self.assertEqual(wrapper.attr, 'This is also a test')
 | |
|         self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
 | |
|         self.assertNotIn('b', wrapper.__annotations__)
 | |
|         self.assertEqual(wrapper.__type_params__, (T,))
 | |
| 
 | |
|     @unittest.skipIf(sys.flags.optimize >= 2,
 | |
|                      "Docstrings are omitted with -O2 and above")
 | |
|     def test_default_update_doc(self):
 | |
|         wrapper, f = self._default_update()
 | |
|         self.assertEqual(wrapper.__doc__, 'This is a test')
 | |
| 
 | |
|     def test_no_update(self):
 | |
|         def f():
 | |
|             """This is a test"""
 | |
|             pass
 | |
|         f.attr = 'This is also a test'
 | |
|         def wrapper():
 | |
|             pass
 | |
|         functools.update_wrapper(wrapper, f, (), ())
 | |
|         self.check_wrapper(wrapper, f, (), ())
 | |
|         self.assertEqual(wrapper.__name__, 'wrapper')
 | |
|         self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
 | |
|         self.assertEqual(wrapper.__doc__, None)
 | |
|         self.assertEqual(wrapper.__annotations__, {})
 | |
|         self.assertFalse(hasattr(wrapper, 'attr'))
 | |
| 
 | |
|     def test_selective_update(self):
 | |
|         def f():
 | |
|             pass
 | |
|         f.attr = 'This is a different test'
 | |
|         f.dict_attr = dict(a=1, b=2, c=3)
 | |
|         def wrapper():
 | |
|             pass
 | |
|         wrapper.dict_attr = {}
 | |
|         assign = ('attr',)
 | |
|         update = ('dict_attr',)
 | |
|         functools.update_wrapper(wrapper, f, assign, update)
 | |
|         self.check_wrapper(wrapper, f, assign, update)
 | |
|         self.assertEqual(wrapper.__name__, 'wrapper')
 | |
|         self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
 | |
|         self.assertEqual(wrapper.__doc__, None)
 | |
|         self.assertEqual(wrapper.attr, 'This is a different test')
 | |
|         self.assertEqual(wrapper.dict_attr, f.dict_attr)
 | |
| 
 | |
|     def test_missing_attributes(self):
 | |
|         def f():
 | |
|             pass
 | |
|         def wrapper():
 | |
|             pass
 | |
|         wrapper.dict_attr = {}
 | |
|         assign = ('attr',)
 | |
|         update = ('dict_attr',)
 | |
|         # Missing attributes on wrapped object are ignored
 | |
|         functools.update_wrapper(wrapper, f, assign, update)
 | |
|         self.assertNotIn('attr', wrapper.__dict__)
 | |
|         self.assertEqual(wrapper.dict_attr, {})
 | |
|         # Wrapper must have expected attributes for updating
 | |
|         del wrapper.dict_attr
 | |
|         with self.assertRaises(AttributeError):
 | |
|             functools.update_wrapper(wrapper, f, assign, update)
 | |
|         wrapper.dict_attr = 1
 | |
|         with self.assertRaises(AttributeError):
 | |
|             functools.update_wrapper(wrapper, f, assign, update)
 | |
| 
 | |
|     @support.requires_docstrings
 | |
|     @unittest.skipIf(sys.flags.optimize >= 2,
 | |
|                      "Docstrings are omitted with -O2 and above")
 | |
|     def test_builtin_update(self):
 | |
|         # Test for bug #1576241
 | |
|         def wrapper():
 | |
|             pass
 | |
|         functools.update_wrapper(wrapper, max)
 | |
|         self.assertEqual(wrapper.__name__, 'max')
 | |
|         self.assertTrue(wrapper.__doc__.startswith('max('))
 | |
|         self.assertEqual(wrapper.__annotations__, {})
 | |
| 
 | |
|     def test_update_type_wrapper(self):
 | |
|         def wrapper(*args): pass
 | |
| 
 | |
|         functools.update_wrapper(wrapper, type)
 | |
|         self.assertEqual(wrapper.__name__, 'type')
 | |
|         self.assertEqual(wrapper.__annotations__, {})
 | |
|         self.assertEqual(wrapper.__type_params__, ())
 | |
| 
 | |
| 
 | |
| class TestWraps(TestUpdateWrapper):
 | |
| 
 | |
|     def _default_update(self):
 | |
|         def f():
 | |
|             """This is a test"""
 | |
|             pass
 | |
|         f.attr = 'This is also a test'
 | |
|         f.__wrapped__ = "This is still a bald faced lie"
 | |
|         @functools.wraps(f)
 | |
|         def wrapper():
 | |
|             pass
 | |
|         return wrapper, f
 | |
| 
 | |
|     def test_default_update(self):
 | |
|         wrapper, f = self._default_update()
 | |
|         self.check_wrapper(wrapper, f)
 | |
|         self.assertEqual(wrapper.__name__, 'f')
 | |
|         self.assertEqual(wrapper.__qualname__, f.__qualname__)
 | |
|         self.assertEqual(wrapper.attr, 'This is also a test')
 | |
| 
 | |
|     @unittest.skipIf(sys.flags.optimize >= 2,
 | |
|                      "Docstrings are omitted with -O2 and above")
 | |
|     def test_default_update_doc(self):
 | |
|         wrapper, _ = self._default_update()
 | |
|         self.assertEqual(wrapper.__doc__, 'This is a test')
 | |
| 
 | |
|     def test_no_update(self):
 | |
|         def f():
 | |
|             """This is a test"""
 | |
|             pass
 | |
|         f.attr = 'This is also a test'
 | |
|         @functools.wraps(f, (), ())
 | |
|         def wrapper():
 | |
|             pass
 | |
|         self.check_wrapper(wrapper, f, (), ())
 | |
|         self.assertEqual(wrapper.__name__, 'wrapper')
 | |
|         self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
 | |
|         self.assertEqual(wrapper.__doc__, None)
 | |
|         self.assertFalse(hasattr(wrapper, 'attr'))
 | |
| 
 | |
|     def test_selective_update(self):
 | |
|         def f():
 | |
|             pass
 | |
|         f.attr = 'This is a different test'
 | |
|         f.dict_attr = dict(a=1, b=2, c=3)
 | |
|         def add_dict_attr(f):
 | |
|             f.dict_attr = {}
 | |
|             return f
 | |
|         assign = ('attr',)
 | |
|         update = ('dict_attr',)
 | |
|         @functools.wraps(f, assign, update)
 | |
|         @add_dict_attr
 | |
|         def wrapper():
 | |
|             pass
 | |
|         self.check_wrapper(wrapper, f, assign, update)
 | |
|         self.assertEqual(wrapper.__name__, 'wrapper')
 | |
|         self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
 | |
|         self.assertEqual(wrapper.__doc__, None)
 | |
|         self.assertEqual(wrapper.attr, 'This is a different test')
 | |
|         self.assertEqual(wrapper.dict_attr, f.dict_attr)
 | |
| 
 | |
| 
 | |
| class TestReduce:
 | |
|     def test_reduce(self):
 | |
|         class Squares:
 | |
|             def __init__(self, max):
 | |
|                 self.max = max
 | |
|                 self.sofar = []
 | |
| 
 | |
|             def __len__(self):
 | |
|                 return len(self.sofar)
 | |
| 
 | |
|             def __getitem__(self, i):
 | |
|                 if not 0 <= i < self.max: raise IndexError
 | |
|                 n = len(self.sofar)
 | |
|                 while n <= i:
 | |
|                     self.sofar.append(n*n)
 | |
|                     n += 1
 | |
|                 return self.sofar[i]
 | |
|         def add(x, y):
 | |
|             return x + y
 | |
|         self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
 | |
|         self.assertEqual(
 | |
|             self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
 | |
|             ['a','c','d','w']
 | |
|         )
 | |
|         self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
 | |
|         self.assertEqual(
 | |
|             self.reduce(lambda x, y: x*y, range(2,21), 1),
 | |
|             2432902008176640000
 | |
|         )
 | |
|         self.assertEqual(self.reduce(add, Squares(10)), 285)
 | |
|         self.assertEqual(self.reduce(add, Squares(10), 0), 285)
 | |
|         self.assertEqual(self.reduce(add, Squares(0), 0), 0)
 | |
|         self.assertRaises(TypeError, self.reduce)
 | |
|         self.assertRaises(TypeError, self.reduce, 42, 42)
 | |
|         self.assertRaises(TypeError, self.reduce, 42, 42, 42)
 | |
|         self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
 | |
|         self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
 | |
|         self.assertRaises(TypeError, self.reduce, 42, (42, 42))
 | |
|         self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
 | |
|         self.assertRaises(TypeError, self.reduce, add, "")
 | |
|         self.assertRaises(TypeError, self.reduce, add, ())
 | |
|         self.assertRaises(TypeError, self.reduce, add, object())
 | |
| 
 | |
|         class TestFailingIter:
 | |
|             def __iter__(self):
 | |
|                 raise RuntimeError
 | |
|         self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
 | |
| 
 | |
|         self.assertEqual(self.reduce(add, [], None), None)
 | |
|         self.assertEqual(self.reduce(add, [], 42), 42)
 | |
| 
 | |
|         class BadSeq:
 | |
|             def __getitem__(self, index):
 | |
|                 raise ValueError
 | |
|         self.assertRaises(ValueError, self.reduce, 42, BadSeq())
 | |
| 
 | |
|     # Test reduce()'s use of iterators.
 | |
|     def test_iterator_usage(self):
 | |
|         class SequenceClass:
 | |
|             def __init__(self, n):
 | |
|                 self.n = n
 | |
|             def __getitem__(self, i):
 | |
|                 if 0 <= i < self.n:
 | |
|                     return i
 | |
|                 else:
 | |
|                     raise IndexError
 | |
| 
 | |
|         from operator import add
 | |
|         self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
 | |
|         self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
 | |
|         self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
 | |
|         self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
 | |
|         self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
 | |
|         self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
 | |
| 
 | |
|         d = {"one": 1, "two": 2, "three": 3}
 | |
|         self.assertEqual(self.reduce(add, d), "".join(d.keys()))
 | |
| 
 | |
| 
 | |
| @unittest.skipUnless(c_functools, 'requires the C _functools module')
 | |
| class TestReduceC(TestReduce, unittest.TestCase):
 | |
|     if c_functools:
 | |
|         reduce = c_functools.reduce
 | |
| 
 | |
| 
 | |
| class TestReducePy(TestReduce, unittest.TestCase):
 | |
|     reduce = staticmethod(py_functools.reduce)
 | |
| 
 | |
| 
 | |
| class TestCmpToKey:
 | |
| 
 | |
|     def test_cmp_to_key(self):
 | |
|         def cmp1(x, y):
 | |
|             return (x > y) - (x < y)
 | |
|         key = self.cmp_to_key(cmp1)
 | |
|         self.assertEqual(key(3), key(3))
 | |
|         self.assertGreater(key(3), key(1))
 | |
|         self.assertGreaterEqual(key(3), key(3))
 | |
| 
 | |
|         def cmp2(x, y):
 | |
|             return int(x) - int(y)
 | |
|         key = self.cmp_to_key(cmp2)
 | |
|         self.assertEqual(key(4.0), key('4'))
 | |
|         self.assertLess(key(2), key('35'))
 | |
|         self.assertLessEqual(key(2), key('35'))
 | |
|         self.assertNotEqual(key(2), key('35'))
 | |
| 
 | |
|     def test_cmp_to_key_arguments(self):
 | |
|         def cmp1(x, y):
 | |
|             return (x > y) - (x < y)
 | |
|         key = self.cmp_to_key(mycmp=cmp1)
 | |
|         self.assertEqual(key(obj=3), key(obj=3))
 | |
|         self.assertGreater(key(obj=3), key(obj=1))
 | |
|         with self.assertRaises((TypeError, AttributeError)):
 | |
|             key(3) > 1    # rhs is not a K object
 | |
|         with self.assertRaises((TypeError, AttributeError)):
 | |
|             1 < key(3)    # lhs is not a K object
 | |
|         with self.assertRaises(TypeError):
 | |
|             key = self.cmp_to_key()             # too few args
 | |
|         with self.assertRaises(TypeError):
 | |
|             key = self.cmp_to_key(cmp1, None)   # too many args
 | |
|         key = self.cmp_to_key(cmp1)
 | |
|         with self.assertRaises(TypeError):
 | |
|             key()                                    # too few args
 | |
|         with self.assertRaises(TypeError):
 | |
|             key(None, None)                          # too many args
 | |
| 
 | |
|     def test_bad_cmp(self):
 | |
|         def cmp1(x, y):
 | |
|             raise ZeroDivisionError
 | |
|         key = self.cmp_to_key(cmp1)
 | |
|         with self.assertRaises(ZeroDivisionError):
 | |
|             key(3) > key(1)
 | |
| 
 | |
|         class BadCmp:
 | |
|             def __lt__(self, other):
 | |
|                 raise ZeroDivisionError
 | |
|         def cmp1(x, y):
 | |
|             return BadCmp()
 | |
|         with self.assertRaises(ZeroDivisionError):
 | |
|             key(3) > key(1)
 | |
| 
 | |
|     def test_obj_field(self):
 | |
|         def cmp1(x, y):
 | |
|             return (x > y) - (x < y)
 | |
|         key = self.cmp_to_key(mycmp=cmp1)
 | |
|         self.assertEqual(key(50).obj, 50)
 | |
| 
 | |
|     def test_sort_int(self):
 | |
|         def mycmp(x, y):
 | |
|             return y - x
 | |
|         self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
 | |
|                          [4, 3, 2, 1, 0])
 | |
| 
 | |
|     def test_sort_int_str(self):
 | |
|         def mycmp(x, y):
 | |
|             x, y = int(x), int(y)
 | |
|             return (x > y) - (x < y)
 | |
|         values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
 | |
|         values = sorted(values, key=self.cmp_to_key(mycmp))
 | |
|         self.assertEqual([int(value) for value in values],
 | |
|                          [0, 1, 1, 2, 3, 4, 5, 7, 10])
 | |
| 
 | |
|     def test_hash(self):
 | |
|         def mycmp(x, y):
 | |
|             return y - x
 | |
|         key = self.cmp_to_key(mycmp)
 | |
|         k = key(10)
 | |
|         self.assertRaises(TypeError, hash, k)
 | |
|         self.assertNotIsInstance(k, collections.abc.Hashable)
 | |
| 
 | |
|     @unittest.skipIf(support.MISSING_C_DOCSTRINGS,
 | |
|                      "Signature information for builtins requires docstrings")
 | |
|     def test_cmp_to_signature(self):
 | |
|         sig = Signature.from_callable(self.cmp_to_key)
 | |
|         self.assertEqual(str(sig), '(mycmp)')
 | |
|         def mycmp(x, y):
 | |
|             return y - x
 | |
|         sig = Signature.from_callable(self.cmp_to_key(mycmp))
 | |
|         self.assertEqual(str(sig), '(obj)')
 | |
| 
 | |
| 
 | |
| 
 | |
| @unittest.skipUnless(c_functools, 'requires the C _functools module')
 | |
| class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
 | |
|     if c_functools:
 | |
|         cmp_to_key = c_functools.cmp_to_key
 | |
| 
 | |
|     @support.cpython_only
 | |
|     def test_disallow_instantiation(self):
 | |
|         # Ensure that the type disallows instantiation (bpo-43916)
 | |
|         support.check_disallow_instantiation(
 | |
|             self, type(c_functools.cmp_to_key(None))
 | |
|         )
 | |
| 
 | |
| 
 | |
| class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
 | |
|     cmp_to_key = staticmethod(py_functools.cmp_to_key)
 | |
| 
 | |
| 
 | |
| class TestTotalOrdering(unittest.TestCase):
 | |
| 
 | |
|     def test_total_ordering_lt(self):
 | |
|         @functools.total_ordering
 | |
|         class A:
 | |
|             def __init__(self, value):
 | |
|                 self.value = value
 | |
|             def __lt__(self, other):
 | |
|                 return self.value < other.value
 | |
|             def __eq__(self, other):
 | |
|                 return self.value == other.value
 | |
|         self.assertTrue(A(1) < A(2))
 | |
|         self.assertTrue(A(2) > A(1))
 | |
|         self.assertTrue(A(1) <= A(2))
 | |
|         self.assertTrue(A(2) >= A(1))
 | |
|         self.assertTrue(A(2) <= A(2))
 | |
|         self.assertTrue(A(2) >= A(2))
 | |
|         self.assertFalse(A(1) > A(2))
 | |
| 
 | |
|     def test_total_ordering_le(self):
 | |
|         @functools.total_ordering
 | |
|         class A:
 | |
|             def __init__(self, value):
 | |
|                 self.value = value
 | |
|             def __le__(self, other):
 | |
|                 return self.value <= other.value
 | |
|             def __eq__(self, other):
 | |
|                 return self.value == other.value
 | |
|         self.assertTrue(A(1) < A(2))
 | |
|         self.assertTrue(A(2) > A(1))
 | |
|         self.assertTrue(A(1) <= A(2))
 | |
|         self.assertTrue(A(2) >= A(1))
 | |
|         self.assertTrue(A(2) <= A(2))
 | |
|         self.assertTrue(A(2) >= A(2))
 | |
|         self.assertFalse(A(1) >= A(2))
 | |
| 
 | |
|     def test_total_ordering_gt(self):
 | |
|         @functools.total_ordering
 | |
|         class A:
 | |
|             def __init__(self, value):
 | |
|                 self.value = value
 | |
|             def __gt__(self, other):
 | |
|                 return self.value > other.value
 | |
|             def __eq__(self, other):
 | |
|                 return self.value == other.value
 | |
|         self.assertTrue(A(1) < A(2))
 | |
|         self.assertTrue(A(2) > A(1))
 | |
|         self.assertTrue(A(1) <= A(2))
 | |
|         self.assertTrue(A(2) >= A(1))
 | |
|         self.assertTrue(A(2) <= A(2))
 | |
|         self.assertTrue(A(2) >= A(2))
 | |
|         self.assertFalse(A(2) < A(1))
 | |
| 
 | |
|     def test_total_ordering_ge(self):
 | |
|         @functools.total_ordering
 | |
|         class A:
 | |
|             def __init__(self, value):
 | |
|                 self.value = value
 | |
|             def __ge__(self, other):
 | |
|                 return self.value >= other.value
 | |
|             def __eq__(self, other):
 | |
|                 return self.value == other.value
 | |
|         self.assertTrue(A(1) < A(2))
 | |
|         self.assertTrue(A(2) > A(1))
 | |
|         self.assertTrue(A(1) <= A(2))
 | |
|         self.assertTrue(A(2) >= A(1))
 | |
|         self.assertTrue(A(2) <= A(2))
 | |
|         self.assertTrue(A(2) >= A(2))
 | |
|         self.assertFalse(A(2) <= A(1))
 | |
| 
 | |
|     def test_total_ordering_no_overwrite(self):
 | |
|         # new methods should not overwrite existing
 | |
|         @functools.total_ordering
 | |
|         class A(int):
 | |
|             pass
 | |
|         self.assertTrue(A(1) < A(2))
 | |
|         self.assertTrue(A(2) > A(1))
 | |
|         self.assertTrue(A(1) <= A(2))
 | |
|         self.assertTrue(A(2) >= A(1))
 | |
|         self.assertTrue(A(2) <= A(2))
 | |
|         self.assertTrue(A(2) >= A(2))
 | |
| 
 | |
|     def test_no_operations_defined(self):
 | |
|         with self.assertRaises(ValueError):
 | |
|             @functools.total_ordering
 | |
|             class A:
 | |
|                 pass
 | |
| 
 | |
|     def test_notimplemented(self):
 | |
|         # Verify NotImplemented results are correctly handled
 | |
|         @functools.total_ordering
 | |
|         class ImplementsLessThan:
 | |
|             def __init__(self, value):
 | |
|                 self.value = value
 | |
|             def __eq__(self, other):
 | |
|                 if isinstance(other, ImplementsLessThan):
 | |
|                     return self.value == other.value
 | |
|                 return False
 | |
|             def __lt__(self, other):
 | |
|                 if isinstance(other, ImplementsLessThan):
 | |
|                     return self.value < other.value
 | |
|                 return NotImplemented
 | |
| 
 | |
|         @functools.total_ordering
 | |
|         class ImplementsLessThanEqualTo:
 | |
|             def __init__(self, value):
 | |
|                 self.value = value
 | |
|             def __eq__(self, other):
 | |
|                 if isinstance(other, ImplementsLessThanEqualTo):
 | |
|                     return self.value == other.value
 | |
|                 return False
 | |
|             def __le__(self, other):
 | |
|                 if isinstance(other, ImplementsLessThanEqualTo):
 | |
|                     return self.value <= other.value
 | |
|                 return NotImplemented
 | |
| 
 | |
|         @functools.total_ordering
 | |
|         class ImplementsGreaterThan:
 | |
|             def __init__(self, value):
 | |
|                 self.value = value
 | |
|             def __eq__(self, other):
 | |
|                 if isinstance(other, ImplementsGreaterThan):
 | |
|                     return self.value == other.value
 | |
|                 return False
 | |
|             def __gt__(self, other):
 | |
|                 if isinstance(other, ImplementsGreaterThan):
 | |
|                     return self.value > other.value
 | |
|                 return NotImplemented
 | |
| 
 | |
|         @functools.total_ordering
 | |
|         class ImplementsGreaterThanEqualTo:
 | |
|             def __init__(self, value):
 | |
|                 self.value = value
 | |
|             def __eq__(self, other):
 | |
|                 if isinstance(other, ImplementsGreaterThanEqualTo):
 | |
|                     return self.value == other.value
 | |
|                 return False
 | |
|             def __ge__(self, other):
 | |
|                 if isinstance(other, ImplementsGreaterThanEqualTo):
 | |
|                     return self.value >= other.value
 | |
|                 return NotImplemented
 | |
| 
 | |
|         self.assertIs(ImplementsLessThan(1).__le__(1), NotImplemented)
 | |
|         self.assertIs(ImplementsLessThan(1).__gt__(1), NotImplemented)
 | |
|         self.assertIs(ImplementsLessThan(1).__ge__(1), NotImplemented)
 | |
|         self.assertIs(ImplementsLessThanEqualTo(1).__lt__(1), NotImplemented)
 | |
|         self.assertIs(ImplementsLessThanEqualTo(1).__gt__(1), NotImplemented)
 | |
|         self.assertIs(ImplementsLessThanEqualTo(1).__ge__(1), NotImplemented)
 | |
|         self.assertIs(ImplementsGreaterThan(1).__lt__(1), NotImplemented)
 | |
|         self.assertIs(ImplementsGreaterThan(1).__gt__(1), NotImplemented)
 | |
|         self.assertIs(ImplementsGreaterThan(1).__ge__(1), NotImplemented)
 | |
|         self.assertIs(ImplementsGreaterThanEqualTo(1).__lt__(1), NotImplemented)
 | |
|         self.assertIs(ImplementsGreaterThanEqualTo(1).__le__(1), NotImplemented)
 | |
|         self.assertIs(ImplementsGreaterThanEqualTo(1).__gt__(1), NotImplemented)
 | |
| 
 | |
|     def test_type_error_when_not_implemented(self):
 | |
|         # bug 10042; ensure stack overflow does not occur
 | |
|         # when decorated types return NotImplemented
 | |
|         @functools.total_ordering
 | |
|         class ImplementsLessThan:
 | |
|             def __init__(self, value):
 | |
|                 self.value = value
 | |
|             def __eq__(self, other):
 | |
|                 if isinstance(other, ImplementsLessThan):
 | |
|                     return self.value == other.value
 | |
|                 return False
 | |
|             def __lt__(self, other):
 | |
|                 if isinstance(other, ImplementsLessThan):
 | |
|                     return self.value < other.value
 | |
|                 return NotImplemented
 | |
| 
 | |
|         @functools.total_ordering
 | |
|         class ImplementsGreaterThan:
 | |
|             def __init__(self, value):
 | |
|                 self.value = value
 | |
|             def __eq__(self, other):
 | |
|                 if isinstance(other, ImplementsGreaterThan):
 | |
|                     return self.value == other.value
 | |
|                 return False
 | |
|             def __gt__(self, other):
 | |
|                 if isinstance(other, ImplementsGreaterThan):
 | |
|                     return self.value > other.value
 | |
|                 return NotImplemented
 | |
| 
 | |
|         @functools.total_ordering
 | |
|         class ImplementsLessThanEqualTo:
 | |
|             def __init__(self, value):
 | |
|                 self.value = value
 | |
|             def __eq__(self, other):
 | |
|                 if isinstance(other, ImplementsLessThanEqualTo):
 | |
|                     return self.value == other.value
 | |
|                 return False
 | |
|             def __le__(self, other):
 | |
|                 if isinstance(other, ImplementsLessThanEqualTo):
 | |
|                     return self.value <= other.value
 | |
|                 return NotImplemented
 | |
| 
 | |
|         @functools.total_ordering
 | |
|         class ImplementsGreaterThanEqualTo:
 | |
|             def __init__(self, value):
 | |
|                 self.value = value
 | |
|             def __eq__(self, other):
 | |
|                 if isinstance(other, ImplementsGreaterThanEqualTo):
 | |
|                     return self.value == other.value
 | |
|                 return False
 | |
|             def __ge__(self, other):
 | |
|                 if isinstance(other, ImplementsGreaterThanEqualTo):
 | |
|                     return self.value >= other.value
 | |
|                 return NotImplemented
 | |
| 
 | |
|         @functools.total_ordering
 | |
|         class ComparatorNotImplemented:
 | |
|             def __init__(self, value):
 | |
|                 self.value = value
 | |
|             def __eq__(self, other):
 | |
|                 if isinstance(other, ComparatorNotImplemented):
 | |
|                     return self.value == other.value
 | |
|                 return False
 | |
|             def __lt__(self, other):
 | |
|                 return NotImplemented
 | |
| 
 | |
|         with self.subTest("LT < 1"), self.assertRaises(TypeError):
 | |
|             ImplementsLessThan(-1) < 1
 | |
| 
 | |
|         with self.subTest("LT < LE"), self.assertRaises(TypeError):
 | |
|             ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
 | |
| 
 | |
|         with self.subTest("LT < GT"), self.assertRaises(TypeError):
 | |
|             ImplementsLessThan(1) < ImplementsGreaterThan(1)
 | |
| 
 | |
|         with self.subTest("LE <= LT"), self.assertRaises(TypeError):
 | |
|             ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
 | |
| 
 | |
|         with self.subTest("LE <= GE"), self.assertRaises(TypeError):
 | |
|             ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
 | |
| 
 | |
|         with self.subTest("GT > GE"), self.assertRaises(TypeError):
 | |
|             ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
 | |
| 
 | |
|         with self.subTest("GT > LT"), self.assertRaises(TypeError):
 | |
|             ImplementsGreaterThan(5) > ImplementsLessThan(5)
 | |
| 
 | |
|         with self.subTest("GE >= GT"), self.assertRaises(TypeError):
 | |
|             ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
 | |
| 
 | |
|         with self.subTest("GE >= LE"), self.assertRaises(TypeError):
 | |
|             ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
 | |
| 
 | |
|         with self.subTest("GE when equal"):
 | |
|             a = ComparatorNotImplemented(8)
 | |
|             b = ComparatorNotImplemented(8)
 | |
|             self.assertEqual(a, b)
 | |
|             with self.assertRaises(TypeError):
 | |
|                 a >= b
 | |
| 
 | |
|         with self.subTest("LE when equal"):
 | |
|             a = ComparatorNotImplemented(9)
 | |
|             b = ComparatorNotImplemented(9)
 | |
|             self.assertEqual(a, b)
 | |
|             with self.assertRaises(TypeError):
 | |
|                 a <= b
 | |
| 
 | |
|     def test_pickle(self):
 | |
|         for proto in range(pickle.HIGHEST_PROTOCOL + 1):
 | |
|             for name in '__lt__', '__gt__', '__le__', '__ge__':
 | |
|                 with self.subTest(method=name, proto=proto):
 | |
|                     method = getattr(Orderable_LT, name)
 | |
|                     method_copy = pickle.loads(pickle.dumps(method, proto))
 | |
|                     self.assertIs(method_copy, method)
 | |
| 
 | |
| 
 | |
|     def test_total_ordering_for_metaclasses_issue_44605(self):
 | |
| 
 | |
|         @functools.total_ordering
 | |
|         class SortableMeta(type):
 | |
|             def __new__(cls, name, bases, ns):
 | |
|                 return super().__new__(cls, name, bases, ns)
 | |
| 
 | |
|             def __lt__(self, other):
 | |
|                 if not isinstance(other, SortableMeta):
 | |
|                     pass
 | |
|                 return self.__name__ < other.__name__
 | |
| 
 | |
|             def __eq__(self, other):
 | |
|                 if not isinstance(other, SortableMeta):
 | |
|                     pass
 | |
|                 return self.__name__ == other.__name__
 | |
| 
 | |
|         class B(metaclass=SortableMeta):
 | |
|             pass
 | |
| 
 | |
|         class A(metaclass=SortableMeta):
 | |
|             pass
 | |
| 
 | |
|         self.assertTrue(A < B)
 | |
|         self.assertFalse(A > B)
 | |
| 
 | |
| 
 | |
| @functools.total_ordering
 | |
| class Orderable_LT:
 | |
|     def __init__(self, value):
 | |
|         self.value = value
 | |
|     def __lt__(self, other):
 | |
|         return self.value < other.value
 | |
|     def __eq__(self, other):
 | |
|         return self.value == other.value
 | |
| 
 | |
| 
 | |
| class TestCache:
 | |
|     # This tests that the pass-through is working as designed.
 | |
|     # The underlying functionality is tested in TestLRU.
 | |
| 
 | |
|     def test_cache(self):
 | |
|         @self.module.cache
 | |
|         def fib(n):
 | |
|             if n < 2:
 | |
|                 return n
 | |
|             return fib(n-1) + fib(n-2)
 | |
|         self.assertEqual([fib(n) for n in range(16)],
 | |
|             [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
 | |
|         self.assertEqual(fib.cache_info(),
 | |
|             self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
 | |
|         fib.cache_clear()
 | |
|         self.assertEqual(fib.cache_info(),
 | |
|             self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
 | |
| 
 | |
| 
 | |
| class TestLRU:
 | |
| 
 | |
|     def test_lru(self):
 | |
|         def orig(x, y):
 | |
|             return 3 * x + y
 | |
|         f = self.module.lru_cache(maxsize=20)(orig)
 | |
|         hits, misses, maxsize, currsize = f.cache_info()
 | |
|         self.assertEqual(maxsize, 20)
 | |
|         self.assertEqual(currsize, 0)
 | |
|         self.assertEqual(hits, 0)
 | |
|         self.assertEqual(misses, 0)
 | |
| 
 | |
|         domain = range(5)
 | |
|         for i in range(1000):
 | |
|             x, y = choice(domain), choice(domain)
 | |
|             actual = f(x, y)
 | |
|             expected = orig(x, y)
 | |
|             self.assertEqual(actual, expected)
 | |
|         hits, misses, maxsize, currsize = f.cache_info()
 | |
|         self.assertTrue(hits > misses)
 | |
|         self.assertEqual(hits + misses, 1000)
 | |
|         self.assertEqual(currsize, 20)
 | |
| 
 | |
|         f.cache_clear()   # test clearing
 | |
|         hits, misses, maxsize, currsize = f.cache_info()
 | |
|         self.assertEqual(hits, 0)
 | |
|         self.assertEqual(misses, 0)
 | |
|         self.assertEqual(currsize, 0)
 | |
|         f(x, y)
 | |
|         hits, misses, maxsize, currsize = f.cache_info()
 | |
|         self.assertEqual(hits, 0)
 | |
|         self.assertEqual(misses, 1)
 | |
|         self.assertEqual(currsize, 1)
 | |
| 
 | |
|         # Test bypassing the cache
 | |
|         self.assertIs(f.__wrapped__, orig)
 | |
|         f.__wrapped__(x, y)
 | |
|         hits, misses, maxsize, currsize = f.cache_info()
 | |
|         self.assertEqual(hits, 0)
 | |
|         self.assertEqual(misses, 1)
 | |
|         self.assertEqual(currsize, 1)
 | |
| 
 | |
|         # test size zero (which means "never-cache")
 | |
|         @self.module.lru_cache(0)
 | |
|         def f():
 | |
|             nonlocal f_cnt
 | |
|             f_cnt += 1
 | |
|             return 20
 | |
|         self.assertEqual(f.cache_info().maxsize, 0)
 | |
|         f_cnt = 0
 | |
|         for i in range(5):
 | |
|             self.assertEqual(f(), 20)
 | |
|         self.assertEqual(f_cnt, 5)
 | |
|         hits, misses, maxsize, currsize = f.cache_info()
 | |
|         self.assertEqual(hits, 0)
 | |
|         self.assertEqual(misses, 5)
 | |
|         self.assertEqual(currsize, 0)
 | |
| 
 | |
|         # test size one
 | |
|         @self.module.lru_cache(1)
 | |
|         def f():
 | |
|             nonlocal f_cnt
 | |
|             f_cnt += 1
 | |
|             return 20
 | |
|         self.assertEqual(f.cache_info().maxsize, 1)
 | |
|         f_cnt = 0
 | |
|         for i in range(5):
 | |
|             self.assertEqual(f(), 20)
 | |
|         self.assertEqual(f_cnt, 1)
 | |
|         hits, misses, maxsize, currsize = f.cache_info()
 | |
|         self.assertEqual(hits, 4)
 | |
|         self.assertEqual(misses, 1)
 | |
|         self.assertEqual(currsize, 1)
 | |
| 
 | |
|         # test size two
 | |
|         @self.module.lru_cache(2)
 | |
|         def f(x):
 | |
|             nonlocal f_cnt
 | |
|             f_cnt += 1
 | |
|             return x*10
 | |
|         self.assertEqual(f.cache_info().maxsize, 2)
 | |
|         f_cnt = 0
 | |
|         for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
 | |
|             #    *  *              *                          *
 | |
|             self.assertEqual(f(x), x*10)
 | |
|         self.assertEqual(f_cnt, 4)
 | |
|         hits, misses, maxsize, currsize = f.cache_info()
 | |
|         self.assertEqual(hits, 12)
 | |
|         self.assertEqual(misses, 4)
 | |
|         self.assertEqual(currsize, 2)
 | |
| 
 | |
|     def test_lru_no_args(self):
 | |
|         @self.module.lru_cache
 | |
|         def square(x):
 | |
|             return x ** 2
 | |
| 
 | |
|         self.assertEqual(list(map(square, [10, 20, 10])),
 | |
|                          [100, 400, 100])
 | |
|         self.assertEqual(square.cache_info().hits, 1)
 | |
|         self.assertEqual(square.cache_info().misses, 2)
 | |
|         self.assertEqual(square.cache_info().maxsize, 128)
 | |
|         self.assertEqual(square.cache_info().currsize, 2)
 | |
| 
 | |
|     def test_lru_bug_35780(self):
 | |
|         # C version of the lru_cache was not checking to see if
 | |
|         # the user function call has already modified the cache
 | |
|         # (this arises in recursive calls and in multi-threading).
 | |
|         # This cause the cache to have orphan links not referenced
 | |
|         # by the cache dictionary.
 | |
| 
 | |
|         once = True                 # Modified by f(x) below
 | |
| 
 | |
|         @self.module.lru_cache(maxsize=10)
 | |
|         def f(x):
 | |
|             nonlocal once
 | |
|             rv = f'.{x}.'
 | |
|             if x == 20 and once:
 | |
|                 once = False
 | |
|                 rv = f(x)
 | |
|             return rv
 | |
| 
 | |
|         # Fill the cache
 | |
|         for x in range(15):
 | |
|             self.assertEqual(f(x), f'.{x}.')
 | |
|         self.assertEqual(f.cache_info().currsize, 10)
 | |
| 
 | |
|         # Make a recursive call and make sure the cache remains full
 | |
|         self.assertEqual(f(20), '.20.')
 | |
|         self.assertEqual(f.cache_info().currsize, 10)
 | |
| 
 | |
|     def test_lru_bug_36650(self):
 | |
|         # C version of lru_cache was treating a call with an empty **kwargs
 | |
|         # dictionary as being distinct from a call with no keywords at all.
 | |
|         # This did not result in an incorrect answer, but it did trigger
 | |
|         # an unexpected cache miss.
 | |
| 
 | |
|         @self.module.lru_cache()
 | |
|         def f(x):
 | |
|             pass
 | |
| 
 | |
|         f(0)
 | |
|         f(0, **{})
 | |
|         self.assertEqual(f.cache_info().hits, 1)
 | |
| 
 | |
|     def test_lru_hash_only_once(self):
 | |
|         # To protect against weird reentrancy bugs and to improve
 | |
|         # efficiency when faced with slow __hash__ methods, the
 | |
|         # LRU cache guarantees that it will only call __hash__
 | |
|         # only once per use as an argument to the cached function.
 | |
| 
 | |
|         @self.module.lru_cache(maxsize=1)
 | |
|         def f(x, y):
 | |
|             return x * 3 + y
 | |
| 
 | |
|         # Simulate the integer 5
 | |
|         mock_int = unittest.mock.Mock()
 | |
|         mock_int.__mul__ = unittest.mock.Mock(return_value=15)
 | |
|         mock_int.__hash__ = unittest.mock.Mock(return_value=999)
 | |
| 
 | |
|         # Add to cache:  One use as an argument gives one call
 | |
|         self.assertEqual(f(mock_int, 1), 16)
 | |
|         self.assertEqual(mock_int.__hash__.call_count, 1)
 | |
|         self.assertEqual(f.cache_info(), (0, 1, 1, 1))
 | |
| 
 | |
|         # Cache hit: One use as an argument gives one additional call
 | |
|         self.assertEqual(f(mock_int, 1), 16)
 | |
|         self.assertEqual(mock_int.__hash__.call_count, 2)
 | |
|         self.assertEqual(f.cache_info(), (1, 1, 1, 1))
 | |
| 
 | |
|         # Cache eviction: No use as an argument gives no additional call
 | |
|         self.assertEqual(f(6, 2), 20)
 | |
|         self.assertEqual(mock_int.__hash__.call_count, 2)
 | |
|         self.assertEqual(f.cache_info(), (1, 2, 1, 1))
 | |
| 
 | |
|         # Cache miss: One use as an argument gives one additional call
 | |
|         self.assertEqual(f(mock_int, 1), 16)
 | |
|         self.assertEqual(mock_int.__hash__.call_count, 3)
 | |
|         self.assertEqual(f.cache_info(), (1, 3, 1, 1))
 | |
| 
 | |
|     def test_lru_reentrancy_with_len(self):
 | |
|         # Test to make sure the LRU cache code isn't thrown-off by
 | |
|         # caching the built-in len() function.  Since len() can be
 | |
|         # cached, we shouldn't use it inside the lru code itself.
 | |
|         old_len = builtins.len
 | |
|         try:
 | |
|             builtins.len = self.module.lru_cache(4)(len)
 | |
|             for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
 | |
|                 self.assertEqual(len('abcdefghijklmn'[:i]), i)
 | |
|         finally:
 | |
|             builtins.len = old_len
 | |
| 
 | |
|     def test_lru_star_arg_handling(self):
 | |
|         # Test regression that arose in ea064ff3c10f
 | |
|         @self.module.lru_cache()
 | |
|         def f(*args):
 | |
|             return args
 | |
| 
 | |
|         self.assertEqual(f(1, 2), (1, 2))
 | |
|         self.assertEqual(f((1, 2)), ((1, 2),))
 | |
| 
 | |
|     def test_lru_type_error(self):
 | |
|         # Regression test for issue #28653.
 | |
|         # lru_cache was leaking when one of the arguments
 | |
|         # wasn't cacheable.
 | |
| 
 | |
|         @self.module.lru_cache(maxsize=None)
 | |
|         def infinite_cache(o):
 | |
|             pass
 | |
| 
 | |
|         @self.module.lru_cache(maxsize=10)
 | |
|         def limited_cache(o):
 | |
|             pass
 | |
| 
 | |
|         with self.assertRaises(TypeError):
 | |
|             infinite_cache([])
 | |
| 
 | |
|         with self.assertRaises(TypeError):
 | |
|             limited_cache([])
 | |
| 
 | |
|     def test_lru_with_maxsize_none(self):
 | |
|         @self.module.lru_cache(maxsize=None)
 | |
|         def fib(n):
 | |
|             if n < 2:
 | |
|                 return n
 | |
|             return fib(n-1) + fib(n-2)
 | |
|         self.assertEqual([fib(n) for n in range(16)],
 | |
|             [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
 | |
|         self.assertEqual(fib.cache_info(),
 | |
|             self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
 | |
|         fib.cache_clear()
 | |
|         self.assertEqual(fib.cache_info(),
 | |
|             self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
 | |
| 
 | |
|     def test_lru_with_maxsize_negative(self):
 | |
|         @self.module.lru_cache(maxsize=-10)
 | |
|         def eq(n):
 | |
|             return n
 | |
|         for i in (0, 1):
 | |
|             self.assertEqual([eq(n) for n in range(150)], list(range(150)))
 | |
|         self.assertEqual(eq.cache_info(),
 | |
|             self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
 | |
| 
 | |
|     def test_lru_with_exceptions(self):
 | |
|         # Verify that user_function exceptions get passed through without
 | |
|         # creating a hard-to-read chained exception.
 | |
|         # http://bugs.python.org/issue13177
 | |
|         for maxsize in (None, 128):
 | |
|             @self.module.lru_cache(maxsize)
 | |
|             def func(i):
 | |
|                 return 'abc'[i]
 | |
|             self.assertEqual(func(0), 'a')
 | |
|             with self.assertRaises(IndexError) as cm:
 | |
|                 func(15)
 | |
|             self.assertIsNone(cm.exception.__context__)
 | |
|             # Verify that the previous exception did not result in a cached entry
 | |
|             with self.assertRaises(IndexError):
 | |
|                 func(15)
 | |
| 
 | |
|     def test_lru_with_types(self):
 | |
|         for maxsize in (None, 128):
 | |
|             @self.module.lru_cache(maxsize=maxsize, typed=True)
 | |
|             def square(x):
 | |
|                 return x * x
 | |
|             self.assertEqual(square(3), 9)
 | |
|             self.assertEqual(type(square(3)), type(9))
 | |
|             self.assertEqual(square(3.0), 9.0)
 | |
|             self.assertEqual(type(square(3.0)), type(9.0))
 | |
|             self.assertEqual(square(x=3), 9)
 | |
|             self.assertEqual(type(square(x=3)), type(9))
 | |
|             self.assertEqual(square(x=3.0), 9.0)
 | |
|             self.assertEqual(type(square(x=3.0)), type(9.0))
 | |
|             self.assertEqual(square.cache_info().hits, 4)
 | |
|             self.assertEqual(square.cache_info().misses, 4)
 | |
| 
 | |
|     def test_lru_cache_typed_is_not_recursive(self):
 | |
|         cached = self.module.lru_cache(typed=True)(repr)
 | |
| 
 | |
|         self.assertEqual(cached(1), '1')
 | |
|         self.assertEqual(cached(True), 'True')
 | |
|         self.assertEqual(cached(1.0), '1.0')
 | |
|         self.assertEqual(cached(0), '0')
 | |
|         self.assertEqual(cached(False), 'False')
 | |
|         self.assertEqual(cached(0.0), '0.0')
 | |
| 
 | |
|         self.assertEqual(cached((1,)), '(1,)')
 | |
|         self.assertEqual(cached((True,)), '(1,)')
 | |
|         self.assertEqual(cached((1.0,)), '(1,)')
 | |
|         self.assertEqual(cached((0,)), '(0,)')
 | |
|         self.assertEqual(cached((False,)), '(0,)')
 | |
|         self.assertEqual(cached((0.0,)), '(0,)')
 | |
| 
 | |
|         class T(tuple):
 | |
|             pass
 | |
| 
 | |
|         self.assertEqual(cached(T((1,))), '(1,)')
 | |
|         self.assertEqual(cached(T((True,))), '(1,)')
 | |
|         self.assertEqual(cached(T((1.0,))), '(1,)')
 | |
|         self.assertEqual(cached(T((0,))), '(0,)')
 | |
|         self.assertEqual(cached(T((False,))), '(0,)')
 | |
|         self.assertEqual(cached(T((0.0,))), '(0,)')
 | |
| 
 | |
|     def test_lru_with_keyword_args(self):
 | |
|         @self.module.lru_cache()
 | |
|         def fib(n):
 | |
|             if n < 2:
 | |
|                 return n
 | |
|             return fib(n=n-1) + fib(n=n-2)
 | |
|         self.assertEqual(
 | |
|             [fib(n=number) for number in range(16)],
 | |
|             [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
 | |
|         )
 | |
|         self.assertEqual(fib.cache_info(),
 | |
|             self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
 | |
|         fib.cache_clear()
 | |
|         self.assertEqual(fib.cache_info(),
 | |
|             self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
 | |
| 
 | |
|     def test_lru_with_keyword_args_maxsize_none(self):
 | |
|         @self.module.lru_cache(maxsize=None)
 | |
|         def fib(n):
 | |
|             if n < 2:
 | |
|                 return n
 | |
|             return fib(n=n-1) + fib(n=n-2)
 | |
|         self.assertEqual([fib(n=number) for number in range(16)],
 | |
|             [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
 | |
|         self.assertEqual(fib.cache_info(),
 | |
|             self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
 | |
|         fib.cache_clear()
 | |
|         self.assertEqual(fib.cache_info(),
 | |
|             self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
 | |
| 
 | |
|     def test_kwargs_order(self):
 | |
|         # PEP 468: Preserving Keyword Argument Order
 | |
|         @self.module.lru_cache(maxsize=10)
 | |
|         def f(**kwargs):
 | |
|             return list(kwargs.items())
 | |
|         self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
 | |
|         self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
 | |
|         self.assertEqual(f.cache_info(),
 | |
|             self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
 | |
| 
 | |
|     def test_lru_cache_decoration(self):
 | |
|         def f(zomg: 'zomg_annotation'):
 | |
|             """f doc string"""
 | |
|             return 42
 | |
|         g = self.module.lru_cache()(f)
 | |
|         for attr in self.module.WRAPPER_ASSIGNMENTS:
 | |
|             self.assertEqual(getattr(g, attr), getattr(f, attr))
 | |
| 
 | |
|     @threading_helper.requires_working_threading()
 | |
|     def test_lru_cache_threaded(self):
 | |
|         n, m = 5, 11
 | |
|         def orig(x, y):
 | |
|             return 3 * x + y
 | |
|         f = self.module.lru_cache(maxsize=n*m)(orig)
 | |
|         hits, misses, maxsize, currsize = f.cache_info()
 | |
|         self.assertEqual(currsize, 0)
 | |
| 
 | |
|         start = threading.Event()
 | |
|         def full(k):
 | |
|             start.wait(10)
 | |
|             for _ in range(m):
 | |
|                 self.assertEqual(f(k, 0), orig(k, 0))
 | |
| 
 | |
|         def clear():
 | |
|             start.wait(10)
 | |
|             for _ in range(2*m):
 | |
|                 f.cache_clear()
 | |
| 
 | |
|         orig_si = sys.getswitchinterval()
 | |
|         support.setswitchinterval(1e-6)
 | |
|         try:
 | |
|             # create n threads in order to fill cache
 | |
|             threads = [threading.Thread(target=full, args=[k])
 | |
|                        for k in range(n)]
 | |
|             with threading_helper.start_threads(threads):
 | |
|                 start.set()
 | |
| 
 | |
|             hits, misses, maxsize, currsize = f.cache_info()
 | |
|             if self.module is py_functools:
 | |
|                 # XXX: Why can be not equal?
 | |
|                 self.assertLessEqual(misses, n)
 | |
|                 self.assertLessEqual(hits, m*n - misses)
 | |
|             else:
 | |
|                 self.assertEqual(misses, n)
 | |
|                 self.assertEqual(hits, m*n - misses)
 | |
|             self.assertEqual(currsize, n)
 | |
| 
 | |
|             # create n threads in order to fill cache and 1 to clear it
 | |
|             threads = [threading.Thread(target=clear)]
 | |
|             threads += [threading.Thread(target=full, args=[k])
 | |
|                         for k in range(n)]
 | |
|             start.clear()
 | |
|             with threading_helper.start_threads(threads):
 | |
|                 start.set()
 | |
|         finally:
 | |
|             sys.setswitchinterval(orig_si)
 | |
| 
 | |
|     @threading_helper.requires_working_threading()
 | |
|     def test_lru_cache_threaded2(self):
 | |
|         # Simultaneous call with the same arguments
 | |
|         n, m = 5, 7
 | |
|         start = threading.Barrier(n+1)
 | |
|         pause = threading.Barrier(n+1)
 | |
|         stop = threading.Barrier(n+1)
 | |
|         @self.module.lru_cache(maxsize=m*n)
 | |
|         def f(x):
 | |
|             pause.wait(10)
 | |
|             return 3 * x
 | |
|         self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
 | |
|         def test():
 | |
|             for i in range(m):
 | |
|                 start.wait(10)
 | |
|                 self.assertEqual(f(i), 3 * i)
 | |
|                 stop.wait(10)
 | |
|         threads = [threading.Thread(target=test) for k in range(n)]
 | |
|         with threading_helper.start_threads(threads):
 | |
|             for i in range(m):
 | |
|                 start.wait(10)
 | |
|                 stop.reset()
 | |
|                 pause.wait(10)
 | |
|                 start.reset()
 | |
|                 stop.wait(10)
 | |
|                 pause.reset()
 | |
|                 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
 | |
| 
 | |
|     @threading_helper.requires_working_threading()
 | |
|     def test_lru_cache_threaded3(self):
 | |
|         @self.module.lru_cache(maxsize=2)
 | |
|         def f(x):
 | |
|             time.sleep(.01)
 | |
|             return 3 * x
 | |
|         def test(i, x):
 | |
|             with self.subTest(thread=i):
 | |
|                 self.assertEqual(f(x), 3 * x, i)
 | |
|         threads = [threading.Thread(target=test, args=(i, v))
 | |
|                    for i, v in enumerate([1, 2, 2, 3, 2])]
 | |
|         with threading_helper.start_threads(threads):
 | |
|             pass
 | |
| 
 | |
|     def test_need_for_rlock(self):
 | |
|         # This will deadlock on an LRU cache that uses a regular lock
 | |
| 
 | |
|         @self.module.lru_cache(maxsize=10)
 | |
|         def test_func(x):
 | |
|             'Used to demonstrate a reentrant lru_cache call within a single thread'
 | |
|             return x
 | |
| 
 | |
|         class DoubleEq:
 | |
|             'Demonstrate a reentrant lru_cache call within a single thread'
 | |
|             def __init__(self, x):
 | |
|                 self.x = x
 | |
|             def __hash__(self):
 | |
|                 return self.x
 | |
|             def __eq__(self, other):
 | |
|                 if self.x == 2:
 | |
|                     test_func(DoubleEq(1))
 | |
|                 return self.x == other.x
 | |
| 
 | |
|         test_func(DoubleEq(1))                      # Load the cache
 | |
|         test_func(DoubleEq(2))                      # Load the cache
 | |
|         self.assertEqual(test_func(DoubleEq(2)),    # Trigger a re-entrant __eq__ call
 | |
|                          DoubleEq(2))               # Verify the correct return value
 | |
| 
 | |
|     def test_lru_method(self):
 | |
|         class X(int):
 | |
|             f_cnt = 0
 | |
|             @self.module.lru_cache(2)
 | |
|             def f(self, x):
 | |
|                 self.f_cnt += 1
 | |
|                 return x*10+self
 | |
|         a = X(5)
 | |
|         b = X(5)
 | |
|         c = X(7)
 | |
|         self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
 | |
| 
 | |
|         for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
 | |
|             self.assertEqual(a.f(x), x*10 + 5)
 | |
|         self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
 | |
|         self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
 | |
| 
 | |
|         for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
 | |
|             self.assertEqual(b.f(x), x*10 + 5)
 | |
|         self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
 | |
|         self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
 | |
| 
 | |
|         for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
 | |
|             self.assertEqual(c.f(x), x*10 + 7)
 | |
|         self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
 | |
|         self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
 | |
| 
 | |
|         self.assertEqual(a.f.cache_info(), X.f.cache_info())
 | |
|         self.assertEqual(b.f.cache_info(), X.f.cache_info())
 | |
|         self.assertEqual(c.f.cache_info(), X.f.cache_info())
 | |
| 
 | |
|     def test_pickle(self):
 | |
|         cls = self.__class__
 | |
|         for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
 | |
|             for proto in range(pickle.HIGHEST_PROTOCOL + 1):
 | |
|                 with self.subTest(proto=proto, func=f):
 | |
|                     f_copy = pickle.loads(pickle.dumps(f, proto))
 | |
|                     self.assertIs(f_copy, f)
 | |
| 
 | |
|     def test_copy(self):
 | |
|         cls = self.__class__
 | |
|         def orig(x, y):
 | |
|             return 3 * x + y
 | |
|         part = self.module.partial(orig, 2)
 | |
|         funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
 | |
|                  self.module.lru_cache(2)(part))
 | |
|         for f in funcs:
 | |
|             with self.subTest(func=f):
 | |
|                 f_copy = copy.copy(f)
 | |
|                 self.assertIs(f_copy, f)
 | |
| 
 | |
|     def test_deepcopy(self):
 | |
|         cls = self.__class__
 | |
|         def orig(x, y):
 | |
|             return 3 * x + y
 | |
|         part = self.module.partial(orig, 2)
 | |
|         funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
 | |
|                  self.module.lru_cache(2)(part))
 | |
|         for f in funcs:
 | |
|             with self.subTest(func=f):
 | |
|                 f_copy = copy.deepcopy(f)
 | |
|                 self.assertIs(f_copy, f)
 | |
| 
 | |
|     def test_lru_cache_parameters(self):
 | |
|         @self.module.lru_cache(maxsize=2)
 | |
|         def f():
 | |
|             return 1
 | |
|         self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
 | |
| 
 | |
|         @self.module.lru_cache(maxsize=1000, typed=True)
 | |
|         def f():
 | |
|             return 1
 | |
|         self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
 | |
| 
 | |
|     @support.suppress_immortalization()
 | |
|     def test_lru_cache_weakrefable(self):
 | |
|         @self.module.lru_cache
 | |
|         def test_function(x):
 | |
|             return x
 | |
| 
 | |
|         class A:
 | |
|             @self.module.lru_cache
 | |
|             def test_method(self, x):
 | |
|                 return (self, x)
 | |
| 
 | |
|             @staticmethod
 | |
|             @self.module.lru_cache
 | |
|             def test_staticmethod(x):
 | |
|                 return (self, x)
 | |
| 
 | |
|         refs = [weakref.ref(test_function),
 | |
|                 weakref.ref(A.test_method),
 | |
|                 weakref.ref(A.test_staticmethod)]
 | |
| 
 | |
|         for ref in refs:
 | |
|             self.assertIsNotNone(ref())
 | |
| 
 | |
|         del A
 | |
|         del test_function
 | |
|         gc.collect()
 | |
| 
 | |
|         for ref in refs:
 | |
|             self.assertIsNone(ref())
 | |
| 
 | |
|     def test_common_signatures(self):
 | |
|         def orig(a, /, b, c=True): ...
 | |
|         lru = self.module.lru_cache(1)(orig)
 | |
| 
 | |
|         self.assertEqual(str(Signature.from_callable(lru)), '(a, /, b, c=True)')
 | |
|         self.assertEqual(str(Signature.from_callable(lru.cache_info)), '()')
 | |
|         self.assertEqual(str(Signature.from_callable(lru.cache_clear)), '()')
 | |
| 
 | |
|     @support.skip_on_s390x
 | |
|     @unittest.skipIf(support.is_wasi, "WASI has limited C stack")
 | |
|     def test_lru_recursion(self):
 | |
| 
 | |
|         @self.module.lru_cache
 | |
|         def fib(n):
 | |
|             if n <= 1:
 | |
|                 return n
 | |
|             return fib(n-1) + fib(n-2)
 | |
| 
 | |
|         if not support.Py_DEBUG:
 | |
|             depth = support.get_c_recursion_limit()*2//7
 | |
|             with support.infinite_recursion():
 | |
|                 fib(depth)
 | |
|         if self.module == c_functools:
 | |
|             fib.cache_clear()
 | |
|             with support.infinite_recursion():
 | |
|                 with self.assertRaises(RecursionError):
 | |
|                     fib(10000)
 | |
| 
 | |
| 
 | |
| @py_functools.lru_cache()
 | |
| def py_cached_func(x, y):
 | |
|     return 3 * x + y
 | |
| 
 | |
| if c_functools:
 | |
|     @c_functools.lru_cache()
 | |
|     def c_cached_func(x, y):
 | |
|         return 3 * x + y
 | |
| 
 | |
| 
 | |
| class TestLRUPy(TestLRU, unittest.TestCase):
 | |
|     module = py_functools
 | |
|     cached_func = py_cached_func,
 | |
| 
 | |
|     @module.lru_cache()
 | |
|     def cached_meth(self, x, y):
 | |
|         return 3 * x + y
 | |
| 
 | |
|     @staticmethod
 | |
|     @module.lru_cache()
 | |
|     def cached_staticmeth(x, y):
 | |
|         return 3 * x + y
 | |
| 
 | |
| 
 | |
| @unittest.skipUnless(c_functools, 'requires the C _functools module')
 | |
| class TestLRUC(TestLRU, unittest.TestCase):
 | |
|     if c_functools:
 | |
|         module = c_functools
 | |
|         cached_func = c_cached_func,
 | |
| 
 | |
|         @module.lru_cache()
 | |
|         def cached_meth(self, x, y):
 | |
|             return 3 * x + y
 | |
| 
 | |
|         @staticmethod
 | |
|         @module.lru_cache()
 | |
|         def cached_staticmeth(x, y):
 | |
|             return 3 * x + y
 | |
| 
 | |
| 
 | |
| class TestSingleDispatch(unittest.TestCase):
 | |
|     def test_simple_overloads(self):
 | |
|         @functools.singledispatch
 | |
|         def g(obj):
 | |
|             return "base"
 | |
|         def g_int(i):
 | |
|             return "integer"
 | |
|         g.register(int, g_int)
 | |
|         self.assertEqual(g("str"), "base")
 | |
|         self.assertEqual(g(1), "integer")
 | |
|         self.assertEqual(g([1,2,3]), "base")
 | |
| 
 | |
|     def test_mro(self):
 | |
|         @functools.singledispatch
 | |
|         def g(obj):
 | |
|             return "base"
 | |
|         class A:
 | |
|             pass
 | |
|         class C(A):
 | |
|             pass
 | |
|         class B(A):
 | |
|             pass
 | |
|         class D(C, B):
 | |
|             pass
 | |
|         def g_A(a):
 | |
|             return "A"
 | |
|         def g_B(b):
 | |
|             return "B"
 | |
|         g.register(A, g_A)
 | |
|         g.register(B, g_B)
 | |
|         self.assertEqual(g(A()), "A")
 | |
|         self.assertEqual(g(B()), "B")
 | |
|         self.assertEqual(g(C()), "A")
 | |
|         self.assertEqual(g(D()), "B")
 | |
| 
 | |
|     def test_register_decorator(self):
 | |
|         @functools.singledispatch
 | |
|         def g(obj):
 | |
|             return "base"
 | |
|         @g.register(int)
 | |
|         def g_int(i):
 | |
|             return "int %s" % (i,)
 | |
|         self.assertEqual(g(""), "base")
 | |
|         self.assertEqual(g(12), "int 12")
 | |
|         self.assertIs(g.dispatch(int), g_int)
 | |
|         self.assertIs(g.dispatch(object), g.dispatch(str))
 | |
|         # Note: in the assert above this is not g.
 | |
|         # @singledispatch returns the wrapper.
 | |
| 
 | |
|     def test_wrapping_attributes(self):
 | |
|         @functools.singledispatch
 | |
|         def g(obj):
 | |
|             "Simple test"
 | |
|             return "Test"
 | |
|         self.assertEqual(g.__name__, "g")
 | |
|         if sys.flags.optimize < 2:
 | |
|             self.assertEqual(g.__doc__, "Simple test")
 | |
| 
 | |
|     @unittest.skipUnless(decimal, 'requires _decimal')
 | |
|     @support.cpython_only
 | |
|     def test_c_classes(self):
 | |
|         @functools.singledispatch
 | |
|         def g(obj):
 | |
|             return "base"
 | |
|         @g.register(decimal.DecimalException)
 | |
|         def _(obj):
 | |
|             return obj.args
 | |
|         subn = decimal.Subnormal("Exponent < Emin")
 | |
|         rnd = decimal.Rounded("Number got rounded")
 | |
|         self.assertEqual(g(subn), ("Exponent < Emin",))
 | |
|         self.assertEqual(g(rnd), ("Number got rounded",))
 | |
|         @g.register(decimal.Subnormal)
 | |
|         def _(obj):
 | |
|             return "Too small to care."
 | |
|         self.assertEqual(g(subn), "Too small to care.")
 | |
|         self.assertEqual(g(rnd), ("Number got rounded",))
 | |
| 
 | |
|     def test_compose_mro(self):
 | |
|         # None of the examples in this test depend on haystack ordering.
 | |
|         c = collections.abc
 | |
|         mro = functools._compose_mro
 | |
|         bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
 | |
|         for haystack in permutations(bases):
 | |
|             m = mro(dict, haystack)
 | |
|             self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
 | |
|                                  c.Collection, c.Sized, c.Iterable,
 | |
|                                  c.Container, object])
 | |
|         bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
 | |
|         for haystack in permutations(bases):
 | |
|             m = mro(collections.ChainMap, haystack)
 | |
|             self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
 | |
|                                  c.Collection, c.Sized, c.Iterable,
 | |
|                                  c.Container, object])
 | |
| 
 | |
|         # If there's a generic function with implementations registered for
 | |
|         # both Sized and Container, passing a defaultdict to it results in an
 | |
|         # ambiguous dispatch which will cause a RuntimeError (see
 | |
|         # test_mro_conflicts).
 | |
|         bases = [c.Container, c.Sized, str]
 | |
|         for haystack in permutations(bases):
 | |
|             m = mro(collections.defaultdict, [c.Sized, c.Container, str])
 | |
|             self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
 | |
|                                  c.Container, object])
 | |
| 
 | |
|         # MutableSequence below is registered directly on D. In other words, it
 | |
|         # precedes MutableMapping which means single dispatch will always
 | |
|         # choose MutableSequence here.
 | |
|         class D(collections.defaultdict):
 | |
|             pass
 | |
|         c.MutableSequence.register(D)
 | |
|         bases = [c.MutableSequence, c.MutableMapping]
 | |
|         for haystack in permutations(bases):
 | |
|             m = mro(D, haystack)
 | |
|             self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
 | |
|                                  collections.defaultdict, dict, c.MutableMapping, c.Mapping,
 | |
|                                  c.Collection, c.Sized, c.Iterable, c.Container,
 | |
|                                  object])
 | |
| 
 | |
|         # Container and Callable are registered on different base classes and
 | |
|         # a generic function supporting both should always pick the Callable
 | |
|         # implementation if a C instance is passed.
 | |
|         class C(collections.defaultdict):
 | |
|             def __call__(self):
 | |
|                 pass
 | |
|         bases = [c.Sized, c.Callable, c.Container, c.Mapping]
 | |
|         for haystack in permutations(bases):
 | |
|             m = mro(C, haystack)
 | |
|             self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
 | |
|                                  c.Collection, c.Sized, c.Iterable,
 | |
|                                  c.Container, object])
 | |
| 
 | |
|     def test_register_abc(self):
 | |
|         c = collections.abc
 | |
|         d = {"a": "b"}
 | |
|         l = [1, 2, 3]
 | |
|         s = {object(), None}
 | |
|         f = frozenset(s)
 | |
|         t = (1, 2, 3)
 | |
|         @functools.singledispatch
 | |
|         def g(obj):
 | |
|             return "base"
 | |
|         self.assertEqual(g(d), "base")
 | |
|         self.assertEqual(g(l), "base")
 | |
|         self.assertEqual(g(s), "base")
 | |
|         self.assertEqual(g(f), "base")
 | |
|         self.assertEqual(g(t), "base")
 | |
|         g.register(c.Sized, lambda obj: "sized")
 | |
|         self.assertEqual(g(d), "sized")
 | |
|         self.assertEqual(g(l), "sized")
 | |
|         self.assertEqual(g(s), "sized")
 | |
|         self.assertEqual(g(f), "sized")
 | |
|         self.assertEqual(g(t), "sized")
 | |
|         g.register(c.MutableMapping, lambda obj: "mutablemapping")
 | |
|         self.assertEqual(g(d), "mutablemapping")
 | |
|         self.assertEqual(g(l), "sized")
 | |
|         self.assertEqual(g(s), "sized")
 | |
|         self.assertEqual(g(f), "sized")
 | |
|         self.assertEqual(g(t), "sized")
 | |
|         g.register(collections.ChainMap, lambda obj: "chainmap")
 | |
|         self.assertEqual(g(d), "mutablemapping")  # irrelevant ABCs registered
 | |
|         self.assertEqual(g(l), "sized")
 | |
|         self.assertEqual(g(s), "sized")
 | |
|         self.assertEqual(g(f), "sized")
 | |
|         self.assertEqual(g(t), "sized")
 | |
|         g.register(c.MutableSequence, lambda obj: "mutablesequence")
 | |
|         self.assertEqual(g(d), "mutablemapping")
 | |
|         self.assertEqual(g(l), "mutablesequence")
 | |
|         self.assertEqual(g(s), "sized")
 | |
|         self.assertEqual(g(f), "sized")
 | |
|         self.assertEqual(g(t), "sized")
 | |
|         g.register(c.MutableSet, lambda obj: "mutableset")
 | |
|         self.assertEqual(g(d), "mutablemapping")
 | |
|         self.assertEqual(g(l), "mutablesequence")
 | |
|         self.assertEqual(g(s), "mutableset")
 | |
|         self.assertEqual(g(f), "sized")
 | |
|         self.assertEqual(g(t), "sized")
 | |
|         g.register(c.Mapping, lambda obj: "mapping")
 | |
|         self.assertEqual(g(d), "mutablemapping")  # not specific enough
 | |
|         self.assertEqual(g(l), "mutablesequence")
 | |
|         self.assertEqual(g(s), "mutableset")
 | |
|         self.assertEqual(g(f), "sized")
 | |
|         self.assertEqual(g(t), "sized")
 | |
|         g.register(c.Sequence, lambda obj: "sequence")
 | |
|         self.assertEqual(g(d), "mutablemapping")
 | |
|         self.assertEqual(g(l), "mutablesequence")
 | |
|         self.assertEqual(g(s), "mutableset")
 | |
|         self.assertEqual(g(f), "sized")
 | |
|         self.assertEqual(g(t), "sequence")
 | |
|         g.register(c.Set, lambda obj: "set")
 | |
|         self.assertEqual(g(d), "mutablemapping")
 | |
|         self.assertEqual(g(l), "mutablesequence")
 | |
|         self.assertEqual(g(s), "mutableset")
 | |
|         self.assertEqual(g(f), "set")
 | |
|         self.assertEqual(g(t), "sequence")
 | |
|         g.register(dict, lambda obj: "dict")
 | |
|         self.assertEqual(g(d), "dict")
 | |
|         self.assertEqual(g(l), "mutablesequence")
 | |
|         self.assertEqual(g(s), "mutableset")
 | |
|         self.assertEqual(g(f), "set")
 | |
|         self.assertEqual(g(t), "sequence")
 | |
|         g.register(list, lambda obj: "list")
 | |
|         self.assertEqual(g(d), "dict")
 | |
|         self.assertEqual(g(l), "list")
 | |
|         self.assertEqual(g(s), "mutableset")
 | |
|         self.assertEqual(g(f), "set")
 | |
|         self.assertEqual(g(t), "sequence")
 | |
|         g.register(set, lambda obj: "concrete-set")
 | |
|         self.assertEqual(g(d), "dict")
 | |
|         self.assertEqual(g(l), "list")
 | |
|         self.assertEqual(g(s), "concrete-set")
 | |
|         self.assertEqual(g(f), "set")
 | |
|         self.assertEqual(g(t), "sequence")
 | |
|         g.register(frozenset, lambda obj: "frozen-set")
 | |
|         self.assertEqual(g(d), "dict")
 | |
|         self.assertEqual(g(l), "list")
 | |
|         self.assertEqual(g(s), "concrete-set")
 | |
|         self.assertEqual(g(f), "frozen-set")
 | |
|         self.assertEqual(g(t), "sequence")
 | |
|         g.register(tuple, lambda obj: "tuple")
 | |
|         self.assertEqual(g(d), "dict")
 | |
|         self.assertEqual(g(l), "list")
 | |
|         self.assertEqual(g(s), "concrete-set")
 | |
|         self.assertEqual(g(f), "frozen-set")
 | |
|         self.assertEqual(g(t), "tuple")
 | |
| 
 | |
|     def test_c3_abc(self):
 | |
|         c = collections.abc
 | |
|         mro = functools._c3_mro
 | |
|         class A(object):
 | |
|             pass
 | |
|         class B(A):
 | |
|             def __len__(self):
 | |
|                 return 0   # implies Sized
 | |
|         @c.Container.register
 | |
|         class C(object):
 | |
|             pass
 | |
|         class D(object):
 | |
|             pass   # unrelated
 | |
|         class X(D, C, B):
 | |
|             def __call__(self):
 | |
|                 pass   # implies Callable
 | |
|         expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
 | |
|         for abcs in permutations([c.Sized, c.Callable, c.Container]):
 | |
|             self.assertEqual(mro(X, abcs=abcs), expected)
 | |
|         # unrelated ABCs don't appear in the resulting MRO
 | |
|         many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
 | |
|         self.assertEqual(mro(X, abcs=many_abcs), expected)
 | |
| 
 | |
|     def test_false_meta(self):
 | |
|         # see issue23572
 | |
|         class MetaA(type):
 | |
|             def __len__(self):
 | |
|                 return 0
 | |
|         class A(metaclass=MetaA):
 | |
|             pass
 | |
|         class AA(A):
 | |
|             pass
 | |
|         @functools.singledispatch
 | |
|         def fun(a):
 | |
|             return 'base A'
 | |
|         @fun.register(A)
 | |
|         def _(a):
 | |
|             return 'fun A'
 | |
|         aa = AA()
 | |
|         self.assertEqual(fun(aa), 'fun A')
 | |
| 
 | |
|     def test_mro_conflicts(self):
 | |
|         c = collections.abc
 | |
|         @functools.singledispatch
 | |
|         def g(arg):
 | |
|             return "base"
 | |
|         class O(c.Sized):
 | |
|             def __len__(self):
 | |
|                 return 0
 | |
|         o = O()
 | |
|         self.assertEqual(g(o), "base")
 | |
|         g.register(c.Iterable, lambda arg: "iterable")
 | |
|         g.register(c.Container, lambda arg: "container")
 | |
|         g.register(c.Sized, lambda arg: "sized")
 | |
|         g.register(c.Set, lambda arg: "set")
 | |
|         self.assertEqual(g(o), "sized")
 | |
|         c.Iterable.register(O)
 | |
|         self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__
 | |
|         c.Container.register(O)
 | |
|         self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__
 | |
|         c.Set.register(O)
 | |
|         self.assertEqual(g(o), "set")     # because c.Set is a subclass of
 | |
|                                           # c.Sized and c.Container
 | |
|         class P:
 | |
|             pass
 | |
|         p = P()
 | |
|         self.assertEqual(g(p), "base")
 | |
|         c.Iterable.register(P)
 | |
|         self.assertEqual(g(p), "iterable")
 | |
|         c.Container.register(P)
 | |
|         with self.assertRaises(RuntimeError) as re_one:
 | |
|             g(p)
 | |
|         self.assertIn(
 | |
|             str(re_one.exception),
 | |
|             (("Ambiguous dispatch: <class 'collections.abc.Container'> "
 | |
|               "or <class 'collections.abc.Iterable'>"),
 | |
|              ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
 | |
|               "or <class 'collections.abc.Container'>")),
 | |
|         )
 | |
|         class Q(c.Sized):
 | |
|             def __len__(self):
 | |
|                 return 0
 | |
|         q = Q()
 | |
|         self.assertEqual(g(q), "sized")
 | |
|         c.Iterable.register(Q)
 | |
|         self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__
 | |
|         c.Set.register(Q)
 | |
|         self.assertEqual(g(q), "set")     # because c.Set is a subclass of
 | |
|                                           # c.Sized and c.Iterable
 | |
|         @functools.singledispatch
 | |
|         def h(arg):
 | |
|             return "base"
 | |
|         @h.register(c.Sized)
 | |
|         def _(arg):
 | |
|             return "sized"
 | |
|         @h.register(c.Container)
 | |
|         def _(arg):
 | |
|             return "container"
 | |
|         # Even though Sized and Container are explicit bases of MutableMapping,
 | |
|         # this ABC is implicitly registered on defaultdict which makes all of
 | |
|         # MutableMapping's bases implicit as well from defaultdict's
 | |
|         # perspective.
 | |
|         with self.assertRaises(RuntimeError) as re_two:
 | |
|             h(collections.defaultdict(lambda: 0))
 | |
|         self.assertIn(
 | |
|             str(re_two.exception),
 | |
|             (("Ambiguous dispatch: <class 'collections.abc.Container'> "
 | |
|               "or <class 'collections.abc.Sized'>"),
 | |
|              ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
 | |
|               "or <class 'collections.abc.Container'>")),
 | |
|         )
 | |
|         class R(collections.defaultdict):
 | |
|             pass
 | |
|         c.MutableSequence.register(R)
 | |
|         @functools.singledispatch
 | |
|         def i(arg):
 | |
|             return "base"
 | |
|         @i.register(c.MutableMapping)
 | |
|         def _(arg):
 | |
|             return "mapping"
 | |
|         @i.register(c.MutableSequence)
 | |
|         def _(arg):
 | |
|             return "sequence"
 | |
|         r = R()
 | |
|         self.assertEqual(i(r), "sequence")
 | |
|         class S:
 | |
|             pass
 | |
|         class T(S, c.Sized):
 | |
|             def __len__(self):
 | |
|                 return 0
 | |
|         t = T()
 | |
|         self.assertEqual(h(t), "sized")
 | |
|         c.Container.register(T)
 | |
|         self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO
 | |
|         class U:
 | |
|             def __len__(self):
 | |
|                 return 0
 | |
|         u = U()
 | |
|         self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred
 | |
|                                           # from the existence of __len__()
 | |
|         c.Container.register(U)
 | |
|         # There is no preference for registered versus inferred ABCs.
 | |
|         with self.assertRaises(RuntimeError) as re_three:
 | |
|             h(u)
 | |
|         self.assertIn(
 | |
|             str(re_three.exception),
 | |
|             (("Ambiguous dispatch: <class 'collections.abc.Container'> "
 | |
|               "or <class 'collections.abc.Sized'>"),
 | |
|              ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
 | |
|               "or <class 'collections.abc.Container'>")),
 | |
|         )
 | |
|         class V(c.Sized, S):
 | |
|             def __len__(self):
 | |
|                 return 0
 | |
|         @functools.singledispatch
 | |
|         def j(arg):
 | |
|             return "base"
 | |
|         @j.register(S)
 | |
|         def _(arg):
 | |
|             return "s"
 | |
|         @j.register(c.Container)
 | |
|         def _(arg):
 | |
|             return "container"
 | |
|         v = V()
 | |
|         self.assertEqual(j(v), "s")
 | |
|         c.Container.register(V)
 | |
|         self.assertEqual(j(v), "container")   # because it ends up right after
 | |
|                                               # Sized in the MRO
 | |
| 
 | |
|     def test_cache_invalidation(self):
 | |
|         from collections import UserDict
 | |
|         import weakref
 | |
| 
 | |
|         class TracingDict(UserDict):
 | |
|             def __init__(self, *args, **kwargs):
 | |
|                 super(TracingDict, self).__init__(*args, **kwargs)
 | |
|                 self.set_ops = []
 | |
|                 self.get_ops = []
 | |
|             def __getitem__(self, key):
 | |
|                 result = self.data[key]
 | |
|                 self.get_ops.append(key)
 | |
|                 return result
 | |
|             def __setitem__(self, key, value):
 | |
|                 self.set_ops.append(key)
 | |
|                 self.data[key] = value
 | |
|             def clear(self):
 | |
|                 self.data.clear()
 | |
| 
 | |
|         td = TracingDict()
 | |
|         with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
 | |
|             c = collections.abc
 | |
|             @functools.singledispatch
 | |
|             def g(arg):
 | |
|                 return "base"
 | |
|             d = {}
 | |
|             l = []
 | |
|             self.assertEqual(len(td), 0)
 | |
|             self.assertEqual(g(d), "base")
 | |
|             self.assertEqual(len(td), 1)
 | |
|             self.assertEqual(td.get_ops, [])
 | |
|             self.assertEqual(td.set_ops, [dict])
 | |
|             self.assertEqual(td.data[dict], g.registry[object])
 | |
|             self.assertEqual(g(l), "base")
 | |
|             self.assertEqual(len(td), 2)
 | |
|             self.assertEqual(td.get_ops, [])
 | |
|             self.assertEqual(td.set_ops, [dict, list])
 | |
|             self.assertEqual(td.data[dict], g.registry[object])
 | |
|             self.assertEqual(td.data[list], g.registry[object])
 | |
|             self.assertEqual(td.data[dict], td.data[list])
 | |
|             self.assertEqual(g(l), "base")
 | |
|             self.assertEqual(g(d), "base")
 | |
|             self.assertEqual(td.get_ops, [list, dict])
 | |
|             self.assertEqual(td.set_ops, [dict, list])
 | |
|             g.register(list, lambda arg: "list")
 | |
|             self.assertEqual(td.get_ops, [list, dict])
 | |
|             self.assertEqual(len(td), 0)
 | |
|             self.assertEqual(g(d), "base")
 | |
|             self.assertEqual(len(td), 1)
 | |
|             self.assertEqual(td.get_ops, [list, dict])
 | |
|             self.assertEqual(td.set_ops, [dict, list, dict])
 | |
|             self.assertEqual(td.data[dict],
 | |
|                              functools._find_impl(dict, g.registry))
 | |
|             self.assertEqual(g(l), "list")
 | |
|             self.assertEqual(len(td), 2)
 | |
|             self.assertEqual(td.get_ops, [list, dict])
 | |
|             self.assertEqual(td.set_ops, [dict, list, dict, list])
 | |
|             self.assertEqual(td.data[list],
 | |
|                              functools._find_impl(list, g.registry))
 | |
|             class X:
 | |
|                 pass
 | |
|             c.MutableMapping.register(X)   # Will not invalidate the cache,
 | |
|                                            # not using ABCs yet.
 | |
|             self.assertEqual(g(d), "base")
 | |
|             self.assertEqual(g(l), "list")
 | |
|             self.assertEqual(td.get_ops, [list, dict, dict, list])
 | |
|             self.assertEqual(td.set_ops, [dict, list, dict, list])
 | |
|             g.register(c.Sized, lambda arg: "sized")
 | |
|             self.assertEqual(len(td), 0)
 | |
|             self.assertEqual(g(d), "sized")
 | |
|             self.assertEqual(len(td), 1)
 | |
|             self.assertEqual(td.get_ops, [list, dict, dict, list])
 | |
|             self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
 | |
|             self.assertEqual(g(l), "list")
 | |
|             self.assertEqual(len(td), 2)
 | |
|             self.assertEqual(td.get_ops, [list, dict, dict, list])
 | |
|             self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
 | |
|             self.assertEqual(g(l), "list")
 | |
|             self.assertEqual(g(d), "sized")
 | |
|             self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
 | |
|             self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
 | |
|             g.dispatch(list)
 | |
|             g.dispatch(dict)
 | |
|             self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
 | |
|                                           list, dict])
 | |
|             self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
 | |
|             c.MutableSet.register(X)       # Will invalidate the cache.
 | |
|             self.assertEqual(len(td), 2)   # Stale cache.
 | |
|             self.assertEqual(g(l), "list")
 | |
|             self.assertEqual(len(td), 1)
 | |
|             g.register(c.MutableMapping, lambda arg: "mutablemapping")
 | |
|             self.assertEqual(len(td), 0)
 | |
|             self.assertEqual(g(d), "mutablemapping")
 | |
|             self.assertEqual(len(td), 1)
 | |
|             self.assertEqual(g(l), "list")
 | |
|             self.assertEqual(len(td), 2)
 | |
|             g.register(dict, lambda arg: "dict")
 | |
|             self.assertEqual(g(d), "dict")
 | |
|             self.assertEqual(g(l), "list")
 | |
|             g._clear_cache()
 | |
|             self.assertEqual(len(td), 0)
 | |
| 
 | |
|     def test_annotations(self):
 | |
|         @functools.singledispatch
 | |
|         def i(arg):
 | |
|             return "base"
 | |
|         @i.register
 | |
|         def _(arg: collections.abc.Mapping):
 | |
|             return "mapping"
 | |
|         @i.register
 | |
|         def _(arg: "collections.abc.Sequence"):
 | |
|             return "sequence"
 | |
|         self.assertEqual(i(None), "base")
 | |
|         self.assertEqual(i({"a": 1}), "mapping")
 | |
|         self.assertEqual(i([1, 2, 3]), "sequence")
 | |
|         self.assertEqual(i((1, 2, 3)), "sequence")
 | |
|         self.assertEqual(i("str"), "sequence")
 | |
| 
 | |
|         # Registering classes as callables doesn't work with annotations,
 | |
|         # you need to pass the type explicitly.
 | |
|         @i.register(str)
 | |
|         class _:
 | |
|             def __init__(self, arg):
 | |
|                 self.arg = arg
 | |
| 
 | |
|             def __eq__(self, other):
 | |
|                 return self.arg == other
 | |
|         self.assertEqual(i("str"), "str")
 | |
| 
 | |
|     def test_method_register(self):
 | |
|         class A:
 | |
|             @functools.singledispatchmethod
 | |
|             def t(self, arg):
 | |
|                 self.arg = "base"
 | |
|             @t.register(int)
 | |
|             def _(self, arg):
 | |
|                 self.arg = "int"
 | |
|             @t.register(str)
 | |
|             def _(self, arg):
 | |
|                 self.arg = "str"
 | |
|         a = A()
 | |
| 
 | |
|         a.t(0)
 | |
|         self.assertEqual(a.arg, "int")
 | |
|         aa = A()
 | |
|         self.assertFalse(hasattr(aa, 'arg'))
 | |
|         a.t('')
 | |
|         self.assertEqual(a.arg, "str")
 | |
|         aa = A()
 | |
|         self.assertFalse(hasattr(aa, 'arg'))
 | |
|         a.t(0.0)
 | |
|         self.assertEqual(a.arg, "base")
 | |
|         aa = A()
 | |
|         self.assertFalse(hasattr(aa, 'arg'))
 | |
| 
 | |
|     def test_staticmethod_register(self):
 | |
|         class A:
 | |
|             @functools.singledispatchmethod
 | |
|             @staticmethod
 | |
|             def t(arg):
 | |
|                 return arg
 | |
|             @t.register(int)
 | |
|             @staticmethod
 | |
|             def _(arg):
 | |
|                 return isinstance(arg, int)
 | |
|             @t.register(str)
 | |
|             @staticmethod
 | |
|             def _(arg):
 | |
|                 return isinstance(arg, str)
 | |
|         a = A()
 | |
| 
 | |
|         self.assertTrue(A.t(0))
 | |
|         self.assertTrue(A.t(''))
 | |
|         self.assertEqual(A.t(0.0), 0.0)
 | |
| 
 | |
|     def test_slotted_class(self):
 | |
|         class Slot:
 | |
|             __slots__ = ('a', 'b')
 | |
|             @functools.singledispatchmethod
 | |
|             def go(self, item, arg):
 | |
|                 pass
 | |
| 
 | |
|             @go.register
 | |
|             def _(self, item: int, arg):
 | |
|                 return item + arg
 | |
| 
 | |
|         s = Slot()
 | |
|         self.assertEqual(s.go(1, 1), 2)
 | |
| 
 | |
|     def test_classmethod_slotted_class(self):
 | |
|         class Slot:
 | |
|             __slots__ = ('a', 'b')
 | |
|             @functools.singledispatchmethod
 | |
|             @classmethod
 | |
|             def go(cls, item, arg):
 | |
|                 pass
 | |
| 
 | |
|             @go.register
 | |
|             @classmethod
 | |
|             def _(cls, item: int, arg):
 | |
|                 return item + arg
 | |
| 
 | |
|         s = Slot()
 | |
|         self.assertEqual(s.go(1, 1), 2)
 | |
|         self.assertEqual(Slot.go(1, 1), 2)
 | |
| 
 | |
|     def test_staticmethod_slotted_class(self):
 | |
|         class A:
 | |
|             __slots__ = ['a']
 | |
|             @functools.singledispatchmethod
 | |
|             @staticmethod
 | |
|             def t(arg):
 | |
|                 return arg
 | |
|             @t.register(int)
 | |
|             @staticmethod
 | |
|             def _(arg):
 | |
|                 return isinstance(arg, int)
 | |
|             @t.register(str)
 | |
|             @staticmethod
 | |
|             def _(arg):
 | |
|                 return isinstance(arg, str)
 | |
|         a = A()
 | |
| 
 | |
|         self.assertTrue(A.t(0))
 | |
|         self.assertTrue(A.t(''))
 | |
|         self.assertEqual(A.t(0.0), 0.0)
 | |
|         self.assertTrue(a.t(0))
 | |
|         self.assertTrue(a.t(''))
 | |
|         self.assertEqual(a.t(0.0), 0.0)
 | |
| 
 | |
|     def test_assignment_behavior(self):
 | |
|         # see gh-106448
 | |
|         class A:
 | |
|             @functools.singledispatchmethod
 | |
|             def t(arg):
 | |
|                 return arg
 | |
| 
 | |
|         a = A()
 | |
|         a.t.foo = 'bar'
 | |
|         a2 = A()
 | |
|         with self.assertRaises(AttributeError):
 | |
|             a2.t.foo
 | |
| 
 | |
|     def test_classmethod_register(self):
 | |
|         class A:
 | |
|             def __init__(self, arg):
 | |
|                 self.arg = arg
 | |
| 
 | |
|             @functools.singledispatchmethod
 | |
|             @classmethod
 | |
|             def t(cls, arg):
 | |
|                 return cls("base")
 | |
|             @t.register(int)
 | |
|             @classmethod
 | |
|             def _(cls, arg):
 | |
|                 return cls("int")
 | |
|             @t.register(str)
 | |
|             @classmethod
 | |
|             def _(cls, arg):
 | |
|                 return cls("str")
 | |
| 
 | |
|         self.assertEqual(A.t(0).arg, "int")
 | |
|         self.assertEqual(A.t('').arg, "str")
 | |
|         self.assertEqual(A.t(0.0).arg, "base")
 | |
| 
 | |
|     def test_callable_register(self):
 | |
|         class A:
 | |
|             def __init__(self, arg):
 | |
|                 self.arg = arg
 | |
| 
 | |
|             @functools.singledispatchmethod
 | |
|             @classmethod
 | |
|             def t(cls, arg):
 | |
|                 return cls("base")
 | |
| 
 | |
|         @A.t.register(int)
 | |
|         @classmethod
 | |
|         def _(cls, arg):
 | |
|             return cls("int")
 | |
|         @A.t.register(str)
 | |
|         @classmethod
 | |
|         def _(cls, arg):
 | |
|             return cls("str")
 | |
| 
 | |
|         self.assertEqual(A.t(0).arg, "int")
 | |
|         self.assertEqual(A.t('').arg, "str")
 | |
|         self.assertEqual(A.t(0.0).arg, "base")
 | |
| 
 | |
|     def test_abstractmethod_register(self):
 | |
|         class Abstract(metaclass=abc.ABCMeta):
 | |
| 
 | |
|             @functools.singledispatchmethod
 | |
|             @abc.abstractmethod
 | |
|             def add(self, x, y):
 | |
|                 pass
 | |
| 
 | |
|         self.assertTrue(Abstract.add.__isabstractmethod__)
 | |
|         self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__)
 | |
| 
 | |
|         with self.assertRaises(TypeError):
 | |
|             Abstract()
 | |
| 
 | |
|     def test_type_ann_register(self):
 | |
|         class A:
 | |
|             @functools.singledispatchmethod
 | |
|             def t(self, arg):
 | |
|                 return "base"
 | |
|             @t.register
 | |
|             def _(self, arg: int):
 | |
|                 return "int"
 | |
|             @t.register
 | |
|             def _(self, arg: str):
 | |
|                 return "str"
 | |
|         a = A()
 | |
| 
 | |
|         self.assertEqual(a.t(0), "int")
 | |
|         self.assertEqual(a.t(''), "str")
 | |
|         self.assertEqual(a.t(0.0), "base")
 | |
| 
 | |
|     def test_staticmethod_type_ann_register(self):
 | |
|         class A:
 | |
|             @functools.singledispatchmethod
 | |
|             @staticmethod
 | |
|             def t(arg):
 | |
|                 return arg
 | |
|             @t.register
 | |
|             @staticmethod
 | |
|             def _(arg: int):
 | |
|                 return isinstance(arg, int)
 | |
|             @t.register
 | |
|             @staticmethod
 | |
|             def _(arg: str):
 | |
|                 return isinstance(arg, str)
 | |
|         a = A()
 | |
| 
 | |
|         self.assertTrue(A.t(0))
 | |
|         self.assertTrue(A.t(''))
 | |
|         self.assertEqual(A.t(0.0), 0.0)
 | |
| 
 | |
|     def test_classmethod_type_ann_register(self):
 | |
|         class A:
 | |
|             def __init__(self, arg):
 | |
|                 self.arg = arg
 | |
| 
 | |
|             @functools.singledispatchmethod
 | |
|             @classmethod
 | |
|             def t(cls, arg):
 | |
|                 return cls("base")
 | |
|             @t.register
 | |
|             @classmethod
 | |
|             def _(cls, arg: int):
 | |
|                 return cls("int")
 | |
|             @t.register
 | |
|             @classmethod
 | |
|             def _(cls, arg: str):
 | |
|                 return cls("str")
 | |
| 
 | |
|         self.assertEqual(A.t(0).arg, "int")
 | |
|         self.assertEqual(A.t('').arg, "str")
 | |
|         self.assertEqual(A.t(0.0).arg, "base")
 | |
| 
 | |
|     def test_method_wrapping_attributes(self):
 | |
|         class A:
 | |
|             @functools.singledispatchmethod
 | |
|             def func(self, arg: int) -> str:
 | |
|                 """My function docstring"""
 | |
|                 return str(arg)
 | |
|             @functools.singledispatchmethod
 | |
|             @classmethod
 | |
|             def cls_func(cls, arg: int) -> str:
 | |
|                 """My function docstring"""
 | |
|                 return str(arg)
 | |
|             @functools.singledispatchmethod
 | |
|             @staticmethod
 | |
|             def static_func(arg: int) -> str:
 | |
|                 """My function docstring"""
 | |
|                 return str(arg)
 | |
| 
 | |
|         for meth in (
 | |
|             A.func,
 | |
|             A().func,
 | |
|             A.cls_func,
 | |
|             A().cls_func,
 | |
|             A.static_func,
 | |
|             A().static_func
 | |
|         ):
 | |
|             with self.subTest(meth=meth):
 | |
|                 self.assertEqual(meth.__doc__,
 | |
|                                  ('My function docstring'
 | |
|                                   if support.HAVE_DOCSTRINGS
 | |
|                                   else None))
 | |
|                 self.assertEqual(meth.__annotations__['arg'], int)
 | |
| 
 | |
|         self.assertEqual(A.func.__name__, 'func')
 | |
|         self.assertEqual(A().func.__name__, 'func')
 | |
|         self.assertEqual(A.cls_func.__name__, 'cls_func')
 | |
|         self.assertEqual(A().cls_func.__name__, 'cls_func')
 | |
|         self.assertEqual(A.static_func.__name__, 'static_func')
 | |
|         self.assertEqual(A().static_func.__name__, 'static_func')
 | |
| 
 | |
|     def test_double_wrapped_methods(self):
 | |
|         def classmethod_friendly_decorator(func):
 | |
|             wrapped = func.__func__
 | |
|             @classmethod
 | |
|             @functools.wraps(wrapped)
 | |
|             def wrapper(*args, **kwargs):
 | |
|                 return wrapped(*args, **kwargs)
 | |
|             return wrapper
 | |
| 
 | |
|         class WithoutSingleDispatch:
 | |
|             @classmethod
 | |
|             @contextlib.contextmanager
 | |
|             def cls_context_manager(cls, arg: int) -> str:
 | |
|                 try:
 | |
|                     yield str(arg)
 | |
|                 finally:
 | |
|                     return 'Done'
 | |
| 
 | |
|             @classmethod_friendly_decorator
 | |
|             @classmethod
 | |
|             def decorated_classmethod(cls, arg: int) -> str:
 | |
|                 return str(arg)
 | |
| 
 | |
|         class WithSingleDispatch:
 | |
|             @functools.singledispatchmethod
 | |
|             @classmethod
 | |
|             @contextlib.contextmanager
 | |
|             def cls_context_manager(cls, arg: int) -> str:
 | |
|                 """My function docstring"""
 | |
|                 try:
 | |
|                     yield str(arg)
 | |
|                 finally:
 | |
|                     return 'Done'
 | |
| 
 | |
|             @functools.singledispatchmethod
 | |
|             @classmethod_friendly_decorator
 | |
|             @classmethod
 | |
|             def decorated_classmethod(cls, arg: int) -> str:
 | |
|                 """My function docstring"""
 | |
|                 return str(arg)
 | |
| 
 | |
|         # These are sanity checks
 | |
|         # to test the test itself is working as expected
 | |
|         with WithoutSingleDispatch.cls_context_manager(5) as foo:
 | |
|             without_single_dispatch_foo = foo
 | |
| 
 | |
|         with WithSingleDispatch.cls_context_manager(5) as foo:
 | |
|             single_dispatch_foo = foo
 | |
| 
 | |
|         self.assertEqual(without_single_dispatch_foo, single_dispatch_foo)
 | |
|         self.assertEqual(single_dispatch_foo, '5')
 | |
| 
 | |
|         self.assertEqual(
 | |
|             WithoutSingleDispatch.decorated_classmethod(5),
 | |
|             WithSingleDispatch.decorated_classmethod(5)
 | |
|         )
 | |
| 
 | |
|         self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5')
 | |
| 
 | |
|         # Behavioural checks now follow
 | |
|         for method_name in ('cls_context_manager', 'decorated_classmethod'):
 | |
|             with self.subTest(method=method_name):
 | |
|                 self.assertEqual(
 | |
|                     getattr(WithSingleDispatch, method_name).__name__,
 | |
|                     getattr(WithoutSingleDispatch, method_name).__name__
 | |
|                 )
 | |
| 
 | |
|                 self.assertEqual(
 | |
|                     getattr(WithSingleDispatch(), method_name).__name__,
 | |
|                     getattr(WithoutSingleDispatch(), method_name).__name__
 | |
|                 )
 | |
| 
 | |
|         for meth in (
 | |
|             WithSingleDispatch.cls_context_manager,
 | |
|             WithSingleDispatch().cls_context_manager,
 | |
|             WithSingleDispatch.decorated_classmethod,
 | |
|             WithSingleDispatch().decorated_classmethod
 | |
|         ):
 | |
|             with self.subTest(meth=meth):
 | |
|                 self.assertEqual(meth.__doc__,
 | |
|                                  ('My function docstring'
 | |
|                                   if support.HAVE_DOCSTRINGS
 | |
|                                   else None))
 | |
|                 self.assertEqual(meth.__annotations__['arg'], int)
 | |
| 
 | |
|         self.assertEqual(
 | |
|             WithSingleDispatch.cls_context_manager.__name__,
 | |
|             'cls_context_manager'
 | |
|         )
 | |
|         self.assertEqual(
 | |
|             WithSingleDispatch().cls_context_manager.__name__,
 | |
|             'cls_context_manager'
 | |
|         )
 | |
|         self.assertEqual(
 | |
|             WithSingleDispatch.decorated_classmethod.__name__,
 | |
|             'decorated_classmethod'
 | |
|         )
 | |
|         self.assertEqual(
 | |
|             WithSingleDispatch().decorated_classmethod.__name__,
 | |
|             'decorated_classmethod'
 | |
|         )
 | |
| 
 | |
|     def test_invalid_registrations(self):
 | |
|         msg_prefix = "Invalid first argument to `register()`: "
 | |
|         msg_suffix = (
 | |
|             ". Use either `@register(some_class)` or plain `@register` on an "
 | |
|             "annotated function."
 | |
|         )
 | |
|         @functools.singledispatch
 | |
|         def i(arg):
 | |
|             return "base"
 | |
|         with self.assertRaises(TypeError) as exc:
 | |
|             @i.register(42)
 | |
|             def _(arg):
 | |
|                 return "I annotated with a non-type"
 | |
|         self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
 | |
|         self.assertTrue(str(exc.exception).endswith(msg_suffix))
 | |
|         with self.assertRaises(TypeError) as exc:
 | |
|             @i.register
 | |
|             def _(arg):
 | |
|                 return "I forgot to annotate"
 | |
|         self.assertTrue(str(exc.exception).startswith(msg_prefix +
 | |
|             "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
 | |
|         ))
 | |
|         self.assertTrue(str(exc.exception).endswith(msg_suffix))
 | |
| 
 | |
|         with self.assertRaises(TypeError) as exc:
 | |
|             @i.register
 | |
|             def _(arg: typing.Iterable[str]):
 | |
|                 # At runtime, dispatching on generics is impossible.
 | |
|                 # When registering implementations with singledispatch, avoid
 | |
|                 # types from `typing`. Instead, annotate with regular types
 | |
|                 # or ABCs.
 | |
|                 return "I annotated with a generic collection"
 | |
|         self.assertTrue(str(exc.exception).startswith(
 | |
|             "Invalid annotation for 'arg'."
 | |
|         ))
 | |
|         self.assertTrue(str(exc.exception).endswith(
 | |
|             'typing.Iterable[str] is not a class.'
 | |
|         ))
 | |
| 
 | |
|         with self.assertRaises(TypeError) as exc:
 | |
|             @i.register
 | |
|             def _(arg: typing.Union[int, typing.Iterable[str]]):
 | |
|                 return "Invalid Union"
 | |
|         self.assertTrue(str(exc.exception).startswith(
 | |
|             "Invalid annotation for 'arg'."
 | |
|         ))
 | |
|         self.assertTrue(str(exc.exception).endswith(
 | |
|             'typing.Union[int, typing.Iterable[str]] not all arguments are classes.'
 | |
|         ))
 | |
| 
 | |
|     def test_invalid_positional_argument(self):
 | |
|         @functools.singledispatch
 | |
|         def f(*args, **kwargs):
 | |
|             pass
 | |
|         msg = 'f requires at least 1 positional argument'
 | |
|         with self.assertRaisesRegex(TypeError, msg):
 | |
|             f()
 | |
|         msg = 'f requires at least 1 positional argument'
 | |
|         with self.assertRaisesRegex(TypeError, msg):
 | |
|             f(a=1)
 | |
| 
 | |
|     def test_invalid_positional_argument_singledispatchmethod(self):
 | |
|         class A:
 | |
|             @functools.singledispatchmethod
 | |
|             def t(self, *args, **kwargs):
 | |
|                 pass
 | |
|         msg = 't requires at least 1 positional argument'
 | |
|         with self.assertRaisesRegex(TypeError, msg):
 | |
|             A().t()
 | |
|         msg = 't requires at least 1 positional argument'
 | |
|         with self.assertRaisesRegex(TypeError, msg):
 | |
|             A().t(a=1)
 | |
| 
 | |
|     def test_union(self):
 | |
|         @functools.singledispatch
 | |
|         def f(arg):
 | |
|             return "default"
 | |
| 
 | |
|         @f.register
 | |
|         def _(arg: typing.Union[str, bytes]):
 | |
|             return "typing.Union"
 | |
| 
 | |
|         @f.register
 | |
|         def _(arg: int | float):
 | |
|             return "types.UnionType"
 | |
| 
 | |
|         self.assertEqual(f([]), "default")
 | |
|         self.assertEqual(f(""), "typing.Union")
 | |
|         self.assertEqual(f(b""), "typing.Union")
 | |
|         self.assertEqual(f(1), "types.UnionType")
 | |
|         self.assertEqual(f(1.0), "types.UnionType")
 | |
| 
 | |
|     def test_union_conflict(self):
 | |
|         @functools.singledispatch
 | |
|         def f(arg):
 | |
|             return "default"
 | |
| 
 | |
|         @f.register
 | |
|         def _(arg: typing.Union[str, bytes]):
 | |
|             return "typing.Union"
 | |
| 
 | |
|         @f.register
 | |
|         def _(arg: int | str):
 | |
|             return "types.UnionType"
 | |
| 
 | |
|         self.assertEqual(f([]), "default")
 | |
|         self.assertEqual(f(""), "types.UnionType")  # last one wins
 | |
|         self.assertEqual(f(b""), "typing.Union")
 | |
|         self.assertEqual(f(1), "types.UnionType")
 | |
| 
 | |
|     def test_union_None(self):
 | |
|         @functools.singledispatch
 | |
|         def typing_union(arg):
 | |
|             return "default"
 | |
| 
 | |
|         @typing_union.register
 | |
|         def _(arg: typing.Union[str, None]):
 | |
|             return "typing.Union"
 | |
| 
 | |
|         self.assertEqual(typing_union(1), "default")
 | |
|         self.assertEqual(typing_union(""), "typing.Union")
 | |
|         self.assertEqual(typing_union(None), "typing.Union")
 | |
| 
 | |
|         @functools.singledispatch
 | |
|         def types_union(arg):
 | |
|             return "default"
 | |
| 
 | |
|         @types_union.register
 | |
|         def _(arg: int | None):
 | |
|             return "types.UnionType"
 | |
| 
 | |
|         self.assertEqual(types_union(""), "default")
 | |
|         self.assertEqual(types_union(1), "types.UnionType")
 | |
|         self.assertEqual(types_union(None), "types.UnionType")
 | |
| 
 | |
|     def test_register_genericalias(self):
 | |
|         @functools.singledispatch
 | |
|         def f(arg):
 | |
|             return "default"
 | |
| 
 | |
|         with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
 | |
|             f.register(list[int], lambda arg: "types.GenericAlias")
 | |
|         with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
 | |
|             f.register(typing.List[int], lambda arg: "typing.GenericAlias")
 | |
|         with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
 | |
|             f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)")
 | |
|         with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
 | |
|             f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]")
 | |
| 
 | |
|         self.assertEqual(f([1]), "default")
 | |
|         self.assertEqual(f([1.0]), "default")
 | |
|         self.assertEqual(f(""), "default")
 | |
|         self.assertEqual(f(b""), "default")
 | |
| 
 | |
|     def test_register_genericalias_decorator(self):
 | |
|         @functools.singledispatch
 | |
|         def f(arg):
 | |
|             return "default"
 | |
| 
 | |
|         with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
 | |
|             f.register(list[int])
 | |
|         with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
 | |
|             f.register(typing.List[int])
 | |
|         with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
 | |
|             f.register(list[int] | str)
 | |
|         with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
 | |
|             f.register(typing.List[int] | str)
 | |
| 
 | |
|     def test_register_genericalias_annotation(self):
 | |
|         @functools.singledispatch
 | |
|         def f(arg):
 | |
|             return "default"
 | |
| 
 | |
|         with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
 | |
|             @f.register
 | |
|             def _(arg: list[int]):
 | |
|                 return "types.GenericAlias"
 | |
|         with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
 | |
|             @f.register
 | |
|             def _(arg: typing.List[float]):
 | |
|                 return "typing.GenericAlias"
 | |
|         with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
 | |
|             @f.register
 | |
|             def _(arg: list[int] | str):
 | |
|                 return "types.UnionType(types.GenericAlias)"
 | |
|         with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
 | |
|             @f.register
 | |
|             def _(arg: typing.List[float] | bytes):
 | |
|                 return "typing.Union[typing.GenericAlias]"
 | |
| 
 | |
|         self.assertEqual(f([1]), "default")
 | |
|         self.assertEqual(f([1.0]), "default")
 | |
|         self.assertEqual(f(""), "default")
 | |
|         self.assertEqual(f(b""), "default")
 | |
| 
 | |
| 
 | |
| class CachedCostItem:
 | |
|     _cost = 1
 | |
| 
 | |
|     def __init__(self):
 | |
|         self.lock = py_functools.RLock()
 | |
| 
 | |
|     @py_functools.cached_property
 | |
|     def cost(self):
 | |
|         """The cost of the item."""
 | |
|         with self.lock:
 | |
|             self._cost += 1
 | |
|         return self._cost
 | |
| 
 | |
| 
 | |
| class OptionallyCachedCostItem:
 | |
|     _cost = 1
 | |
| 
 | |
|     def get_cost(self):
 | |
|         """The cost of the item."""
 | |
|         self._cost += 1
 | |
|         return self._cost
 | |
| 
 | |
|     cached_cost = py_functools.cached_property(get_cost)
 | |
| 
 | |
| 
 | |
| class CachedCostItemWithSlots:
 | |
|     __slots__ = ('_cost')
 | |
| 
 | |
|     def __init__(self):
 | |
|         self._cost = 1
 | |
| 
 | |
|     @py_functools.cached_property
 | |
|     def cost(self):
 | |
|         raise RuntimeError('never called, slots not supported')
 | |
| 
 | |
| 
 | |
| class TestCachedProperty(unittest.TestCase):
 | |
|     def test_cached(self):
 | |
|         item = CachedCostItem()
 | |
|         self.assertEqual(item.cost, 2)
 | |
|         self.assertEqual(item.cost, 2) # not 3
 | |
| 
 | |
|     def test_cached_attribute_name_differs_from_func_name(self):
 | |
|         item = OptionallyCachedCostItem()
 | |
|         self.assertEqual(item.get_cost(), 2)
 | |
|         self.assertEqual(item.cached_cost, 3)
 | |
|         self.assertEqual(item.get_cost(), 4)
 | |
|         self.assertEqual(item.cached_cost, 3)
 | |
| 
 | |
|     def test_object_with_slots(self):
 | |
|         item = CachedCostItemWithSlots()
 | |
|         with self.assertRaisesRegex(
 | |
|                 TypeError,
 | |
|                 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
 | |
|         ):
 | |
|             item.cost
 | |
| 
 | |
|     def test_immutable_dict(self):
 | |
|         class MyMeta(type):
 | |
|             @py_functools.cached_property
 | |
|             def prop(self):
 | |
|                 return True
 | |
| 
 | |
|         class MyClass(metaclass=MyMeta):
 | |
|             pass
 | |
| 
 | |
|         with self.assertRaisesRegex(
 | |
|             TypeError,
 | |
|             "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
 | |
|         ):
 | |
|             MyClass.prop
 | |
| 
 | |
|     def test_reuse_different_names(self):
 | |
|         """Disallow this case because decorated function a would not be cached."""
 | |
|         with self.assertRaises(TypeError) as ctx:
 | |
|             class ReusedCachedProperty:
 | |
|                 @py_functools.cached_property
 | |
|                 def a(self):
 | |
|                     pass
 | |
| 
 | |
|                 b = a
 | |
| 
 | |
|         self.assertEqual(
 | |
|             str(ctx.exception),
 | |
|             str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
 | |
|         )
 | |
| 
 | |
|     def test_reuse_same_name(self):
 | |
|         """Reusing a cached_property on different classes under the same name is OK."""
 | |
|         counter = 0
 | |
| 
 | |
|         @py_functools.cached_property
 | |
|         def _cp(_self):
 | |
|             nonlocal counter
 | |
|             counter += 1
 | |
|             return counter
 | |
| 
 | |
|         class A:
 | |
|             cp = _cp
 | |
| 
 | |
|         class B:
 | |
|             cp = _cp
 | |
| 
 | |
|         a = A()
 | |
|         b = B()
 | |
| 
 | |
|         self.assertEqual(a.cp, 1)
 | |
|         self.assertEqual(b.cp, 2)
 | |
|         self.assertEqual(a.cp, 1)
 | |
| 
 | |
|     def test_set_name_not_called(self):
 | |
|         cp = py_functools.cached_property(lambda s: None)
 | |
|         class Foo:
 | |
|             pass
 | |
| 
 | |
|         Foo.cp = cp
 | |
| 
 | |
|         with self.assertRaisesRegex(
 | |
|                 TypeError,
 | |
|                 "Cannot use cached_property instance without calling __set_name__ on it.",
 | |
|         ):
 | |
|             Foo().cp
 | |
| 
 | |
|     def test_access_from_class(self):
 | |
|         self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
 | |
| 
 | |
|     def test_doc(self):
 | |
|         self.assertEqual(CachedCostItem.cost.__doc__,
 | |
|                          ("The cost of the item."
 | |
|                           if support.HAVE_DOCSTRINGS
 | |
|                           else None))
 | |
| 
 | |
|     def test_module(self):
 | |
|         self.assertEqual(CachedCostItem.cost.__module__, CachedCostItem.__module__)
 | |
| 
 | |
|     def test_subclass_with___set__(self):
 | |
|         """Caching still works for a subclass defining __set__."""
 | |
|         class readonly_cached_property(py_functools.cached_property):
 | |
|             def __set__(self, obj, value):
 | |
|                 raise AttributeError("read only property")
 | |
| 
 | |
|         class Test:
 | |
|             def __init__(self, prop):
 | |
|                 self._prop = prop
 | |
| 
 | |
|             @readonly_cached_property
 | |
|             def prop(self):
 | |
|                 return self._prop
 | |
| 
 | |
|         t = Test(1)
 | |
|         self.assertEqual(t.prop, 1)
 | |
|         t._prop = 999
 | |
|         self.assertEqual(t.prop, 1)
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     unittest.main()
 |