mirror of
				https://github.com/python/cpython.git
				synced 2025-11-04 07:31:38 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			192 lines
		
	
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			192 lines
		
	
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""Unit tests for collections.defaultdict."""
 | 
						|
 | 
						|
import os
 | 
						|
import copy
 | 
						|
import pickle
 | 
						|
import tempfile
 | 
						|
import unittest
 | 
						|
 | 
						|
from collections import defaultdict
 | 
						|
 | 
						|
def foobar():
 | 
						|
    return list
 | 
						|
 | 
						|
class TestDefaultDict(unittest.TestCase):
 | 
						|
 | 
						|
    def test_basic(self):
 | 
						|
        d1 = defaultdict()
 | 
						|
        self.assertEqual(d1.default_factory, None)
 | 
						|
        d1.default_factory = list
 | 
						|
        d1[12].append(42)
 | 
						|
        self.assertEqual(d1, {12: [42]})
 | 
						|
        d1[12].append(24)
 | 
						|
        self.assertEqual(d1, {12: [42, 24]})
 | 
						|
        d1[13]
 | 
						|
        d1[14]
 | 
						|
        self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
 | 
						|
        self.assertTrue(d1[12] is not d1[13] is not d1[14])
 | 
						|
        d2 = defaultdict(list, foo=1, bar=2)
 | 
						|
        self.assertEqual(d2.default_factory, list)
 | 
						|
        self.assertEqual(d2, {"foo": 1, "bar": 2})
 | 
						|
        self.assertEqual(d2["foo"], 1)
 | 
						|
        self.assertEqual(d2["bar"], 2)
 | 
						|
        self.assertEqual(d2[42], [])
 | 
						|
        self.assertIn("foo", d2)
 | 
						|
        self.assertIn("foo", d2.keys())
 | 
						|
        self.assertIn("bar", d2)
 | 
						|
        self.assertIn("bar", d2.keys())
 | 
						|
        self.assertIn(42, d2)
 | 
						|
        self.assertIn(42, d2.keys())
 | 
						|
        self.assertNotIn(12, d2)
 | 
						|
        self.assertNotIn(12, d2.keys())
 | 
						|
        d2.default_factory = None
 | 
						|
        self.assertEqual(d2.default_factory, None)
 | 
						|
        try:
 | 
						|
            d2[15]
 | 
						|
        except KeyError as err:
 | 
						|
            self.assertEqual(err.args, (15,))
 | 
						|
        else:
 | 
						|
            self.fail("d2[15] didn't raise KeyError")
 | 
						|
        self.assertRaises(TypeError, defaultdict, 1)
 | 
						|
 | 
						|
    def test_missing(self):
 | 
						|
        d1 = defaultdict()
 | 
						|
        self.assertRaises(KeyError, d1.__missing__, 42)
 | 
						|
        d1.default_factory = list
 | 
						|
        self.assertEqual(d1.__missing__(42), [])
 | 
						|
 | 
						|
    def test_repr(self):
 | 
						|
        d1 = defaultdict()
 | 
						|
        self.assertEqual(d1.default_factory, None)
 | 
						|
        self.assertEqual(repr(d1), "defaultdict(None, {})")
 | 
						|
        self.assertEqual(eval(repr(d1)), d1)
 | 
						|
        d1[11] = 41
 | 
						|
        self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
 | 
						|
        d2 = defaultdict(int)
 | 
						|
        self.assertEqual(d2.default_factory, int)
 | 
						|
        d2[12] = 42
 | 
						|
        self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})")
 | 
						|
        def foo(): return 43
 | 
						|
        d3 = defaultdict(foo)
 | 
						|
        self.assertTrue(d3.default_factory is foo)
 | 
						|
        d3[13]
 | 
						|
        self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
 | 
						|
 | 
						|
    def test_copy(self):
 | 
						|
        d1 = defaultdict()
 | 
						|
        d2 = d1.copy()
 | 
						|
        self.assertEqual(type(d2), defaultdict)
 | 
						|
        self.assertEqual(d2.default_factory, None)
 | 
						|
        self.assertEqual(d2, {})
 | 
						|
        d1.default_factory = list
 | 
						|
        d3 = d1.copy()
 | 
						|
        self.assertEqual(type(d3), defaultdict)
 | 
						|
        self.assertEqual(d3.default_factory, list)
 | 
						|
        self.assertEqual(d3, {})
 | 
						|
        d1[42]
 | 
						|
        d4 = d1.copy()
 | 
						|
        self.assertEqual(type(d4), defaultdict)
 | 
						|
        self.assertEqual(d4.default_factory, list)
 | 
						|
        self.assertEqual(d4, {42: []})
 | 
						|
        d4[12]
 | 
						|
        self.assertEqual(d4, {42: [], 12: []})
 | 
						|
 | 
						|
        # Issue 6637: Copy fails for empty default dict
 | 
						|
        d = defaultdict()
 | 
						|
        d['a'] = 42
 | 
						|
        e = d.copy()
 | 
						|
        self.assertEqual(e['a'], 42)
 | 
						|
 | 
						|
    def test_shallow_copy(self):
 | 
						|
        d1 = defaultdict(foobar, {1: 1})
 | 
						|
        d2 = copy.copy(d1)
 | 
						|
        self.assertEqual(d2.default_factory, foobar)
 | 
						|
        self.assertEqual(d2, d1)
 | 
						|
        d1.default_factory = list
 | 
						|
        d2 = copy.copy(d1)
 | 
						|
        self.assertEqual(d2.default_factory, list)
 | 
						|
        self.assertEqual(d2, d1)
 | 
						|
 | 
						|
    def test_deep_copy(self):
 | 
						|
        d1 = defaultdict(foobar, {1: [1]})
 | 
						|
        d2 = copy.deepcopy(d1)
 | 
						|
        self.assertEqual(d2.default_factory, foobar)
 | 
						|
        self.assertEqual(d2, d1)
 | 
						|
        self.assertTrue(d1[1] is not d2[1])
 | 
						|
        d1.default_factory = list
 | 
						|
        d2 = copy.deepcopy(d1)
 | 
						|
        self.assertEqual(d2.default_factory, list)
 | 
						|
        self.assertEqual(d2, d1)
 | 
						|
 | 
						|
    def test_keyerror_without_factory(self):
 | 
						|
        d1 = defaultdict()
 | 
						|
        try:
 | 
						|
            d1[(1,)]
 | 
						|
        except KeyError as err:
 | 
						|
            self.assertEqual(err.args[0], (1,))
 | 
						|
        else:
 | 
						|
            self.fail("expected KeyError")
 | 
						|
 | 
						|
    def test_recursive_repr(self):
 | 
						|
        # Issue2045: stack overflow when default_factory is a bound method
 | 
						|
        class sub(defaultdict):
 | 
						|
            def __init__(self):
 | 
						|
                self.default_factory = self._factory
 | 
						|
            def _factory(self):
 | 
						|
                return []
 | 
						|
        d = sub()
 | 
						|
        self.assertRegex(repr(d),
 | 
						|
            r"sub\(<bound method .*sub\._factory "
 | 
						|
            r"of sub\(\.\.\., \{\}\)>, \{\}\)")
 | 
						|
 | 
						|
    def test_callable_arg(self):
 | 
						|
        self.assertRaises(TypeError, defaultdict, {})
 | 
						|
 | 
						|
    def test_pickling(self):
 | 
						|
        d = defaultdict(int)
 | 
						|
        d[1]
 | 
						|
        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
 | 
						|
            s = pickle.dumps(d, proto)
 | 
						|
            o = pickle.loads(s)
 | 
						|
            self.assertEqual(d, o)
 | 
						|
 | 
						|
    def test_union(self):
 | 
						|
        i = defaultdict(int, {1: 1, 2: 2})
 | 
						|
        s = defaultdict(str, {0: "zero", 1: "one"})
 | 
						|
 | 
						|
        i_s = i | s
 | 
						|
        self.assertIs(i_s.default_factory, int)
 | 
						|
        self.assertDictEqual(i_s, {1: "one", 2: 2, 0: "zero"})
 | 
						|
        self.assertEqual(list(i_s), [1, 2, 0])
 | 
						|
 | 
						|
        s_i = s | i
 | 
						|
        self.assertIs(s_i.default_factory, str)
 | 
						|
        self.assertDictEqual(s_i, {0: "zero", 1: 1, 2: 2})
 | 
						|
        self.assertEqual(list(s_i), [0, 1, 2])
 | 
						|
 | 
						|
        i_ds = i | dict(s)
 | 
						|
        self.assertIs(i_ds.default_factory, int)
 | 
						|
        self.assertDictEqual(i_ds, {1: "one", 2: 2, 0: "zero"})
 | 
						|
        self.assertEqual(list(i_ds), [1, 2, 0])
 | 
						|
 | 
						|
        ds_i = dict(s) | i
 | 
						|
        self.assertIs(ds_i.default_factory, int)
 | 
						|
        self.assertDictEqual(ds_i, {0: "zero", 1: 1, 2: 2})
 | 
						|
        self.assertEqual(list(ds_i), [0, 1, 2])
 | 
						|
 | 
						|
        with self.assertRaises(TypeError):
 | 
						|
            i | list(s.items())
 | 
						|
        with self.assertRaises(TypeError):
 | 
						|
            list(s.items()) | i
 | 
						|
 | 
						|
        # We inherit a fine |= from dict, so just a few sanity checks here:
 | 
						|
        i |= list(s.items())
 | 
						|
        self.assertIs(i.default_factory, int)
 | 
						|
        self.assertDictEqual(i, {1: "one", 2: 2, 0: "zero"})
 | 
						|
        self.assertEqual(list(i), [1, 2, 0])
 | 
						|
 | 
						|
        with self.assertRaises(TypeError):
 | 
						|
            i |= None
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    unittest.main()
 |