mirror of
				https://github.com/python/cpython.git
				synced 2025-10-26 11:14:33 +00:00 
			
		
		
		
	 fd5c414880
			
		
	
	
		fd5c414880
		
			
		
	
	
	
	
		
			
			The symbol table handing of PEP572's assignment expressions is not resolving correctly the scope of some variables in presence of global/nonlocal keywords in conjunction with comprehensions.
		
			
				
	
	
		
			519 lines
		
	
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			519 lines
		
	
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import unittest
 | |
| 
 | |
| GLOBAL_VAR = None
 | |
| 
 | |
| class NamedExpressionInvalidTest(unittest.TestCase):
 | |
| 
 | |
|     def test_named_expression_invalid_01(self):
 | |
|         code = """x := 0"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_02(self):
 | |
|         code = """x = y := 0"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_03(self):
 | |
|         code = """y := f(x)"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_04(self):
 | |
|         code = """y0 = y1 := f(x)"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_06(self):
 | |
|         code = """((a, b) := (1, 2))"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError, "cannot use named assignment with tuple"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_07(self):
 | |
|         code = """def spam(a = b := 42): pass"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_08(self):
 | |
|         code = """def spam(a: b := 42 = 5): pass"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_09(self):
 | |
|         code = """spam(a=b := 'c')"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_10(self):
 | |
|         code = """spam(x = y := f(x))"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_11(self):
 | |
|         code = """spam(a=1, b := 2)"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError,
 | |
|             "positional argument follows keyword argument"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_12(self):
 | |
|         code = """spam(a=1, (b := 2))"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError,
 | |
|             "positional argument follows keyword argument"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_13(self):
 | |
|         code = """spam(a=1, (b := 2))"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError,
 | |
|             "positional argument follows keyword argument"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_14(self):
 | |
|         code = """(x := lambda: y := 1)"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_15(self):
 | |
|         code = """(lambda: x := 1)"""
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError,
 | |
|             "cannot use named assignment with lambda"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_16(self):
 | |
|         code = "[i + 1 for i in i := [1,2]]"
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_17(self):
 | |
|         code = "[i := 0, j := 1 for i, j in [(1, 2), (3, 4)]]"
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_in_class_body(self):
 | |
|         code = """class Foo():
 | |
|             [(42, 1 + ((( j := i )))) for i in range(5)]
 | |
