diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py index 9cfaa4a86fa..14b501360d0 100644 --- a/Lib/test/test_dict.py +++ b/Lib/test/test_dict.py @@ -1871,6 +1871,35 @@ def test_pickle_iter(self): self.assertEqual(list(unpickled), expected) self.assertEqual(list(it), expected) + def test_unhashable_key(self): + d = frozendict(a=1) + key = [1, 2, 3] + + def check_unhashable_key(): + msg = "cannot use 'list' as a frozendict key (unhashable type: 'list')" + return self.assertRaisesRegex(TypeError, re.escape(msg)) + + with check_unhashable_key(): + key in d + with check_unhashable_key(): + d[key] + with check_unhashable_key(): + d.get(key) + + # Only TypeError exception is overridden, + # other exceptions are left unchanged. + class HashError: + def __hash__(self): + raise KeyError('error') + + key2 = HashError() + with self.assertRaises(KeyError): + key2 in d + with self.assertRaises(KeyError): + d[key2] + with self.assertRaises(KeyError): + d.get(key2) + if __name__ == "__main__": unittest.main() diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 0a8ba74c228..6c802ca569d 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -2377,7 +2377,7 @@ PyDict_GetItem(PyObject *op, PyObject *key) } static void -dict_unhashable_type(PyObject *key) +dict_unhashable_type(PyObject *op, PyObject *key) { PyObject *exc = PyErr_GetRaisedException(); assert(exc != NULL); @@ -2386,9 +2386,14 @@ dict_unhashable_type(PyObject *key) return; } - PyErr_Format(PyExc_TypeError, - "cannot use '%T' as a dict key (%S)", - key, exc); + const char *errmsg; + if (PyObject_IsInstance(op, (PyObject*)&PyFrozenDict_Type)) { + errmsg = "cannot use '%T' as a frozendict key (%S)"; + } + else { + errmsg = "cannot use '%T' as a dict key (%S)"; + } + PyErr_Format(PyExc_TypeError, errmsg, key, exc); Py_DECREF(exc); } @@ -2401,7 +2406,7 @@ _PyDict_LookupIndexAndValue(PyDictObject *mp, PyObject *key, PyObject **value) Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(key); + dict_unhashable_type((PyObject*)mp, key); return -1; } @@ -2505,7 +2510,7 @@ PyDict_GetItemRef(PyObject *op, PyObject *key, PyObject **result) Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(key); + dict_unhashable_type(op, key); *result = NULL; return -1; } @@ -2521,7 +2526,7 @@ _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject ** Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(key); + dict_unhashable_type((PyObject*)op, key); *result = NULL; return -1; } @@ -2559,7 +2564,7 @@ PyDict_GetItemWithError(PyObject *op, PyObject *key) } hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(key); + dict_unhashable_type(op, key); return NULL; } @@ -2705,7 +2710,7 @@ setitem_take2_lock_held(PyDictObject *mp, PyObject *key, PyObject *value) Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(key); + dict_unhashable_type((PyObject*)mp, key); Py_DECREF(key); Py_DECREF(value); return -1; @@ -2864,7 +2869,7 @@ PyDict_DelItem(PyObject *op, PyObject *key) assert(key); Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(key); + dict_unhashable_type(op, key); return -1; } @@ -3183,7 +3188,7 @@ pop_lock_held(PyObject *op, PyObject *key, PyObject **result) Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(key); + dict_unhashable_type(op, key); if (result) { *result = NULL; } @@ -3596,7 +3601,7 @@ dict_subscript(PyObject *self, PyObject *key) hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(key); + dict_unhashable_type(self, key); return NULL; } ix = _Py_dict_lookup_threadsafe(mp, key, hash, &value); @@ -4515,7 +4520,7 @@ dict_get_impl(PyDictObject *self, PyObject *key, PyObject *default_value) hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(key); + dict_unhashable_type((PyObject*)self, key); return NULL; } ix = _Py_dict_lookup_threadsafe(self, key, hash, &val); @@ -4547,7 +4552,7 @@ dict_setdefault_ref_lock_held(PyObject *d, PyObject *key, PyObject *default_valu hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(key); + dict_unhashable_type(d, key); if (result) { *result = NULL; } @@ -4990,7 +4995,7 @@ dict_contains(PyObject *op, PyObject *key) { Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(key); + dict_unhashable_type(op, key); return -1; } @@ -7066,7 +7071,7 @@ _PyDict_SetItem_LockHeld(PyDictObject *dict, PyObject *name, PyObject *value) if (value == NULL) { Py_hash_t hash = _PyObject_HashFast(name); if (hash == -1) { - dict_unhashable_type(name); + dict_unhashable_type((PyObject*)dict, name); return -1; } return _PyDict_DelItem_KnownHash_LockHeld((PyObject *)dict, name, hash);