|         """
 | |
| 
 | |
|         with self.assertRaisesRegex(SyntaxError,
 | |
|             "assignment expression within a comprehension cannot be used in a class body"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_rebinding_comprehension_iteration_variable(self):
 | |
|         cases = [
 | |
|             ("Local reuse", 'i', "[i := 0 for i in range(5)]"),
 | |
|             ("Nested reuse", 'j', "[[(j := 0) for i in range(5)] for j in range(5)]"),
 | |
|             ("Reuse inner loop target", 'j', "[(j := 0) for i in range(5) for j in range(5)]"),
 | |
|             ("Unpacking reuse", 'i', "[i := 0 for i, j in [(0, 1)]]"),
 | |
|             ("Reuse in loop condition", 'i', "[i+1 for i in range(5) if (i := 0)]"),
 | |
|             ("Unreachable reuse", 'i', "[False or (i:=0) for i in range(5)]"),
 | |
|             ("Unreachable nested reuse", 'i',
 | |
|                 "[(i, j) for i in range(5) for j in range(5) if True or (i:=10)]"),
 | |
|         ]
 | |
|         for case, target, code in cases:
 | |
|             msg = f"assignment expression cannot rebind comprehension iteration variable '{target}'"
 | |
|             with self.subTest(case=case):
 | |
|                 with self.assertRaisesRegex(SyntaxError, msg):
 | |
|                     exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_invalid_rebinding_comprehension_inner_loop(self):
 | |
|         cases = [
 | |
|             ("Inner reuse", 'j', "[i for i in range(5) if (j := 0) for j in range(5)]"),
 | |
|             ("Inner unpacking reuse", 'j', "[i for i in range(5) if (j := 0) for j, k in [(0, 1)]]"),
 | |
|         ]
 | |
|         for case, target, code in cases:
 | |
|             msg = f"comprehension inner loop cannot rebind assignment expression target '{target}'"
 | |
|             with self.subTest(case=case):
 | |
|                 with self.assertRaisesRegex(SyntaxError, msg):
 | |
|                     exec(code, {}) # Module scope
 | |
|                 with self.assertRaisesRegex(SyntaxError, msg):
 | |
|                     exec(code, {}, {}) # Class scope
 | |
|                 with self.assertRaisesRegex(SyntaxError, msg):
 | |
|                     exec(f"lambda: {code}", {}) # Function scope
 | |
| 
 | |
|     def test_named_expression_invalid_comprehension_iterable_expression(self):
 | |
|         cases = [
 | |
|             ("Top level", "[i for i in (i := range(5))]"),
 | |
|             ("Inside tuple", "[i for i in (2, 3, i := range(5))]"),
 | |
|             ("Inside list", "[i for i in [2, 3, i := range(5)]]"),
 | |
|             ("Different name", "[i for i in (j := range(5))]"),
 | |
|             ("Lambda expression", "[i for i in (lambda:(j := range(5)))()]"),
 | |
|             ("Inner loop", "[i for i in range(5) for j in (i := range(5))]"),
 | |
|             ("Nested comprehension", "[i for i in [j for j in (k := range(5))]]"),
 | |
|             ("Nested comprehension condition", "[i for i in [j for j in range(5) if (j := True)]]"),
 | |
|             ("Nested comprehension body", "[i for i in [(j := True) for j in range(5)]]"),
 | |
|         ]
 | |
|         msg = "assignment expression cannot be used in a comprehension iterable expression"
 | |
|         for case, code in cases:
 | |
|             with self.subTest(case=case):
 | |
|                 with self.assertRaisesRegex(SyntaxError, msg):
 | |
|                     exec(code, {}) # Module scope
 | |
|                 with self.assertRaisesRegex(SyntaxError, msg):
 | |
|                     exec(code, {}, {}) # Class scope
 | |
|                 with self.assertRaisesRegex(SyntaxError, msg):
 | |
|                     exec(f"lambda: {code}", {}) # Function scope
 | |
| 
 | |
| 
 | |
| class NamedExpressionAssignmentTest(unittest.TestCase):
 | |
| 
 | |
|     def test_named_expression_assignment_01(self):
 | |
|         (a := 10)
 | |
| 
 | |
|         self.assertEqual(a, 10)
 | |
| 
 | |
|     def test_named_expression_assignment_02(self):
 | |
|         a = 20
 | |
|         (a := a)
 | |
| 
 | |
|         self.assertEqual(a, 20)
 | |
| 
 | |
|     def test_named_expression_assignment_03(self):
 | |
|         (total := 1 + 2)
 | |
| 
 | |
|         self.assertEqual(total, 3)
 | |
| 
 | |
|     def test_named_expression_assignment_04(self):
 | |
|         (info := (1, 2, 3))
 | |
| 
 | |
|         self.assertEqual(info, (1, 2, 3))
 | |
| 
 | |
|     def test_named_expression_assignment_05(self):
 | |
|         (x := 1, 2)
 | |
| 
 | |
|         self.assertEqual(x, 1)
 | |
| 
 | |
|     def test_named_expression_assignment_06(self):
 | |
|         (z := (y := (x := 0)))
 | |
| 
 | |
|         self.assertEqual(x, 0)
 | |
|         self.assertEqual(y, 0)
 | |
|         self.assertEqual(z, 0)
 | |
| 
 | |
|     def test_named_expression_assignment_07(self):
 | |
|         (loc := (1, 2))
 | |
| 
 | |
|         self.assertEqual(loc, (1, 2))
 | |
| 
 | |
|     def test_named_expression_assignment_08(self):
 | |
|         if spam := "eggs":
 | |
|             self.assertEqual(spam, "eggs")
 | |
|         else: self.fail("variable was not assigned using named expression")
 | |
| 
 | |
|     def test_named_expression_assignment_09(self):
 | |
|         if True and (spam := True):
 | |
|             self.assertTrue(spam)
 | |
|         else: self.fail("variable was not assigned using named expression")
 | |
| 
 | |
|     def test_named_expression_assignment_10(self):
 | |
|         if (match := 10) == 10:
 | |
|             pass
 | |
|         else: self.fail("variable was not assigned using named expression")
 | |
| 
 | |
|     def test_named_expression_assignment_11(self):
 | |
|         def spam(a):
 | |
|             return a
 | |
|         input_data = [1, 2, 3]
 | |
|         res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
 | |
| 
 | |
|         self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
 | |
| 
 | |
|     def test_named_expression_assignment_12(self):
 | |
|         def spam(a):
 | |
|             return a
 | |
|         res = [[y := spam(x), x/y] for x in range(1, 5)]
 | |
| 
 | |
|         self.assertEqual(res, [[1, 1.0], [2, 1.0], [3, 1.0], [4, 1.0]])
 | |
| 
 | |
|     def test_named_expression_assignment_13(self):
 | |
|         length = len(lines := [1, 2])
 | |
| 
 | |
|         self.assertEqual(length, 2)
 | |
|         self.assertEqual(lines, [1,2])
 | |
| 
 | |
|     def test_named_expression_assignment_14(self):
 | |
|         """
 | |
|         Where all variables are positive integers, and a is at least as large
 | |
|         as the n'th root of x, this algorithm returns the floor of the n'th
 | |
|         root of x (and roughly doubling the number of accurate bits per
 | |
|         iteration):
 | |
|         """
 | |
|         a = 9
 | |
|         n = 2
 | |
|         x = 3
 | |
| 
 | |
|         while a > (d := x // a**(n-1)):
 | |
|             a = ((n-1)*a + d) // n
 | |
| 
 | |
|         self.assertEqual(a, 1)
 | |
| 
 | |
|     def test_named_expression_assignment_15(self):
 | |
|         while a := False:
 | |
|             pass  # This will not run
 | |
| 
 | |
|         self.assertEqual(a, False)
 | |
| 
 | |
|     def test_named_expression_assignment_16(self):
 | |
|         a, b = 1, 2
 | |
|         fib = {(c := a): (a := b) + (b := a + c) - b for __ in range(6)}
 | |
|         self.assertEqual(fib, {1: 2, 2: 3, 3: 5, 5: 8, 8: 13, 13: 21})
 | |
| 
 | |
| 
 | |
| class NamedExpressionScopeTest(unittest.TestCase):
 | |
| 
 | |
|     def test_named_expression_scope_01(self):
 | |
|         code = """def spam():
 | |
|     (a := 5)
 | |
| print(a)"""
 | |
| 
 | |
|         with self.assertRaisesRegex(NameError, "name 'a' is not defined"):
 | |
|             exec(code, {}, {})
 | |
| 
 | |
|     def test_named_expression_scope_02(self):
 | |
|         total = 0
 | |
|         partial_sums = [total := total + v for v in range(5)]
 | |
| 
 | |
|         self.assertEqual(partial_sums, [0, 1, 3, 6, 10])
 | |
|         self.assertEqual(total, 10)
 | |
| 
 | |
|     def test_named_expression_scope_03(self):
 | |
|         containsOne = any((lastNum := num) == 1 for num in [1, 2, 3])
 | |
| 
 | |
|         self.assertTrue(containsOne)
 | |
|         self.assertEqual(lastNum, 1)
 | |
| 
 | |
|     def test_named_expression_scope_04(self):
 | |
|         def spam(a):
 | |
|             return a
 | |
|         res = [[y := spam(x), x/y] for x in range(1, 5)]
 | |
| 
 | |
|         self.assertEqual(y, 4)
 | |
| 
 | |
|     def test_named_expression_scope_05(self):
 | |
|         def spam(a):
 | |
|             return a
 | |
|         input_data = [1, 2, 3]
 | |
|         res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
 | |
| 
 | |
|         self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
 | |
|         self.assertEqual(y, 3)
 | |
| 
 | |
|     def test_named_expression_scope_06(self):
 | |
|         res = [[spam := i for i in range(3)] for j in range(2)]
 | |
| 
 | |
|         self.assertEqual(res, [[0, 1, 2], [0, 1, 2]])
 | |
|         self.assertEqual(spam, 2)
 | |
| 
 | |
|     def test_named_expression_scope_07(self):
 | |
|         len(lines := [1, 2])
 | |
| 
 | |
|         self.assertEqual(lines, [1, 2])
 | |
| 
 | |
|     def test_named_expression_scope_08(self):
 | |
|         def spam(a):
 | |
|             return a
 | |
| 
 | |
|         def eggs(b):
 | |
|             return b * 2
 | |
| 
 | |
|         res = [spam(a := eggs(b := h)) for h in range(2)]
 | |
| 
 | |
|         self.assertEqual(res, [0, 2])
 | |
|         self.assertEqual(a, 2)
 | |
|         self.assertEqual(b, 1)
 | |
| 
 | |
|     def test_named_expression_scope_09(self):
 | |
|         def spam(a):
 | |
|             return a
 | |
| 
 | |
|         def eggs(b):
 | |
|             return b * 2
 | |
| 
 | |
|         res = [spam(a := eggs(a := h)) for h in range(2)]
 | |
| 
 | |
|         self.assertEqual(res, [0, 2])
 | |
|         self.assertEqual(a, 2)
 | |
| 
 | |
|     def test_named_expression_scope_10(self):
 | |
|         res = [b := [a := 1 for i in range(2)] for j in range(2)]
 | |
| 
 | |
|         self.assertEqual(res, [[1, 1], [1, 1]])
 | |
|         self.assertEqual(a, 1)
 | |
|         self.assertEqual(b, [1, 1])
 | |
| 
 | |
|     def test_named_expression_scope_11(self):
 | |
|         res = [j := i for i in range(5)]
 | |
| 
 | |
|         self.assertEqual(res, [0, 1, 2, 3, 4])
 | |
|         self.assertEqual(j, 4)
 | |
| 
 | |
|     def test_named_expression_scope_17(self):
 | |
|         b = 0
 | |
|         res = [b := i + b for i in range(5)]
 | |
| 
 | |
|         self.assertEqual(res, [0, 1, 3, 6, 10])
 | |
|         self.assertEqual(b, 10)
 | |
| 
 | |
|     def test_named_expression_scope_18(self):
 | |
|         def spam(a):
 | |
|             return a
 | |
| 
 | |
|         res = spam(b := 2)
 | |
| 
 | |
|         self.assertEqual(res, 2)
 | |
|         self.assertEqual(b, 2)
 | |
| 
 | |
|     def test_named_expression_scope_19(self):
 | |
|         def spam(a):
 | |
|             return a
 | |
| 
 | |
|         res = spam((b := 2))
 | |
| 
 | |
|         self.assertEqual(res, 2)
 | |
|         self.assertEqual(b, 2)
 | |
| 
 | |
|     def test_named_expression_scope_20(self):
 | |
|         def spam(a):
 | |
|             return a
 | |
| 
 | |
|         res = spam(a=(b := 2))
 | |
| 
 | |
|         self.assertEqual(res, 2)
 | |
|         self.assertEqual(b, 2)
 | |
| 
 | |
|     def test_named_expression_scope_21(self):
 | |
|         def spam(a, b):
 | |
|             return a + b
 | |
| 
 | |
|         res = spam(c := 2, b=1)
 | |
| 
 | |
|         self.assertEqual(res, 3)
 | |
|         self.assertEqual(c, 2)
 | |
| 
 | |
|     def test_named_expression_scope_22(self):
 | |
|         def spam(a, b):
 | |
|             return a + b
 | |
| 
 | |
|         res = spam((c := 2), b=1)
 | |
| 
 | |
|         self.assertEqual(res, 3)
 | |
|         self.assertEqual(c, 2)
 | |
| 
 | |
|     def test_named_expression_scope_23(self):
 | |
|         def spam(a, b):
 | |
|             return a + b
 | |
| 
 | |
|         res = spam(b=(c := 2), a=1)
 | |
| 
 | |
|         self.assertEqual(res, 3)
 | |
|         self.assertEqual(c, 2)
 | |
| 
 | |
|     def test_named_expression_scope_24(self):
 | |
|         a = 10
 | |
|         def spam():
 | |
|             nonlocal a
 | |
|             (a := 20)
 | |
|         spam()
 | |
| 
 | |
|         self.assertEqual(a, 20)
 | |
| 
 | |
|     def test_named_expression_scope_25(self):
 | |
|         ns = {}
 | |
|         code = """a = 10
 | |
| def spam():
 | |
|     global a
 | |
|     (a := 20)
 | |
| spam()"""
 | |
| 
 | |
|         exec(code, ns, {})
 | |
| 
 | |
|         self.assertEqual(ns["a"], 20)
 | |
| 
 | |
|     def test_named_expression_variable_reuse_in_comprehensions(self):
 | |
|         # The compiler is expected to raise syntax error for comprehension
 | |
|         # iteration variables, but should be fine with rebinding of other
 | |
|         # names (e.g. globals, nonlocals, other assignment expressions)
 | |
| 
 | |
|         # The cases are all defined to produce the same expected result
 | |
|         # Each comprehension is checked at both function scope and module scope
 | |
|         rebinding = "[x := i for i in range(3) if (x := i) or not x]"
 | |
|         filter_ref = "[x := i for i in range(3) if x or not x]"
 | |
|         body_ref = "[x for i in range(3) if (x := i) or not x]"
 | |
|         nested_ref = "[j for i in range(3) if x or not x for j in range(3) if (x := i)][:-3]"
 | |
|         cases = [
 | |
|             ("Rebind global", f"x = 1; result = {rebinding}"),
 | |
|             ("Rebind nonlocal", f"result, x = (lambda x=1: ({rebinding}, x))()"),
 | |
|             ("Filter global", f"x = 1; result = {filter_ref}"),
 | |
|             ("Filter nonlocal", f"result, x = (lambda x=1: ({filter_ref}, x))()"),
 | |
|             ("Body global", f"x = 1; result = {body_ref}"),
 | |
|             ("Body nonlocal", f"result, x = (lambda x=1: ({body_ref}, x))()"),
 | |
|             ("Nested global", f"x = 1; result = {nested_ref}"),
 | |
|             ("Nested nonlocal", f"result, x = (lambda x=1: ({nested_ref}, x))()"),
 | |
|         ]
 | |
|         for case, code in cases:
 | |
|             with self.subTest(case=case):
 | |
|                 ns = {}
 | |
|                 exec(code, ns)
 | |
|                 self.assertEqual(ns["x"], 2)
 | |
|                 self.assertEqual(ns["result"], [0, 1, 2])
 | |
| 
 | |
|     def test_named_expression_global_scope(self):
 | |
|         sentinel = object()
 | |
|         global GLOBAL_VAR
 | |
|         def f():
 | |
|             global GLOBAL_VAR
 | |
|             [GLOBAL_VAR := sentinel for _ in range(1)]
 | |
|             self.assertEqual(GLOBAL_VAR, sentinel)
 | |
|         try:
 | |
|             f()
 | |
|             self.assertEqual(GLOBAL_VAR, sentinel)
 | |
|         finally:
 | |
|             GLOBAL_VAR = None
 | |
| 
 | |
|     def test_named_expression_global_scope_no_global_keyword(self):
 | |
|         sentinel = object()
 | |
|         def f():
 | |
|             GLOBAL_VAR = None
 | |
|             [GLOBAL_VAR := sentinel for _ in range(1)]
 | |
|             self.assertEqual(GLOBAL_VAR, sentinel)
 | |
|         f()
 | |
|         self.assertEqual(GLOBAL_VAR, None)
 | |
| 
 | |
|     def test_named_expression_nonlocal_scope(self):
 | |
|         sentinel = object()
 | |
|         def f():
 | |
|             nonlocal_var = None
 | |
|             def g():
 | |
|                 nonlocal nonlocal_var
 | |
|                 [nonlocal_var := sentinel for _ in range(1)]
 | |
|             g()
 | |
|             self.assertEqual(nonlocal_var, sentinel)
 | |
|         f()
 | |
| 
 | |
|     def test_named_expression_nonlocal_scope_no_nonlocal_keyword(self):
 | |
|         sentinel = object()
 | |
|         def f():
 | |
|             nonlocal_var = None
 | |
|             def g():
 | |
|                 [nonlocal_var := sentinel for _ in range(1)]
 | |
|             g()
 | |
|             self.assertEqual(nonlocal_var, None)
 | |
|         f()
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     unittest.main()
 |