diff --git a/Lib/test/test_external_inspection.py b/Lib/test/test_external_inspection.py index 2e6e6eaad06..f664e8ac53f 100644 --- a/Lib/test/test_external_inspection.py +++ b/Lib/test/test_external_inspection.py @@ -7,6 +7,7 @@ import socket import threading import time +from contextlib import contextmanager from asyncio import staggered, taskgroups, base_events, tasks from unittest.mock import ANY from test.support import ( @@ -27,9 +28,12 @@ PROFILING_MODE_ALL = 3 # Thread status flags -THREAD_STATUS_HAS_GIL = (1 << 0) -THREAD_STATUS_ON_CPU = (1 << 1) -THREAD_STATUS_UNKNOWN = (1 << 2) +THREAD_STATUS_HAS_GIL = 1 << 0 +THREAD_STATUS_ON_CPU = 1 << 1 +THREAD_STATUS_UNKNOWN = 1 << 2 + +# Maximum number of retry attempts for operations that may fail transiently +MAX_TRIES = 10 try: from concurrent import interpreters @@ -48,12 +52,149 @@ ) +# ============================================================================ +# Module-level helper functions +# ============================================================================ + + def _make_test_script(script_dir, script_basename, source): to_return = make_script(script_dir, script_basename, source) importlib.invalidate_caches() return to_return +def _create_server_socket(port, backlog=1): + """Create and configure a server socket for test communication.""" + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server_socket.bind(("localhost", port)) + server_socket.settimeout(SHORT_TIMEOUT) + server_socket.listen(backlog) + return server_socket + + +def _wait_for_signal(sock, expected_signals, timeout=SHORT_TIMEOUT): + """ + Wait for expected signal(s) from a socket with proper timeout and EOF handling. + + Args: + sock: Connected socket to read from + expected_signals: Single bytes object or list of bytes objects to wait for + timeout: Socket timeout in seconds + + Returns: + bytes: Complete accumulated response buffer + + Raises: + RuntimeError: If connection closed before signal received or timeout + """ + if isinstance(expected_signals, bytes): + expected_signals = [expected_signals] + + sock.settimeout(timeout) + buffer = b"" + + while True: + # Check if all expected signals are in buffer + if all(sig in buffer for sig in expected_signals): + return buffer + + try: + chunk = sock.recv(4096) + if not chunk: + # EOF - connection closed + raise RuntimeError( + f"Connection closed before receiving expected signals. " + f"Expected: {expected_signals}, Got: {buffer[-200:]!r}" + ) + buffer += chunk + except socket.timeout: + raise RuntimeError( + f"Timeout waiting for signals. " + f"Expected: {expected_signals}, Got: {buffer[-200:]!r}" + ) + + +def _wait_for_n_signals(sock, signal_pattern, count, timeout=SHORT_TIMEOUT): + """ + Wait for N occurrences of a signal pattern. + + Args: + sock: Connected socket to read from + signal_pattern: bytes pattern to count (e.g., b"ready") + count: Number of occurrences expected + timeout: Socket timeout in seconds + + Returns: + bytes: Complete accumulated response buffer + + Raises: + RuntimeError: If connection closed or timeout before receiving all signals + """ + sock.settimeout(timeout) + buffer = b"" + found_count = 0 + + while found_count < count: + try: + chunk = sock.recv(4096) + if not chunk: + raise RuntimeError( + f"Connection closed after {found_count}/{count} signals. " + f"Last 200 bytes: {buffer[-200:]!r}" + ) + buffer += chunk + # Count occurrences in entire buffer + found_count = buffer.count(signal_pattern) + except socket.timeout: + raise RuntimeError( + f"Timeout waiting for {count} signals (found {found_count}). " + f"Last 200 bytes: {buffer[-200:]!r}" + ) + + return buffer + + +@contextmanager +def _managed_subprocess(args, timeout=SHORT_TIMEOUT): + """ + Context manager for subprocess lifecycle management. + + Ensures process is properly terminated and cleaned up even on exceptions. + Uses graceful termination first, then forceful kill if needed. + """ + p = subprocess.Popen(args) + try: + yield p + finally: + try: + p.terminate() + try: + p.wait(timeout=timeout) + except subprocess.TimeoutExpired: + p.kill() + try: + p.wait(timeout=timeout) + except subprocess.TimeoutExpired: + pass # Process refuses to die, nothing more we can do + except OSError: + pass # Process already dead + + +def _cleanup_sockets(*sockets): + """Safely close multiple sockets, ignoring errors.""" + for sock in sockets: + if sock is not None: + try: + sock.close() + except OSError: + pass + + +# ============================================================================ +# Decorators and skip conditions +# ============================================================================ + skip_if_not_supported = unittest.skipIf( ( sys.platform != "darwin" @@ -66,40 +207,220 @@ def _make_test_script(script_dir, script_basename, source): def requires_subinterpreters(meth): """Decorator to skip a test if subinterpreters are not supported.""" - return unittest.skipIf(interpreters is None, - 'subinterpreters required')(meth) + return unittest.skipIf(interpreters is None, "subinterpreters required")( + meth + ) + + +# ============================================================================ +# Simple wrapper functions for RemoteUnwinder +# ============================================================================ + +# Errors that can occur transiently when reading process memory without synchronization +RETRIABLE_ERRORS = ( + "Task list appears corrupted", + "Invalid linked list structure reading remote memory", + "Unknown error reading memory", + "Unhandled frame owner", + "Failed to parse initial frame", + "Failed to process frame chain", + "Failed to unwind stack", +) + + +def _is_retriable_error(exc): + """Check if an exception is a transient error that should be retried.""" + msg = str(exc) + return any(msg.startswith(err) or err in msg for err in RETRIABLE_ERRORS) def get_stack_trace(pid): - unwinder = RemoteUnwinder(pid, all_threads=True, debug=True) - return unwinder.get_stack_trace() + for _ in busy_retry(SHORT_TIMEOUT): + try: + unwinder = RemoteUnwinder(pid, all_threads=True, debug=True) + return unwinder.get_stack_trace() + except RuntimeError as e: + if _is_retriable_error(e): + continue + raise + raise RuntimeError("Failed to get stack trace after retries") def get_async_stack_trace(pid): - unwinder = RemoteUnwinder(pid, debug=True) - return unwinder.get_async_stack_trace() + for _ in busy_retry(SHORT_TIMEOUT): + try: + unwinder = RemoteUnwinder(pid, debug=True) + return unwinder.get_async_stack_trace() + except RuntimeError as e: + if _is_retriable_error(e): + continue + raise + raise RuntimeError("Failed to get async stack trace after retries") def get_all_awaited_by(pid): - unwinder = RemoteUnwinder(pid, debug=True) - return unwinder.get_all_awaited_by() + for _ in busy_retry(SHORT_TIMEOUT): + try: + unwinder = RemoteUnwinder(pid, debug=True) + return unwinder.get_all_awaited_by() + except RuntimeError as e: + if _is_retriable_error(e): + continue + raise + raise RuntimeError("Failed to get all awaited_by after retries") -class TestGetStackTrace(unittest.TestCase): +# ============================================================================ +# Base test class with shared infrastructure +# ============================================================================ + + +class RemoteInspectionTestBase(unittest.TestCase): + """Base class for remote inspection tests with common helpers.""" + maxDiff = None + def _run_script_and_get_trace( + self, + script, + trace_func, + wait_for_signals=None, + port=None, + backlog=1, + ): + """ + Common pattern: run a script, wait for signals, get trace. + + Args: + script: Script content (will be formatted with port if {port} present) + trace_func: Function to call with pid to get trace (e.g., get_stack_trace) + wait_for_signals: Signal(s) to wait for before getting trace + port: Port to use (auto-selected if None) + backlog: Socket listen backlog + + Returns: + tuple: (trace_result, script_name) + """ + if port is None: + port = find_unused_port() + + # Format script with port if needed + if "{port}" in script or "{{port}}" in script: + script = script.replace("{{port}}", "{port}").format(port=port) + + with os_helper.temp_dir() as work_dir: + script_dir = os.path.join(work_dir, "script_pkg") + os.mkdir(script_dir) + + server_socket = _create_server_socket(port, backlog) + script_name = _make_test_script(script_dir, "script", script) + client_socket = None + + try: + with _managed_subprocess([sys.executable, script_name]) as p: + client_socket, _ = server_socket.accept() + server_socket.close() + server_socket = None + + if wait_for_signals: + _wait_for_signal(client_socket, wait_for_signals) + + try: + trace = trace_func(p.pid) + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) + return trace, script_name + finally: + _cleanup_sockets(client_socket, server_socket) + + def _find_frame_in_trace(self, stack_trace, predicate): + """ + Find a frame matching predicate in stack trace. + + Args: + stack_trace: List of InterpreterInfo objects + predicate: Function(frame) -> bool + + Returns: + FrameInfo or None + """ + for interpreter_info in stack_trace: + for thread_info in interpreter_info.threads: + for frame in thread_info.frame_info: + if predicate(frame): + return frame + return None + + def _find_thread_by_id(self, stack_trace, thread_id): + """Find a thread by its native thread ID.""" + for interpreter_info in stack_trace: + for thread_info in interpreter_info.threads: + if thread_info.thread_id == thread_id: + return thread_info + return None + + def _find_thread_with_frame(self, stack_trace, frame_predicate): + """Find a thread containing a frame matching predicate.""" + for interpreter_info in stack_trace: + for thread_info in interpreter_info.threads: + for frame in thread_info.frame_info: + if frame_predicate(frame): + return thread_info + return None + + def _get_thread_statuses(self, stack_trace): + """Extract thread_id -> status mapping from stack trace.""" + statuses = {} + for interpreter_info in stack_trace: + for thread_info in interpreter_info.threads: + statuses[thread_info.thread_id] = thread_info.status + return statuses + + def _get_task_id_map(self, stack_trace): + """Create task_id -> task mapping from async stack trace.""" + return {task.task_id: task for task in stack_trace[0].awaited_by} + + def _get_awaited_by_relationships(self, stack_trace): + """Extract task name to awaited_by set mapping.""" + id_to_task = self._get_task_id_map(stack_trace) + return { + task.task_name: set( + id_to_task[awaited.task_name].task_name + for awaited in task.awaited_by + ) + for task in stack_trace[0].awaited_by + } + + def _extract_coroutine_stacks(self, stack_trace): + """Extract and format coroutine stacks from tasks.""" + return { + task.task_name: sorted( + tuple(tuple(frame) for frame in coro.call_stack) + for coro in task.coroutine_stack + ) + for task in stack_trace[0].awaited_by + } + + +# ============================================================================ +# Test classes +# ============================================================================ + + +class TestGetStackTrace(RemoteInspectionTestBase): @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_remote_stack_trace(self): - # Spawn a process with some realistic Python code port = find_unused_port() script = textwrap.dedent( f"""\ import time, sys, socket, threading - # Connect to the test process + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) @@ -112,80 +433,78 @@ def baz(): foo() def foo(): - sock.sendall(b"ready:thread\\n"); time.sleep(10_000) # same line number + sock.sendall(b"ready:thread\\n"); time.sleep(10_000) t = threading.Thread(target=bar) t.start() - sock.sendall(b"ready:main\\n"); t.join() # same line number + sock.sendall(b"ready:main\\n"); t.join() """ ) - stack_trace = None + with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - # Create a socket server to communicate with the target process - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(("localhost", port)) - server_socket.settimeout(SHORT_TIMEOUT) - server_socket.listen(1) - + server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None + try: - p = subprocess.Popen([sys.executable, script_name]) - client_socket, _ = server_socket.accept() - server_socket.close() - response = b"" - while ( - b"ready:main" not in response - or b"ready:thread" not in response - ): - response += client_socket.recv(1024) - stack_trace = get_stack_trace(p.pid) - except PermissionError: - self.skipTest( - "Insufficient permissions to read the stack trace" - ) + with _managed_subprocess([sys.executable, script_name]) as p: + client_socket, _ = server_socket.accept() + server_socket.close() + server_socket = None + + _wait_for_signal( + client_socket, [b"ready:main", b"ready:thread"] + ) + + try: + stack_trace = get_stack_trace(p.pid) + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) + + thread_expected_stack_trace = [ + FrameInfo([script_name, 15, "foo"]), + FrameInfo([script_name, 12, "baz"]), + FrameInfo([script_name, 9, "bar"]), + FrameInfo([threading.__file__, ANY, "Thread.run"]), + FrameInfo( + [ + threading.__file__, + ANY, + "Thread._bootstrap_inner", + ] + ), + FrameInfo( + [threading.__file__, ANY, "Thread._bootstrap"] + ), + ] + + # Find expected thread stack + found_thread = self._find_thread_with_frame( + stack_trace, + lambda f: f.funcname == "foo" and f.lineno == 15, + ) + self.assertIsNotNone( + found_thread, "Expected thread stack trace not found" + ) + self.assertEqual( + found_thread.frame_info, thread_expected_stack_trace + ) + + # Check main thread + main_frame = FrameInfo([script_name, 19, ""]) + found_main = self._find_frame_in_trace( + stack_trace, lambda f: f == main_frame + ) + self.assertIsNotNone( + found_main, "Main thread stack trace not found" + ) finally: - if client_socket is not None: - client_socket.close() - p.kill() - p.terminate() - p.wait(timeout=SHORT_TIMEOUT) - - thread_expected_stack_trace = [ - FrameInfo([script_name, 15, "foo"]), - FrameInfo([script_name, 12, "baz"]), - FrameInfo([script_name, 9, "bar"]), - FrameInfo([threading.__file__, ANY, "Thread.run"]), - FrameInfo([threading.__file__, ANY, "Thread._bootstrap_inner"]), - FrameInfo([threading.__file__, ANY, "Thread._bootstrap"]), - ] - # Is possible that there are more threads, so we check that the - # expected stack traces are in the result (looking at you Windows!) - found_expected_stack = False - for interpreter_info in stack_trace: - for thread_info in interpreter_info.threads: - if thread_info.frame_info == thread_expected_stack_trace: - found_expected_stack = True - break - if found_expected_stack: - break - self.assertTrue(found_expected_stack, "Expected thread stack trace not found") - - # Check that the main thread stack trace is in the result - frame = FrameInfo([script_name, 19, ""]) - main_thread_found = False - for interpreter_info in stack_trace: - for thread_info in interpreter_info.threads: - if frame in thread_info.frame_info: - main_thread_found = True - break - if main_thread_found: - break - self.assertTrue(main_thread_found, "Main thread stack trace not found in result") + _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( @@ -193,7 +512,6 @@ def foo(): "Test only runs on Linux with process_vm_readv support", ) def test_async_remote_stack_trace(self): - # Spawn a process with some realistic Python code port = find_unused_port() script = textwrap.dedent( f"""\ @@ -201,12 +519,12 @@ def test_async_remote_stack_trace(self): import time import sys import socket - # Connect to the test process + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def c5(): - sock.sendall(b"ready"); time.sleep(10_000) # same line number + sock.sendall(b"ready"); time.sleep(10_000) async def c4(): await asyncio.sleep(0) @@ -237,7 +555,7 @@ def new_eager_loop(): asyncio.run(main(), loop_factory={{TASK_FACTORY}}) """ ) - stack_trace = None + for task_factory_variant in "asyncio.new_event_loop", "new_eager_loop": with ( self.subTest(task_factory_variant=task_factory_variant), @@ -245,195 +563,203 @@ def new_eager_loop(): ): script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - server_socket = socket.socket( - socket.AF_INET, socket.SOCK_STREAM - ) - server_socket.setsockopt( - socket.SOL_SOCKET, socket.SO_REUSEADDR, 1 - ) - server_socket.bind(("localhost", port)) - server_socket.settimeout(SHORT_TIMEOUT) - server_socket.listen(1) + + server_socket = _create_server_socket(port) script_name = _make_test_script( script_dir, "script", script.format(TASK_FACTORY=task_factory_variant), ) client_socket = None + try: - p = subprocess.Popen([sys.executable, script_name]) - client_socket, _ = server_socket.accept() - server_socket.close() - response = client_socket.recv(1024) - self.assertEqual(response, b"ready") - stack_trace = get_async_stack_trace(p.pid) - except PermissionError: - self.skipTest( - "Insufficient permissions to read the stack trace" - ) - finally: - if client_socket is not None: - client_socket.close() - p.kill() - p.terminate() - p.wait(timeout=SHORT_TIMEOUT) + with _managed_subprocess( + [sys.executable, script_name] + ) as p: + client_socket, _ = server_socket.accept() + server_socket.close() + server_socket = None - # First check all the tasks are present - tasks_names = [ - task.task_name for task in stack_trace[0].awaited_by - ] - for task_name in ["c2_root", "sub_main_1", "sub_main_2"]: - self.assertIn(task_name, tasks_names) + response = _wait_for_signal(client_socket, b"ready") + self.assertIn(b"ready", response) - # Now ensure that the awaited_by_relationships are correct - id_to_task = { - task.task_id: task for task in stack_trace[0].awaited_by - } - task_name_to_awaited_by = { - task.task_name: set( - id_to_task[awaited.task_name].task_name - for awaited in task.awaited_by - ) - for task in stack_trace[0].awaited_by - } - self.assertEqual( - task_name_to_awaited_by, - { - "c2_root": {"Task-1", "sub_main_1", "sub_main_2"}, - "Task-1": set(), - "sub_main_1": {"Task-1"}, - "sub_main_2": {"Task-1"}, - }, - ) - - # Now ensure that the coroutine stacks are correct - coroutine_stacks = { - task.task_name: sorted( - tuple(tuple(frame) for frame in coro.call_stack) - for coro in task.coroutine_stack - ) - for task in stack_trace[0].awaited_by - } - self.assertEqual( - coroutine_stacks, - { - "Task-1": [ - ( - tuple( - [ - taskgroups.__file__, - ANY, - "TaskGroup._aexit", - ] - ), - tuple( - [ - taskgroups.__file__, - ANY, - "TaskGroup.__aexit__", - ] - ), - tuple([script_name, 26, "main"]), + try: + stack_trace = get_async_stack_trace(p.pid) + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" ) - ], - "c2_root": [ - ( - tuple([script_name, 10, "c5"]), - tuple([script_name, 14, "c4"]), - tuple([script_name, 17, "c3"]), - tuple([script_name, 20, "c2"]), - ) - ], - "sub_main_1": [(tuple([script_name, 23, "c1"]),)], - "sub_main_2": [(tuple([script_name, 23, "c1"]),)], - }, - ) - # Now ensure the coroutine stacks for the awaited_by relationships are correct. - awaited_by_coroutine_stacks = { - task.task_name: sorted( - ( - id_to_task[coro.task_name].task_name, - tuple(tuple(frame) for frame in coro.call_stack), + # Check all tasks are present + tasks_names = [ + task.task_name + for task in stack_trace[0].awaited_by + ] + for task_name in [ + "c2_root", + "sub_main_1", + "sub_main_2", + ]: + self.assertIn(task_name, tasks_names) + + # Check awaited_by relationships + relationships = self._get_awaited_by_relationships( + stack_trace ) - for coro in task.awaited_by - ) - for task in stack_trace[0].awaited_by - } - self.assertEqual( - awaited_by_coroutine_stacks, - { - "Task-1": [], - "c2_root": [ - ( - "Task-1", + self.assertEqual( + relationships, + { + "c2_root": { + "Task-1", + "sub_main_1", + "sub_main_2", + }, + "Task-1": set(), + "sub_main_1": {"Task-1"}, + "sub_main_2": {"Task-1"}, + }, + ) + + # Check coroutine stacks + coroutine_stacks = self._extract_coroutine_stacks( + stack_trace + ) + self.assertEqual( + coroutine_stacks, + { + "Task-1": [ + ( + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + tuple([script_name, 26, "main"]), + ) + ], + "c2_root": [ + ( + tuple([script_name, 10, "c5"]), + tuple([script_name, 14, "c4"]), + tuple([script_name, 17, "c3"]), + tuple([script_name, 20, "c2"]), + ) + ], + "sub_main_1": [ + (tuple([script_name, 23, "c1"]),) + ], + "sub_main_2": [ + (tuple([script_name, 23, "c1"]),) + ], + }, + ) + + # Check awaited_by coroutine stacks + id_to_task = self._get_task_id_map(stack_trace) + awaited_by_coroutine_stacks = { + task.task_name: sorted( ( + id_to_task[coro.task_name].task_name, tuple( - [ - taskgroups.__file__, - ANY, - "TaskGroup._aexit", - ] + tuple(frame) + for frame in coro.call_stack ), - tuple( - [ - taskgroups.__file__, - ANY, - "TaskGroup.__aexit__", - ] - ), - tuple([script_name, 26, "main"]), - ), - ), - ("sub_main_1", (tuple([script_name, 23, "c1"]),)), - ("sub_main_2", (tuple([script_name, 23, "c1"]),)), - ], - "sub_main_1": [ - ( - "Task-1", - ( - tuple( - [ - taskgroups.__file__, - ANY, - "TaskGroup._aexit", - ] - ), - tuple( - [ - taskgroups.__file__, - ANY, - "TaskGroup.__aexit__", - ] - ), - tuple([script_name, 26, "main"]), - ), + ) + for coro in task.awaited_by ) - ], - "sub_main_2": [ - ( - "Task-1", - ( - tuple( - [ - taskgroups.__file__, - ANY, - "TaskGroup._aexit", - ] + for task in stack_trace[0].awaited_by + } + self.assertEqual( + awaited_by_coroutine_stacks, + { + "Task-1": [], + "c2_root": [ + ( + "Task-1", + ( + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + tuple([script_name, 26, "main"]), + ), ), - tuple( - [ - taskgroups.__file__, - ANY, - "TaskGroup.__aexit__", - ] + ( + "sub_main_1", + (tuple([script_name, 23, "c1"]),), ), - tuple([script_name, 26, "main"]), - ), - ) - ], - }, - ) + ( + "sub_main_2", + (tuple([script_name, 23, "c1"]),), + ), + ], + "sub_main_1": [ + ( + "Task-1", + ( + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + tuple([script_name, 26, "main"]), + ), + ) + ], + "sub_main_2": [ + ( + "Task-1", + ( + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + tuple([script_name, 26, "main"]), + ), + ) + ], + }, + ) + finally: + _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( @@ -441,7 +767,6 @@ def new_eager_loop(): "Test only runs on Linux with process_vm_readv support", ) def test_asyncgen_remote_stack_trace(self): - # Spawn a process with some realistic Python code port = find_unused_port() script = textwrap.dedent( f"""\ @@ -449,12 +774,12 @@ def test_asyncgen_remote_stack_trace(self): import time import sys import socket - # Connect to the test process + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) async def gen_nested_call(): - sock.sendall(b"ready"); time.sleep(10_000) # same line number + sock.sendall(b"ready"); time.sleep(10_000) async def gen(): for num in range(2): @@ -469,59 +794,56 @@ async def main(): asyncio.run(main()) """ ) - stack_trace = None + with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - # Create a socket server to communicate with the target process - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(("localhost", port)) - server_socket.settimeout(SHORT_TIMEOUT) - server_socket.listen(1) + + server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None + try: - p = subprocess.Popen([sys.executable, script_name]) - client_socket, _ = server_socket.accept() - server_socket.close() - response = client_socket.recv(1024) - self.assertEqual(response, b"ready") - stack_trace = get_async_stack_trace(p.pid) - except PermissionError: - self.skipTest( - "Insufficient permissions to read the stack trace" - ) - finally: - if client_socket is not None: - client_socket.close() - p.kill() - p.terminate() - p.wait(timeout=SHORT_TIMEOUT) + with _managed_subprocess([sys.executable, script_name]) as p: + client_socket, _ = server_socket.accept() + server_socket.close() + server_socket = None - # For this simple asyncgen test, we only expect one task with the full coroutine stack - self.assertEqual(len(stack_trace[0].awaited_by), 1) - task = stack_trace[0].awaited_by[0] - self.assertEqual(task.task_name, "Task-1") + response = _wait_for_signal(client_socket, b"ready") + self.assertIn(b"ready", response) - # Check the coroutine stack - based on actual output, only shows main - coroutine_stack = sorted( - tuple(tuple(frame) for frame in coro.call_stack) - for coro in task.coroutine_stack - ) - self.assertEqual( - coroutine_stack, - [ - ( - tuple([script_name, 10, "gen_nested_call"]), - tuple([script_name, 16, "gen"]), - tuple([script_name, 19, "main"]), + try: + stack_trace = get_async_stack_trace(p.pid) + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) + + # For this simple asyncgen test, we only expect one task + self.assertEqual(len(stack_trace[0].awaited_by), 1) + task = stack_trace[0].awaited_by[0] + self.assertEqual(task.task_name, "Task-1") + + # Check the coroutine stack + coroutine_stack = sorted( + tuple(tuple(frame) for frame in coro.call_stack) + for coro in task.coroutine_stack + ) + self.assertEqual( + coroutine_stack, + [ + ( + tuple([script_name, 10, "gen_nested_call"]), + tuple([script_name, 16, "gen"]), + tuple([script_name, 19, "main"]), + ) + ], ) - ], - ) - # No awaited_by relationships expected for this simple case - self.assertEqual(task.awaited_by, []) + # No awaited_by relationships expected + self.assertEqual(task.awaited_by, []) + finally: + _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( @@ -529,7 +851,6 @@ async def main(): "Test only runs on Linux with process_vm_readv support", ) def test_async_gather_remote_stack_trace(self): - # Spawn a process with some realistic Python code port = find_unused_port() script = textwrap.dedent( f"""\ @@ -537,13 +858,13 @@ def test_async_gather_remote_stack_trace(self): import time import sys import socket - # Connect to the test process + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) async def deep(): await asyncio.sleep(0) - sock.sendall(b"ready"); time.sleep(10_000) # same line number + sock.sendall(b"ready"); time.sleep(10_000) async def c1(): await asyncio.sleep(0) @@ -558,103 +879,92 @@ async def main(): asyncio.run(main()) """ ) - stack_trace = None + with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - # Create a socket server to communicate with the target process - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(("localhost", port)) - server_socket.settimeout(SHORT_TIMEOUT) - server_socket.listen(1) + + server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None + try: - p = subprocess.Popen([sys.executable, script_name]) - client_socket, _ = server_socket.accept() - server_socket.close() - response = client_socket.recv(1024) - self.assertEqual(response, b"ready") - stack_trace = get_async_stack_trace(p.pid) - except PermissionError: - self.skipTest( - "Insufficient permissions to read the stack trace" - ) - finally: - if client_socket is not None: - client_socket.close() - p.kill() - p.terminate() - p.wait(timeout=SHORT_TIMEOUT) + with _managed_subprocess([sys.executable, script_name]) as p: + client_socket, _ = server_socket.accept() + server_socket.close() + server_socket = None - # First check all the tasks are present - tasks_names = [ - task.task_name for task in stack_trace[0].awaited_by - ] - for task_name in ["Task-1", "Task-2"]: - self.assertIn(task_name, tasks_names) + response = _wait_for_signal(client_socket, b"ready") + self.assertIn(b"ready", response) - # Now ensure that the awaited_by_relationships are correct - id_to_task = { - task.task_id: task for task in stack_trace[0].awaited_by - } - task_name_to_awaited_by = { - task.task_name: set( - id_to_task[awaited.task_name].task_name - for awaited in task.awaited_by - ) - for task in stack_trace[0].awaited_by - } - self.assertEqual( - task_name_to_awaited_by, - { - "Task-1": set(), - "Task-2": {"Task-1"}, - }, - ) - - # Now ensure that the coroutine stacks are correct - coroutine_stacks = { - task.task_name: sorted( - tuple(tuple(frame) for frame in coro.call_stack) - for coro in task.coroutine_stack - ) - for task in stack_trace[0].awaited_by - } - self.assertEqual( - coroutine_stacks, - { - "Task-1": [(tuple([script_name, 21, "main"]),)], - "Task-2": [ - ( - tuple([script_name, 11, "deep"]), - tuple([script_name, 15, "c1"]), + try: + stack_trace = get_async_stack_trace(p.pid) + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" ) - ], - }, - ) - # Now ensure the coroutine stacks for the awaited_by relationships are correct. - awaited_by_coroutine_stacks = { - task.task_name: sorted( - ( - id_to_task[coro.task_name].task_name, - tuple(tuple(frame) for frame in coro.call_stack), + # Check all tasks are present + tasks_names = [ + task.task_name for task in stack_trace[0].awaited_by + ] + for task_name in ["Task-1", "Task-2"]: + self.assertIn(task_name, tasks_names) + + # Check awaited_by relationships + relationships = self._get_awaited_by_relationships( + stack_trace ) - for coro in task.awaited_by - ) - for task in stack_trace[0].awaited_by - } - self.assertEqual( - awaited_by_coroutine_stacks, - { - "Task-1": [], - "Task-2": [ - ("Task-1", (tuple([script_name, 21, "main"]),)) - ], - }, - ) + self.assertEqual( + relationships, + { + "Task-1": set(), + "Task-2": {"Task-1"}, + }, + ) + + # Check coroutine stacks + coroutine_stacks = self._extract_coroutine_stacks( + stack_trace + ) + self.assertEqual( + coroutine_stacks, + { + "Task-1": [(tuple([script_name, 21, "main"]),)], + "Task-2": [ + ( + tuple([script_name, 11, "deep"]), + tuple([script_name, 15, "c1"]), + ) + ], + }, + ) + + # Check awaited_by coroutine stacks + id_to_task = self._get_task_id_map(stack_trace) + awaited_by_coroutine_stacks = { + task.task_name: sorted( + ( + id_to_task[coro.task_name].task_name, + tuple( + tuple(frame) for frame in coro.call_stack + ), + ) + for coro in task.awaited_by + ) + for task in stack_trace[0].awaited_by + } + self.assertEqual( + awaited_by_coroutine_stacks, + { + "Task-1": [], + "Task-2": [ + ("Task-1", (tuple([script_name, 21, "main"]),)) + ], + }, + ) + finally: + _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( @@ -662,7 +972,6 @@ async def main(): "Test only runs on Linux with process_vm_readv support", ) def test_async_staggered_race_remote_stack_trace(self): - # Spawn a process with some realistic Python code port = find_unused_port() script = textwrap.dedent( f"""\ @@ -670,13 +979,13 @@ def test_async_staggered_race_remote_stack_trace(self): import time import sys import socket - # Connect to the test process + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) async def deep(): await asyncio.sleep(0) - sock.sendall(b"ready"); time.sleep(10_000) # same line number + sock.sendall(b"ready"); time.sleep(10_000) async def c1(): await asyncio.sleep(0) @@ -694,123 +1003,122 @@ async def main(): asyncio.run(main()) """ ) - stack_trace = None + with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - # Create a socket server to communicate with the target process - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(("localhost", port)) - server_socket.settimeout(SHORT_TIMEOUT) - server_socket.listen(1) + + server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None + try: - p = subprocess.Popen([sys.executable, script_name]) - client_socket, _ = server_socket.accept() - server_socket.close() - response = client_socket.recv(1024) - self.assertEqual(response, b"ready") - stack_trace = get_async_stack_trace(p.pid) - except PermissionError: - self.skipTest( - "Insufficient permissions to read the stack trace" - ) - finally: - if client_socket is not None: - client_socket.close() - p.kill() - p.terminate() - p.wait(timeout=SHORT_TIMEOUT) + with _managed_subprocess([sys.executable, script_name]) as p: + client_socket, _ = server_socket.accept() + server_socket.close() + server_socket = None - # First check all the tasks are present - tasks_names = [ - task.task_name for task in stack_trace[0].awaited_by - ] - for task_name in ["Task-1", "Task-2"]: - self.assertIn(task_name, tasks_names) + response = _wait_for_signal(client_socket, b"ready") + self.assertIn(b"ready", response) - # Now ensure that the awaited_by_relationships are correct - id_to_task = { - task.task_id: task for task in stack_trace[0].awaited_by - } - task_name_to_awaited_by = { - task.task_name: set( - id_to_task[awaited.task_name].task_name - for awaited in task.awaited_by - ) - for task in stack_trace[0].awaited_by - } - self.assertEqual( - task_name_to_awaited_by, - { - "Task-1": set(), - "Task-2": {"Task-1"}, - }, - ) - - # Now ensure that the coroutine stacks are correct - coroutine_stacks = { - task.task_name: sorted( - tuple(tuple(frame) for frame in coro.call_stack) - for coro in task.coroutine_stack - ) - for task in stack_trace[0].awaited_by - } - self.assertEqual( - coroutine_stacks, - { - "Task-1": [ - ( - tuple([staggered.__file__, ANY, "staggered_race"]), - tuple([script_name, 21, "main"]), + try: + stack_trace = get_async_stack_trace(p.pid) + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" ) - ], - "Task-2": [ - ( - tuple([script_name, 11, "deep"]), - tuple([script_name, 15, "c1"]), - tuple( - [ - staggered.__file__, - ANY, - "staggered_race..run_one_coro", - ] - ), - ) - ], - }, - ) - # Now ensure the coroutine stacks for the awaited_by relationships are correct. - awaited_by_coroutine_stacks = { - task.task_name: sorted( - ( - id_to_task[coro.task_name].task_name, - tuple(tuple(frame) for frame in coro.call_stack), + # Check all tasks are present + tasks_names = [ + task.task_name for task in stack_trace[0].awaited_by + ] + for task_name in ["Task-1", "Task-2"]: + self.assertIn(task_name, tasks_names) + + # Check awaited_by relationships + relationships = self._get_awaited_by_relationships( + stack_trace ) - for coro in task.awaited_by - ) - for task in stack_trace[0].awaited_by - } - self.assertEqual( - awaited_by_coroutine_stacks, - { - "Task-1": [], - "Task-2": [ - ( - "Task-1", + self.assertEqual( + relationships, + { + "Task-1": set(), + "Task-2": {"Task-1"}, + }, + ) + + # Check coroutine stacks + coroutine_stacks = self._extract_coroutine_stacks( + stack_trace + ) + self.assertEqual( + coroutine_stacks, + { + "Task-1": [ + ( + tuple( + [ + staggered.__file__, + ANY, + "staggered_race", + ] + ), + tuple([script_name, 21, "main"]), + ) + ], + "Task-2": [ + ( + tuple([script_name, 11, "deep"]), + tuple([script_name, 15, "c1"]), + tuple( + [ + staggered.__file__, + ANY, + "staggered_race..run_one_coro", + ] + ), + ) + ], + }, + ) + + # Check awaited_by coroutine stacks + id_to_task = self._get_task_id_map(stack_trace) + awaited_by_coroutine_stacks = { + task.task_name: sorted( ( + id_to_task[coro.task_name].task_name, tuple( - [staggered.__file__, ANY, "staggered_race"] + tuple(frame) for frame in coro.call_stack ), - tuple([script_name, 21, "main"]), - ), + ) + for coro in task.awaited_by ) - ], - }, - ) + for task in stack_trace[0].awaited_by + } + self.assertEqual( + awaited_by_coroutine_stacks, + { + "Task-1": [], + "Task-2": [ + ( + "Task-1", + ( + tuple( + [ + staggered.__file__, + ANY, + "staggered_race", + ] + ), + tuple([script_name, 21, "main"]), + ), + ) + ], + }, + ) + finally: + _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( @@ -818,6 +1126,10 @@ async def main(): "Test only runs on Linux with process_vm_readv support", ) def test_async_global_awaited_by(self): + # Reduced from 1000 to 100 to avoid file descriptor exhaustion + # when running tests in parallel (e.g., -j 20) + NUM_TASKS = 100 + port = find_unused_port() script = textwrap.dedent( f"""\ @@ -833,7 +1145,6 @@ def test_async_global_awaited_by(self): PORT = socket_helper.find_unused_port() connections = 0 - # Connect to the test process sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) @@ -856,23 +1167,16 @@ async def echo_client(message): assert message == data.decode() writer.close() await writer.wait_closed() - # Signal we are ready to sleep sock.sendall(b"ready") await asyncio.sleep(SHORT_TIMEOUT) async def echo_client_spam(server): async with asyncio.TaskGroup() as tg: - while connections < 1000: + while connections < {NUM_TASKS}: msg = list(ascii_lowercase + digits) random.shuffle(msg) tg.create_task(echo_client("".join(msg))) await asyncio.sleep(0) - # at least a 1000 tasks created. Each task will signal - # when is ready to avoid the race caused by the fact that - # tasks are waited on tg.__exit__ and we cannot signal when - # that happens otherwise - # at this point all client tasks completed without assertion errors - # let's wrap up the test server.close() await server.wait_closed() @@ -887,231 +1191,216 @@ async def main(): asyncio.run(main()) """ ) - stack_trace = None + with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - # Create a socket server to communicate with the target process - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(("localhost", port)) - server_socket.settimeout(SHORT_TIMEOUT) - server_socket.listen(1) + + server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None + try: - p = subprocess.Popen([sys.executable, script_name]) - client_socket, _ = server_socket.accept() - server_socket.close() - for _ in range(1000): - expected_response = b"ready" - response = client_socket.recv(len(expected_response)) - self.assertEqual(response, expected_response) - for _ in busy_retry(SHORT_TIMEOUT): + with _managed_subprocess([sys.executable, script_name]) as p: + client_socket, _ = server_socket.accept() + server_socket.close() + server_socket = None + + # Wait for NUM_TASKS "ready" signals + try: + _wait_for_n_signals(client_socket, b"ready", NUM_TASKS) + except RuntimeError as e: + self.fail(str(e)) + try: all_awaited_by = get_all_awaited_by(p.pid) - except RuntimeError as re: - # This call reads a linked list in another process with - # no synchronization. That occasionally leads to invalid - # reads. Here we avoid making the test flaky. - msg = str(re) - if msg.startswith("Task list appears corrupted"): - continue - elif msg.startswith( - "Invalid linked list structure reading remote memory" - ): - continue - elif msg.startswith("Unknown error reading memory"): - continue - elif msg.startswith("Unhandled frame owner"): - continue - raise # Unrecognized exception, safest not to ignore it - else: - break - # expected: a list of two elements: 1 thread, 1 interp - self.assertEqual(len(all_awaited_by), 2) - # expected: a tuple with the thread ID and the awaited_by list - self.assertEqual(len(all_awaited_by[0]), 2) - # expected: no tasks in the fallback per-interp task list - self.assertEqual(all_awaited_by[1], (0, [])) - entries = all_awaited_by[0][1] - # expected: at least 1000 pending tasks - self.assertGreaterEqual(len(entries), 1000) - # the first three tasks stem from the code structure - main_stack = [ - FrameInfo([taskgroups.__file__, ANY, "TaskGroup._aexit"]), - FrameInfo( - [taskgroups.__file__, ANY, "TaskGroup.__aexit__"] - ), - FrameInfo([script_name, 60, "main"]), - ] - self.assertIn( - TaskInfo( - [ANY, "Task-1", [CoroInfo([main_stack, ANY])], []] - ), - entries, - ) - self.assertIn( - TaskInfo( - [ - ANY, - "server task", - [ - CoroInfo( - [ - [ - FrameInfo( - [ - base_events.__file__, - ANY, - "Server.serve_forever", - ] - ) - ], - ANY, - ] - ) - ], - [ - CoroInfo( - [ - [ - FrameInfo( - [ - taskgroups.__file__, - ANY, - "TaskGroup._aexit", - ] - ), - FrameInfo( - [ - taskgroups.__file__, - ANY, - "TaskGroup.__aexit__", - ] - ), - FrameInfo( - [script_name, ANY, "main"] - ), - ], - ANY, - ] - ) - ], - ] - ), - entries, - ) - self.assertIn( - TaskInfo( - [ - ANY, - "Task-4", - [ - CoroInfo( - [ - [ - FrameInfo( - [tasks.__file__, ANY, "sleep"] - ), - FrameInfo( - [ - script_name, - 38, - "echo_client", - ] - ), - ], - ANY, - ] - ) - ], - [ - CoroInfo( - [ - [ - FrameInfo( - [ - taskgroups.__file__, - ANY, - "TaskGroup._aexit", - ] - ), - FrameInfo( - [ - taskgroups.__file__, - ANY, - "TaskGroup.__aexit__", - ] - ), - FrameInfo( - [ - script_name, - 41, - "echo_client_spam", - ] - ), - ], - ANY, - ] - ) - ], - ] - ), - entries, - ) + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) - expected_awaited_by = [ - CoroInfo( - [ - [ - FrameInfo( - [ - taskgroups.__file__, - ANY, - "TaskGroup._aexit", - ] - ), - FrameInfo( - [ - taskgroups.__file__, - ANY, - "TaskGroup.__aexit__", - ] - ), - FrameInfo( - [script_name, 41, "echo_client_spam"] - ), - ], - ANY, - ] - ) - ] - tasks_with_awaited = [ - task - for task in entries - if task.awaited_by == expected_awaited_by - ] - self.assertGreaterEqual(len(tasks_with_awaited), 1000) + # Expected: a list of two elements: 1 thread, 1 interp + self.assertEqual(len(all_awaited_by), 2) + # Expected: a tuple with the thread ID and the awaited_by list + self.assertEqual(len(all_awaited_by[0]), 2) + # Expected: no tasks in the fallback per-interp task list + self.assertEqual(all_awaited_by[1], (0, [])) - # the final task will have some random number, but it should for - # sure be one of the echo client spam horde (In windows this is not true - # for some reason) - if sys.platform != "win32": - self.assertEqual( - tasks_with_awaited[-1].awaited_by, - entries[-1].awaited_by, + entries = all_awaited_by[0][1] + # Expected: at least NUM_TASKS pending tasks + self.assertGreaterEqual(len(entries), NUM_TASKS) + + # Check the main task structure + main_stack = [ + FrameInfo( + [taskgroups.__file__, ANY, "TaskGroup._aexit"] + ), + FrameInfo( + [taskgroups.__file__, ANY, "TaskGroup.__aexit__"] + ), + FrameInfo([script_name, 52, "main"]), + ] + self.assertIn( + TaskInfo( + [ANY, "Task-1", [CoroInfo([main_stack, ANY])], []] + ), + entries, ) - except PermissionError: - self.skipTest( - "Insufficient permissions to read the stack trace" - ) + self.assertIn( + TaskInfo( + [ + ANY, + "server task", + [ + CoroInfo( + [ + [ + FrameInfo( + [ + base_events.__file__, + ANY, + "Server.serve_forever", + ] + ) + ], + ANY, + ] + ) + ], + [ + CoroInfo( + [ + [ + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + FrameInfo( + [script_name, ANY, "main"] + ), + ], + ANY, + ] + ) + ], + ] + ), + entries, + ) + self.assertIn( + TaskInfo( + [ + ANY, + "Task-4", + [ + CoroInfo( + [ + [ + FrameInfo( + [ + tasks.__file__, + ANY, + "sleep", + ] + ), + FrameInfo( + [ + script_name, + 36, + "echo_client", + ] + ), + ], + ANY, + ] + ) + ], + [ + CoroInfo( + [ + [ + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + FrameInfo( + [ + script_name, + 39, + "echo_client_spam", + ] + ), + ], + ANY, + ] + ) + ], + ] + ), + entries, + ) + + expected_awaited_by = [ + CoroInfo( + [ + [ + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + FrameInfo( + [script_name, 39, "echo_client_spam"] + ), + ], + ANY, + ] + ) + ] + tasks_with_awaited = [ + task + for task in entries + if task.awaited_by == expected_awaited_by + ] + self.assertGreaterEqual(len(tasks_with_awaited), NUM_TASKS) + + # Final task should be from echo client spam (not on Windows) + if sys.platform != "win32": + self.assertEqual( + tasks_with_awaited[-1].awaited_by, + entries[-1].awaited_by, + ) finally: - if client_socket is not None: - client_socket.close() - p.kill() - p.terminate() - p.wait(timeout=SHORT_TIMEOUT) + _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( @@ -1120,25 +1409,24 @@ async def main(): ) def test_self_trace(self): stack_trace = get_stack_trace(os.getpid()) - # Is possible that there are more threads, so we check that the - # expected stack traces are in the result (looking at you Windows!) - this_tread_stack = None - # New format: [InterpreterInfo(interpreter_id, [ThreadInfo(...)])] + + this_thread_stack = None for interpreter_info in stack_trace: for thread_info in interpreter_info.threads: if thread_info.thread_id == threading.get_native_id(): - this_tread_stack = thread_info.frame_info + this_thread_stack = thread_info.frame_info break - if this_tread_stack: + if this_thread_stack: break - self.assertIsNotNone(this_tread_stack) + + self.assertIsNotNone(this_thread_stack) self.assertEqual( - this_tread_stack[:2], + this_thread_stack[:2], [ FrameInfo( [ __file__, - get_stack_trace.__code__.co_firstlineno + 2, + get_stack_trace.__code__.co_firstlineno + 4, "get_stack_trace", ] ), @@ -1159,12 +1447,11 @@ def test_self_trace(self): ) @requires_subinterpreters def test_subinterpreter_stack_trace(self): - # Test that subinterpreters are correctly handled port = find_unused_port() - # Calculate subinterpreter code separately and pickle it to avoid f-string issues import pickle - subinterp_code = textwrap.dedent(f''' + + subinterp_code = textwrap.dedent(f""" import socket import time @@ -1177,9 +1464,8 @@ def nested_func(): nested_func() sub_worker() - ''').strip() + """).strip() - # Pickle the subinterpreter code pickled_code = pickle.dumps(subinterp_code) script = textwrap.dedent( @@ -1190,33 +1476,26 @@ def nested_func(): import socket import threading - # Connect to the test process sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def main_worker(): - # Function running in main interpreter sock.sendall(b"ready:main\\n") time.sleep(10_000) def run_subinterp(): - # Create and run subinterpreter subinterp = interpreters.create() - import pickle pickled_code = {pickled_code!r} subinterp_code = pickle.loads(pickled_code) subinterp.exec(subinterp_code) - # Start subinterpreter in thread sub_thread = threading.Thread(target=run_subinterp) sub_thread.start() - # Start main thread work main_thread = threading.Thread(target=main_worker) main_thread.start() - # Keep main thread alive main_thread.join() sub_thread.join() """ @@ -1226,85 +1505,74 @@ def run_subinterp(): script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - # Create a socket server to communicate with the target process - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(("localhost", port)) - server_socket.settimeout(SHORT_TIMEOUT) - server_socket.listen(1) - + server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_sockets = [] + try: - p = subprocess.Popen([sys.executable, script_name]) - - # Accept connections from both main and subinterpreter - responses = set() - while len(responses) < 2: # Wait for both "ready:main" and "ready:sub" - try: - client_socket, _ = server_socket.accept() - client_sockets.append(client_socket) - - # Read the response from this connection - response = client_socket.recv(1024) - if b"ready:main" in response: - responses.add("main") - if b"ready:sub" in response: - responses.add("sub") - except socket.timeout: - break - - server_socket.close() - stack_trace = get_stack_trace(p.pid) - except PermissionError: - self.skipTest( - "Insufficient permissions to read the stack trace" - ) - finally: - for client_socket in client_sockets: - if client_socket is not None: - client_socket.close() - p.kill() - p.terminate() - p.wait(timeout=SHORT_TIMEOUT) - - # Verify we have multiple interpreters - self.assertGreaterEqual(len(stack_trace), 1, "Should have at least one interpreter") - - # Look for main interpreter (ID 0) and subinterpreter (ID > 0) - main_interp = None - sub_interp = None - - for interpreter_info in stack_trace: - if interpreter_info.interpreter_id == 0: - main_interp = interpreter_info - elif interpreter_info.interpreter_id > 0: - sub_interp = interpreter_info - - self.assertIsNotNone(main_interp, "Main interpreter should be present") - - # Check main interpreter has expected stack trace - main_found = False - for thread_info in main_interp.threads: - for frame in thread_info.frame_info: - if frame.funcname == "main_worker": - main_found = True - break - if main_found: - break - self.assertTrue(main_found, "Main interpreter should have main_worker in stack") - - # If subinterpreter is present, check its stack trace - if sub_interp: - sub_found = False - for thread_info in sub_interp.threads: - for frame in thread_info.frame_info: - if frame.funcname in ("sub_worker", "nested_func"): - sub_found = True + with _managed_subprocess([sys.executable, script_name]) as p: + # Accept connections from both main and subinterpreter + responses = set() + while len(responses) < 2: + try: + client_socket, _ = server_socket.accept() + client_sockets.append(client_socket) + response = client_socket.recv(1024) + if b"ready:main" in response: + responses.add("main") + if b"ready:sub" in response: + responses.add("sub") + except socket.timeout: break - if sub_found: - break - self.assertTrue(sub_found, "Subinterpreter should have sub_worker or nested_func in stack") + + server_socket.close() + server_socket = None + + try: + stack_trace = get_stack_trace(p.pid) + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) + + # Verify we have at least one interpreter + self.assertGreaterEqual(len(stack_trace), 1) + + # Look for main interpreter (ID 0) and subinterpreter (ID > 0) + main_interp = None + sub_interp = None + for interpreter_info in stack_trace: + if interpreter_info.interpreter_id == 0: + main_interp = interpreter_info + elif interpreter_info.interpreter_id > 0: + sub_interp = interpreter_info + + self.assertIsNotNone( + main_interp, "Main interpreter should be present" + ) + + # Check main interpreter has expected stack trace + main_found = self._find_frame_in_trace( + [main_interp], lambda f: f.funcname == "main_worker" + ) + self.assertIsNotNone( + main_found, + "Main interpreter should have main_worker in stack", + ) + + # If subinterpreter is present, check its stack trace + if sub_interp: + sub_found = self._find_frame_in_trace( + [sub_interp], + lambda f: f.funcname + in ("sub_worker", "nested_func"), + ) + self.assertIsNotNone( + sub_found, + "Subinterpreter should have sub_worker or nested_func in stack", + ) + finally: + _cleanup_sockets(*client_sockets, server_socket) @skip_if_not_supported @unittest.skipIf( @@ -1313,14 +1581,11 @@ def run_subinterp(): ) @requires_subinterpreters def test_multiple_subinterpreters_with_threads(self): - # Test multiple subinterpreters, each with multiple threads port = find_unused_port() - # Calculate subinterpreter codes separately and pickle them import pickle - # Code for first subinterpreter with 2 threads - subinterp1_code = textwrap.dedent(f''' + subinterp1_code = textwrap.dedent(f""" import socket import time import threading @@ -1347,10 +1612,9 @@ def nested_func(): t2.start() t1.join() t2.join() - ''').strip() + """).strip() - # Code for second subinterpreter with 2 threads - subinterp2_code = textwrap.dedent(f''' + subinterp2_code = textwrap.dedent(f""" import socket import time import threading @@ -1377,9 +1641,8 @@ def nested_func(): t2.start() t1.join() t2.join() - ''').strip() + """).strip() - # Pickle the subinterpreter codes pickled_code1 = pickle.dumps(subinterp1_code) pickled_code2 = pickle.dumps(subinterp2_code) @@ -1391,44 +1654,35 @@ def nested_func(): import socket import threading - # Connect to the test process sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def main_worker(): - # Function running in main interpreter sock.sendall(b"ready:main\\n") time.sleep(10_000) def run_subinterp1(): - # Create and run first subinterpreter subinterp = interpreters.create() - import pickle pickled_code = {pickled_code1!r} subinterp_code = pickle.loads(pickled_code) subinterp.exec(subinterp_code) def run_subinterp2(): - # Create and run second subinterpreter subinterp = interpreters.create() - import pickle pickled_code = {pickled_code2!r} subinterp_code = pickle.loads(pickled_code) subinterp.exec(subinterp_code) - # Start subinterpreters in threads sub1_thread = threading.Thread(target=run_subinterp1) sub2_thread = threading.Thread(target=run_subinterp2) sub1_thread.start() sub2_thread.start() - # Start main thread work main_thread = threading.Thread(target=main_worker) main_thread.start() - # Keep main thread alive main_thread.join() sub1_thread.join() sub2_thread.join() @@ -1439,72 +1693,80 @@ def run_subinterp2(): script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - # Create a socket server to communicate with the target process - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(("localhost", port)) - server_socket.settimeout(SHORT_TIMEOUT) - server_socket.listen(5) # Allow multiple connections - + server_socket = _create_server_socket(port, backlog=5) script_name = _make_test_script(script_dir, "script", script) client_sockets = [] + try: - p = subprocess.Popen([sys.executable, script_name]) + with _managed_subprocess([sys.executable, script_name]) as p: + # Accept connections from main and all subinterpreter threads + expected_responses = { + "ready:main", + "ready:sub1-t1", + "ready:sub1-t2", + "ready:sub2-t1", + "ready:sub2-t2", + } + responses = set() - # Accept connections from main and all subinterpreter threads - expected_responses = {"ready:main", "ready:sub1-t1", "ready:sub1-t2", "ready:sub2-t1", "ready:sub2-t2"} - responses = set() + while len(responses) < 5: + try: + client_socket, _ = server_socket.accept() + client_sockets.append(client_socket) + response = client_socket.recv(1024) + response_str = response.decode().strip() + if response_str in expected_responses: + responses.add(response_str) + except socket.timeout: + break + + server_socket.close() + server_socket = None - while len(responses) < 5: # Wait for all 5 ready signals try: - client_socket, _ = server_socket.accept() - client_sockets.append(client_socket) + stack_trace = get_stack_trace(p.pid) + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) - # Read the response from this connection - response = client_socket.recv(1024) - response_str = response.decode().strip() - if response_str in expected_responses: - responses.add(response_str) - except socket.timeout: - break + # Verify we have multiple interpreters + self.assertGreaterEqual(len(stack_trace), 2) - server_socket.close() - stack_trace = get_stack_trace(p.pid) - except PermissionError: - self.skipTest( - "Insufficient permissions to read the stack trace" - ) + # Count interpreters by ID + interpreter_ids = { + interp.interpreter_id for interp in stack_trace + } + self.assertIn( + 0, + interpreter_ids, + "Main interpreter should be present", + ) + self.assertGreaterEqual(len(interpreter_ids), 3) + + # Count total threads + total_threads = sum( + len(interp.threads) for interp in stack_trace + ) + self.assertGreaterEqual(total_threads, 5) + + # Look for expected function names + all_funcnames = set() + for interpreter_info in stack_trace: + for thread_info in interpreter_info.threads: + for frame in thread_info.frame_info: + all_funcnames.add(frame.funcname) + + expected_funcs = { + "main_worker", + "worker1", + "worker2", + "nested_func", + } + found_funcs = expected_funcs.intersection(all_funcnames) + self.assertGreater(len(found_funcs), 0) finally: - for client_socket in client_sockets: - if client_socket is not None: - client_socket.close() - p.kill() - p.terminate() - p.wait(timeout=SHORT_TIMEOUT) - - # Verify we have multiple interpreters - self.assertGreaterEqual(len(stack_trace), 2, "Should have at least two interpreters") - - # Count interpreters by ID - interpreter_ids = {interp.interpreter_id for interp in stack_trace} - self.assertIn(0, interpreter_ids, "Main interpreter should be present") - self.assertGreaterEqual(len(interpreter_ids), 3, "Should have main + at least 2 subinterpreters") - - # Count total threads across all interpreters - total_threads = sum(len(interp.threads) for interp in stack_trace) - self.assertGreaterEqual(total_threads, 5, "Should have at least 5 threads total") - - # Look for expected function names in stack traces - all_funcnames = set() - for interpreter_info in stack_trace: - for thread_info in interpreter_info.threads: - for frame in thread_info.frame_info: - all_funcnames.add(frame.funcname) - - # Should find functions from different interpreters and threads - expected_funcs = {"main_worker", "worker1", "worker2", "nested_func"} - found_funcs = expected_funcs.intersection(all_funcnames) - self.assertGreater(len(found_funcs), 0, f"Should find some expected functions, got: {all_funcnames}") + _cleanup_sockets(*client_sockets, server_socket) @skip_if_not_supported @unittest.skipIf( @@ -1513,54 +1775,41 @@ def run_subinterp2(): ) @requires_gil_enabled("Free threaded builds don't have an 'active thread'") def test_only_active_thread(self): - # Test that only_active_thread parameter works correctly port = find_unused_port() script = textwrap.dedent( f"""\ import time, sys, socket, threading - # Connect to the test process sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def worker_thread(name, barrier, ready_event): - barrier.wait() # Synchronize thread start - ready_event.wait() # Wait for main thread signal - # Sleep to keep thread alive + barrier.wait() + ready_event.wait() time.sleep(10_000) def main_work(): - # Do busy work to hold the GIL sock.sendall(b"working\\n") count = 0 while count < 100000000: count += 1 if count % 10000000 == 0: - pass # Keep main thread busy + pass sock.sendall(b"done\\n") - # Create synchronization primitives num_threads = 3 - barrier = threading.Barrier(num_threads + 1) # +1 for main thread + barrier = threading.Barrier(num_threads + 1) ready_event = threading.Event() - # Start worker threads threads = [] for i in range(num_threads): t = threading.Thread(target=worker_thread, args=(f"Worker-{{i}}", barrier, ready_event)) t.start() threads.append(t) - # Wait for all threads to be ready barrier.wait() - - # Signal ready to parent process sock.sendall(b"ready\\n") - - # Signal threads to start waiting ready_event.set() - - # Now do busy work to hold the GIL main_work() """ ) @@ -1569,104 +1818,76 @@ def main_work(): script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - # Create a socket server to communicate with the target process - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(("localhost", port)) - server_socket.settimeout(SHORT_TIMEOUT) - server_socket.listen(1) - + server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None + try: - p = subprocess.Popen([sys.executable, script_name]) - client_socket, _ = server_socket.accept() - server_socket.close() + with _managed_subprocess([sys.executable, script_name]) as p: + client_socket, _ = server_socket.accept() + server_socket.close() + server_socket = None - # Wait for ready signal - response = b"" - while b"ready" not in response: - response += client_socket.recv(1024) + # Wait for ready and working signals + _wait_for_signal(client_socket, [b"ready", b"working"]) - # Wait for the main thread to start its busy work - while b"working" not in response: - response += client_socket.recv(1024) - - # Get stack trace with all threads - unwinder_all = RemoteUnwinder(p.pid, all_threads=True) - for _ in range(10): - # Wait for the main thread to start its busy work - all_traces = unwinder_all.get_stack_trace() - found = False - # New format: [InterpreterInfo(interpreter_id, [ThreadInfo(...)])] - for interpreter_info in all_traces: - for thread_info in interpreter_info.threads: - if not thread_info.frame_info: - continue - current_frame = thread_info.frame_info[0] - if ( - current_frame.funcname == "main_work" - and current_frame.lineno > 15 - ): - found = True + try: + # Get stack trace with all threads + unwinder_all = RemoteUnwinder(p.pid, all_threads=True) + for _ in range(MAX_TRIES): + all_traces = unwinder_all.get_stack_trace() + found = self._find_frame_in_trace( + all_traces, + lambda f: f.funcname == "main_work" + and f.lineno > 12, + ) + if found: break - if found: + time.sleep(0.1) + else: + self.fail( + "Main thread did not start its busy work on time" + ) + + # Get stack trace with only GIL holder + unwinder_gil = RemoteUnwinder( + p.pid, only_active_thread=True + ) + gil_traces = unwinder_gil.get_stack_trace() + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) + + # Count threads + total_threads = sum( + len(interp.threads) for interp in all_traces + ) + self.assertGreater(total_threads, 1) + + total_gil_threads = sum( + len(interp.threads) for interp in gil_traces + ) + self.assertEqual(total_gil_threads, 1) + + # Get the GIL holder thread ID + gil_thread_id = None + for interpreter_info in gil_traces: + if interpreter_info.threads: + gil_thread_id = interpreter_info.threads[ + 0 + ].thread_id break - if found: - break - # Give a bit of time to take the next sample - time.sleep(0.1) - else: - self.fail( - "Main thread did not start its busy work on time" - ) + # Get all thread IDs + all_thread_ids = [] + for interpreter_info in all_traces: + for thread_info in interpreter_info.threads: + all_thread_ids.append(thread_info.thread_id) - # Get stack trace with only GIL holder - unwinder_gil = RemoteUnwinder(p.pid, only_active_thread=True) - gil_traces = unwinder_gil.get_stack_trace() - - except PermissionError: - self.skipTest( - "Insufficient permissions to read the stack trace" - ) + self.assertIn(gil_thread_id, all_thread_ids) finally: - if client_socket is not None: - client_socket.close() - p.kill() - p.terminate() - p.wait(timeout=SHORT_TIMEOUT) - - # Count total threads across all interpreters in all_traces - total_threads = sum(len(interpreter_info.threads) for interpreter_info in all_traces) - self.assertGreater( - total_threads, 1, "Should have multiple threads" - ) - - # Count total threads across all interpreters in gil_traces - total_gil_threads = sum(len(interpreter_info.threads) for interpreter_info in gil_traces) - self.assertEqual( - total_gil_threads, 1, "Should have exactly one GIL holder" - ) - - # Get the GIL holder thread ID - gil_thread_id = None - for interpreter_info in gil_traces: - if interpreter_info.threads: - gil_thread_id = interpreter_info.threads[0].thread_id - break - - # Get all thread IDs from all_traces - all_thread_ids = [] - for interpreter_info in all_traces: - for thread_info in interpreter_info.threads: - all_thread_ids.append(thread_info.thread_id) - - self.assertIn( - gil_thread_id, - all_thread_ids, - "GIL holder should be among all threads", - ) + _cleanup_sockets(client_socket, server_socket) class TestUnsupportedPlatformHandling(unittest.TestCase): @@ -1674,23 +1895,28 @@ class TestUnsupportedPlatformHandling(unittest.TestCase): sys.platform in ("linux", "darwin", "win32"), "Test only runs on unsupported platforms (not Linux, macOS, or Windows)", ) - @unittest.skipIf(sys.platform == "android", "Android raises Linux-specific exception") + @unittest.skipIf( + sys.platform == "android", "Android raises Linux-specific exception" + ) def test_unsupported_platform_error(self): with self.assertRaises(RuntimeError) as cm: RemoteUnwinder(os.getpid()) self.assertIn( "Reading the PyRuntime section is not supported on this platform", - str(cm.exception) + str(cm.exception), ) -class TestDetectionOfThreadStatus(unittest.TestCase): - @unittest.skipIf( - sys.platform not in ("linux", "darwin", "win32"), - "Test only runs on unsupported platforms (not Linux, macOS, or Windows)", - ) - @unittest.skipIf(sys.platform == "android", "Android raises Linux-specific exception") - def test_thread_status_detection(self): + +class TestDetectionOfThreadStatus(RemoteInspectionTestBase): + def _run_thread_status_test(self, mode, check_condition): + """ + Common pattern for thread status detection tests. + + Args: + mode: Profiling mode (PROFILING_MODE_CPU, PROFILING_MODE_GIL, etc.) + check_condition: Function(statuses, sleeper_tid, busy_tid) -> bool + """ port = find_unused_port() script = textwrap.dedent( f"""\ @@ -1723,203 +1949,146 @@ def busy(): sock.close() """ ) + with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(("localhost", port)) - server_socket.settimeout(SHORT_TIMEOUT) - server_socket.listen(1) - script_name = _make_test_script(script_dir, "thread_status_script", script) + server_socket = _create_server_socket(port) + script_name = _make_test_script( + script_dir, "thread_status_script", script + ) client_socket = None + try: - p = subprocess.Popen([sys.executable, script_name]) - client_socket, _ = server_socket.accept() - server_socket.close() - response = b"" - sleeper_tid = None - busy_tid = None - while True: - chunk = client_socket.recv(1024) - response += chunk - if b"ready:main" in response and b"ready:sleeper" in response and b"ready:busy" in response: - # Parse TIDs from the response - for line in response.split(b"\n"): - if line.startswith(b"ready:sleeper:"): - try: - sleeper_tid = int(line.split(b":")[-1]) - except Exception: - pass - elif line.startswith(b"ready:busy:"): - try: - busy_tid = int(line.split(b":")[-1]) - except Exception: - pass - break + with _managed_subprocess([sys.executable, script_name]) as p: + client_socket, _ = server_socket.accept() + server_socket.close() + server_socket = None - attempts = 10 - statuses = {} - try: - unwinder = RemoteUnwinder(p.pid, all_threads=True, mode=PROFILING_MODE_CPU, - skip_non_matching_threads=False) - for _ in range(attempts): - traces = unwinder.get_stack_trace() - # Find threads and their statuses - statuses = {} - for interpreter_info in traces: - for thread_info in interpreter_info.threads: - statuses[thread_info.thread_id] = thread_info.status - - # Check if sleeper thread is off CPU and busy thread is on CPU - # In the new flags system: - # - sleeper should NOT have ON_CPU flag (off CPU) - # - busy should have ON_CPU flag - if (sleeper_tid in statuses and - busy_tid in statuses and - not (statuses[sleeper_tid] & THREAD_STATUS_ON_CPU) and - (statuses[busy_tid] & THREAD_STATUS_ON_CPU)): - break - time.sleep(0.5) # Give a bit of time to let threads settle - except PermissionError: - self.skipTest( - "Insufficient permissions to read the stack trace" + # Wait for all ready signals and parse TIDs + response = _wait_for_signal( + client_socket, + [b"ready:main", b"ready:sleeper", b"ready:busy"], ) - self.assertIsNotNone(sleeper_tid, "Sleeper thread id not received") - self.assertIsNotNone(busy_tid, "Busy thread id not received") - self.assertIn(sleeper_tid, statuses, "Sleeper tid not found in sampled threads") - self.assertIn(busy_tid, statuses, "Busy tid not found in sampled threads") - self.assertFalse(statuses[sleeper_tid] & THREAD_STATUS_ON_CPU, "Sleeper thread should be off CPU") - self.assertTrue(statuses[busy_tid] & THREAD_STATUS_ON_CPU, "Busy thread should be on CPU") + sleeper_tid = None + busy_tid = None + for line in response.split(b"\n"): + if line.startswith(b"ready:sleeper:"): + try: + sleeper_tid = int(line.split(b":")[-1]) + except (ValueError, IndexError): + pass + elif line.startswith(b"ready:busy:"): + try: + busy_tid = int(line.split(b":")[-1]) + except (ValueError, IndexError): + pass - finally: - if client_socket is not None: - client_socket.close() - p.terminate() - p.wait(timeout=SHORT_TIMEOUT) - - @unittest.skipIf( - sys.platform not in ("linux", "darwin", "win32"), - "Test only runs on unsupported platforms (not Linux, macOS, or Windows)", - ) - @unittest.skipIf(sys.platform == "android", "Android raises Linux-specific exception") - def test_thread_status_gil_detection(self): - port = find_unused_port() - script = textwrap.dedent( - f"""\ - import time, sys, socket, threading - import os - - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect(('localhost', {port})) - - def sleeper(): - tid = threading.get_native_id() - sock.sendall(f'ready:sleeper:{{tid}}\\n'.encode()) - time.sleep(10000) - - def busy(): - tid = threading.get_native_id() - sock.sendall(f'ready:busy:{{tid}}\\n'.encode()) - x = 0 - while True: - x = x + 1 - time.sleep(0.5) - - t1 = threading.Thread(target=sleeper) - t2 = threading.Thread(target=busy) - t1.start() - t2.start() - sock.sendall(b'ready:main\\n') - t1.join() - t2.join() - sock.close() - """ - ) - with os_helper.temp_dir() as work_dir: - script_dir = os.path.join(work_dir, "script_pkg") - os.mkdir(script_dir) - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(("localhost", port)) - server_socket.settimeout(SHORT_TIMEOUT) - server_socket.listen(1) - - script_name = _make_test_script(script_dir, "thread_status_script", script) - client_socket = None - try: - p = subprocess.Popen([sys.executable, script_name]) - client_socket, _ = server_socket.accept() - server_socket.close() - response = b"" - sleeper_tid = None - busy_tid = None - while True: - chunk = client_socket.recv(1024) - response += chunk - if b"ready:main" in response and b"ready:sleeper" in response and b"ready:busy" in response: - # Parse TIDs from the response - for line in response.split(b"\n"): - if line.startswith(b"ready:sleeper:"): - try: - sleeper_tid = int(line.split(b":")[-1]) - except Exception: - pass - elif line.startswith(b"ready:busy:"): - try: - busy_tid = int(line.split(b":")[-1]) - except Exception: - pass - break - - attempts = 10 - statuses = {} - try: - unwinder = RemoteUnwinder(p.pid, all_threads=True, mode=PROFILING_MODE_GIL, - skip_non_matching_threads=False) - for _ in range(attempts): - traces = unwinder.get_stack_trace() - # Find threads and their statuses - statuses = {} - for interpreter_info in traces: - for thread_info in interpreter_info.threads: - statuses[thread_info.thread_id] = thread_info.status - - # Check if sleeper thread doesn't have GIL and busy thread has GIL - # In the new flags system: - # - sleeper should NOT have HAS_GIL flag (waiting for GIL) - # - busy should have HAS_GIL flag - if (sleeper_tid in statuses and - busy_tid in statuses and - not (statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL) and - (statuses[busy_tid] & THREAD_STATUS_HAS_GIL)): - break - time.sleep(0.5) # Give a bit of time to let threads settle - except PermissionError: - self.skipTest( - "Insufficient permissions to read the stack trace" + self.assertIsNotNone( + sleeper_tid, "Sleeper thread id not received" + ) + self.assertIsNotNone( + busy_tid, "Busy thread id not received" ) - self.assertIsNotNone(sleeper_tid, "Sleeper thread id not received") - self.assertIsNotNone(busy_tid, "Busy thread id not received") - self.assertIn(sleeper_tid, statuses, "Sleeper tid not found in sampled threads") - self.assertIn(busy_tid, statuses, "Busy tid not found in sampled threads") - self.assertFalse(statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL, "Sleeper thread should not have GIL") - self.assertTrue(statuses[busy_tid] & THREAD_STATUS_HAS_GIL, "Busy thread should have GIL") + # Sample until we see expected thread states + statuses = {} + try: + unwinder = RemoteUnwinder( + p.pid, + all_threads=True, + mode=mode, + skip_non_matching_threads=False, + ) + for _ in range(MAX_TRIES): + traces = unwinder.get_stack_trace() + statuses = self._get_thread_statuses(traces) + if check_condition( + statuses, sleeper_tid, busy_tid + ): + break + time.sleep(0.5) + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) + + return statuses, sleeper_tid, busy_tid finally: - if client_socket is not None: - client_socket.close() - p.terminate() - p.wait(timeout=SHORT_TIMEOUT) + _cleanup_sockets(client_socket, server_socket) @unittest.skipIf( sys.platform not in ("linux", "darwin", "win32"), "Test only runs on supported platforms (Linux, macOS, or Windows)", ) - @unittest.skipIf(sys.platform == "android", "Android raises Linux-specific exception") + @unittest.skipIf( + sys.platform == "android", "Android raises Linux-specific exception" + ) + def test_thread_status_detection(self): + def check_cpu_status(statuses, sleeper_tid, busy_tid): + return ( + sleeper_tid in statuses + and busy_tid in statuses + and not (statuses[sleeper_tid] & THREAD_STATUS_ON_CPU) + and (statuses[busy_tid] & THREAD_STATUS_ON_CPU) + ) + + statuses, sleeper_tid, busy_tid = self._run_thread_status_test( + PROFILING_MODE_CPU, check_cpu_status + ) + + self.assertIn(sleeper_tid, statuses) + self.assertIn(busy_tid, statuses) + self.assertFalse( + statuses[sleeper_tid] & THREAD_STATUS_ON_CPU, + "Sleeper thread should be off CPU", + ) + self.assertTrue( + statuses[busy_tid] & THREAD_STATUS_ON_CPU, + "Busy thread should be on CPU", + ) + + @unittest.skipIf( + sys.platform not in ("linux", "darwin", "win32"), + "Test only runs on supported platforms (Linux, macOS, or Windows)", + ) + @unittest.skipIf( + sys.platform == "android", "Android raises Linux-specific exception" + ) + def test_thread_status_gil_detection(self): + def check_gil_status(statuses, sleeper_tid, busy_tid): + return ( + sleeper_tid in statuses + and busy_tid in statuses + and not (statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL) + and (statuses[busy_tid] & THREAD_STATUS_HAS_GIL) + ) + + statuses, sleeper_tid, busy_tid = self._run_thread_status_test( + PROFILING_MODE_GIL, check_gil_status + ) + + self.assertIn(sleeper_tid, statuses) + self.assertIn(busy_tid, statuses) + self.assertFalse( + statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL, + "Sleeper thread should not have GIL", + ) + self.assertTrue( + statuses[busy_tid] & THREAD_STATUS_HAS_GIL, + "Busy thread should have GIL", + ) + + @unittest.skipIf( + sys.platform not in ("linux", "darwin", "win32"), + "Test only runs on supported platforms (Linux, macOS, or Windows)", + ) + @unittest.skipIf( + sys.platform == "android", "Android raises Linux-specific exception" + ) def test_thread_status_all_mode_detection(self): port = find_unused_port() script = textwrap.dedent( @@ -1952,104 +2121,112 @@ def busy_thread(): with os_helper.temp_dir() as tmp_dir: script_file = make_script(tmp_dir, "script", script) - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(("localhost", port)) - server_socket.listen(2) - server_socket.settimeout(SHORT_TIMEOUT) - - p = subprocess.Popen( - [sys.executable, script_file], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - + server_socket = _create_server_socket(port, backlog=2) client_sockets = [] + try: - sleeper_tid = None - busy_tid = None + with _managed_subprocess( + [sys.executable, script_file], + ) as p: + sleeper_tid = None + busy_tid = None - # Receive thread IDs from the child process - for _ in range(2): - client_socket, _ = server_socket.accept() - client_sockets.append(client_socket) - line = client_socket.recv(1024) - if line: - if line.startswith(b"sleeper:"): - try: - sleeper_tid = int(line.split(b":")[-1]) - except Exception: - pass - elif line.startswith(b"busy:"): - try: - busy_tid = int(line.split(b":")[-1]) - except Exception: - pass + # Receive thread IDs from the child process + for _ in range(2): + client_socket, _ = server_socket.accept() + client_sockets.append(client_socket) + line = client_socket.recv(1024) + if line: + if line.startswith(b"sleeper:"): + try: + sleeper_tid = int(line.split(b":")[-1]) + except (ValueError, IndexError): + pass + elif line.startswith(b"busy:"): + try: + busy_tid = int(line.split(b":")[-1]) + except (ValueError, IndexError): + pass - server_socket.close() + server_socket.close() + server_socket = None - attempts = 10 - statuses = {} - try: - unwinder = RemoteUnwinder(p.pid, all_threads=True, mode=PROFILING_MODE_ALL, - skip_non_matching_threads=False) - for _ in range(attempts): - traces = unwinder.get_stack_trace() - # Find threads and their statuses - statuses = {} - for interpreter_info in traces: - for thread_info in interpreter_info.threads: - statuses[thread_info.thread_id] = thread_info.status + statuses = {} + try: + unwinder = RemoteUnwinder( + p.pid, + all_threads=True, + mode=PROFILING_MODE_ALL, + skip_non_matching_threads=False, + ) + for _ in range(MAX_TRIES): + traces = unwinder.get_stack_trace() + statuses = self._get_thread_statuses(traces) - # Check ALL mode provides both GIL and CPU info - # - sleeper should NOT have ON_CPU and NOT have HAS_GIL - # - busy should have ON_CPU and have HAS_GIL - if (sleeper_tid in statuses and - busy_tid in statuses and - not (statuses[sleeper_tid] & THREAD_STATUS_ON_CPU) and - not (statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL) and - (statuses[busy_tid] & THREAD_STATUS_ON_CPU) and - (statuses[busy_tid] & THREAD_STATUS_HAS_GIL)): - break - time.sleep(0.5) - except PermissionError: - self.skipTest( - "Insufficient permissions to read the stack trace" + # Check ALL mode provides both GIL and CPU info + if ( + sleeper_tid in statuses + and busy_tid in statuses + and not ( + statuses[sleeper_tid] + & THREAD_STATUS_ON_CPU + ) + and not ( + statuses[sleeper_tid] + & THREAD_STATUS_HAS_GIL + ) + and (statuses[busy_tid] & THREAD_STATUS_ON_CPU) + and ( + statuses[busy_tid] & THREAD_STATUS_HAS_GIL + ) + ): + break + time.sleep(0.5) + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) + + self.assertIsNotNone( + sleeper_tid, "Sleeper thread id not received" + ) + self.assertIsNotNone( + busy_tid, "Busy thread id not received" + ) + self.assertIn(sleeper_tid, statuses) + self.assertIn(busy_tid, statuses) + + # Sleeper: off CPU, no GIL + self.assertFalse( + statuses[sleeper_tid] & THREAD_STATUS_ON_CPU, + "Sleeper should be off CPU", + ) + self.assertFalse( + statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL, + "Sleeper should not have GIL", ) - self.assertIsNotNone(sleeper_tid, "Sleeper thread id not received") - self.assertIsNotNone(busy_tid, "Busy thread id not received") - self.assertIn(sleeper_tid, statuses, "Sleeper tid not found in sampled threads") - self.assertIn(busy_tid, statuses, "Busy tid not found in sampled threads") - - # Sleeper thread: off CPU, no GIL - self.assertFalse(statuses[sleeper_tid] & THREAD_STATUS_ON_CPU, "Sleeper should be off CPU") - self.assertFalse(statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL, "Sleeper should not have GIL") - - # Busy thread: on CPU, has GIL - self.assertTrue(statuses[busy_tid] & THREAD_STATUS_ON_CPU, "Busy should be on CPU") - self.assertTrue(statuses[busy_tid] & THREAD_STATUS_HAS_GIL, "Busy should have GIL") - + # Busy: on CPU, has GIL + self.assertTrue( + statuses[busy_tid] & THREAD_STATUS_ON_CPU, + "Busy should be on CPU", + ) + self.assertTrue( + statuses[busy_tid] & THREAD_STATUS_HAS_GIL, + "Busy should have GIL", + ) finally: - for client_socket in client_sockets: - client_socket.close() - p.terminate() - p.wait(timeout=SHORT_TIMEOUT) - p.stdout.close() - p.stderr.close() + _cleanup_sockets(*client_sockets, server_socket) -class TestFrameCaching(unittest.TestCase): +class TestFrameCaching(RemoteInspectionTestBase): """Test that frame caching produces correct results. Uses socket-based synchronization for deterministic testing. All tests verify cache reuse via object identity checks (assertIs). """ - maxDiff = None - MAX_TRIES = 10 - - @contextlib.contextmanager + @contextmanager def _target_process(self, script_body): """Context manager for running a target process with socket sync.""" port = find_unused_port() @@ -2064,61 +2241,62 @@ def _target_process(self, script_body): script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(("localhost", port)) - server_socket.settimeout(SHORT_TIMEOUT) - server_socket.listen(1) - + server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None - p = None + try: - p = subprocess.Popen([sys.executable, script_name]) - client_socket, _ = server_socket.accept() - server_socket.close() + with _managed_subprocess([sys.executable, script_name]) as p: + client_socket, _ = server_socket.accept() + server_socket.close() + server_socket = None - def make_unwinder(cache_frames=True): - return RemoteUnwinder(p.pid, all_threads=True, cache_frames=cache_frames) + def make_unwinder(cache_frames=True): + return RemoteUnwinder( + p.pid, all_threads=True, cache_frames=cache_frames + ) - yield p, client_socket, make_unwinder + yield p, client_socket, make_unwinder except PermissionError: - self.skipTest("Insufficient permissions to read the stack trace") + self.skipTest( + "Insufficient permissions to read the stack trace" + ) finally: - if client_socket: - client_socket.close() - if p: - p.kill() - p.terminate() - p.wait(timeout=SHORT_TIMEOUT) + _cleanup_sockets(client_socket, server_socket) - def _wait_for_signal(self, client_socket, signal): - """Block until signal received from target.""" - response = b"" - while signal not in response: - chunk = client_socket.recv(64) - if not chunk: - break - response += chunk - return response - - def _get_frames(self, unwinder, required_funcs): - """Sample and return frame_info list for thread containing required_funcs.""" - traces = unwinder.get_stack_trace() - for interp in traces: - for thread in interp.threads: - funcs = [f.funcname for f in thread.frame_info] - if required_funcs.issubset(set(funcs)): - return thread.frame_info + def _get_frames_with_retry(self, unwinder, required_funcs): + """Get frames containing required_funcs, with retry for transient errors.""" + for _ in range(MAX_TRIES): + try: + traces = unwinder.get_stack_trace() + for interp in traces: + for thread in interp.threads: + funcs = {f.funcname for f in thread.frame_info} + if required_funcs.issubset(funcs): + return thread.frame_info + except RuntimeError as e: + if _is_retriable_error(e): + pass + else: + raise + time.sleep(0.1) return None - def _sample_frames(self, client_socket, unwinder, wait_signal, send_ack, required_funcs, expected_frames=1): - """Wait for signal, sample frames, send ack. Returns frame_info list.""" - self._wait_for_signal(client_socket, wait_signal) - # Give at least MAX_TRIES tries for the process to arrive to a steady state - for _ in range(self.MAX_TRIES): - frames = self._get_frames(unwinder, required_funcs) + def _sample_frames( + self, + client_socket, + unwinder, + wait_signal, + send_ack, + required_funcs, + expected_frames=1, + ): + """Wait for signal, sample frames with retry until required funcs present, send ack.""" + _wait_for_signal(client_socket, wait_signal) + frames = None + for _ in range(MAX_TRIES): + frames = self._get_frames_with_retry(unwinder, required_funcs) if frames and len(frames) >= expected_frames: break time.sleep(0.1) @@ -2155,13 +2333,23 @@ def level1(): level1() """ - with self._target_process(script_body) as (p, client_socket, make_unwinder): + with self._target_process(script_body) as ( + p, + client_socket, + make_unwinder, + ): unwinder = make_unwinder(cache_frames=True) expected = {"level1", "level2", "level3"} - frames1 = self._sample_frames(client_socket, unwinder, b"sync1", b"ack", expected) - frames2 = self._sample_frames(client_socket, unwinder, b"sync2", b"ack", expected) - frames3 = self._sample_frames(client_socket, unwinder, b"sync3", b"done", expected) + frames1 = self._sample_frames( + client_socket, unwinder, b"sync1", b"ack", expected + ) + frames2 = self._sample_frames( + client_socket, unwinder, b"sync2", b"ack", expected + ) + frames3 = self._sample_frames( + client_socket, unwinder, b"sync3", b"done", expected + ) self.assertIsNotNone(frames1) self.assertIsNotNone(frames2) @@ -2176,8 +2364,12 @@ def level1(): # Parent frames (index 1+) must be identical objects (cache reuse) for i in range(1, len(frames1)): f1, f2, f3 = frames1[i], frames2[i], frames3[i] - self.assertIs(f1, f2, f"Frame {i}: samples 1-2 must be same object") - self.assertIs(f2, f3, f"Frame {i}: samples 2-3 must be same object") + self.assertIs( + f1, f2, f"Frame {i}: samples 1-2 must be same object" + ) + self.assertIs( + f2, f3, f"Frame {i}: samples 2-3 must be same object" + ) @skip_if_not_supported @unittest.skipIf( @@ -2203,13 +2395,25 @@ def inner(): outer() """ - with self._target_process(script_body) as (p, client_socket, make_unwinder): + with self._target_process(script_body) as ( + p, + client_socket, + make_unwinder, + ): unwinder = make_unwinder(cache_frames=True) - frames_a = self._sample_frames(client_socket, unwinder, b"line_a", b"ack", {"inner"}) - frames_b = self._sample_frames(client_socket, unwinder, b"line_b", b"ack", {"inner"}) - frames_c = self._sample_frames(client_socket, unwinder, b"line_c", b"ack", {"inner"}) - frames_d = self._sample_frames(client_socket, unwinder, b"line_d", b"done", {"inner"}) + frames_a = self._sample_frames( + client_socket, unwinder, b"line_a", b"ack", {"inner"} + ) + frames_b = self._sample_frames( + client_socket, unwinder, b"line_b", b"ack", {"inner"} + ) + frames_c = self._sample_frames( + client_socket, unwinder, b"line_c", b"ack", {"inner"} + ) + frames_d = self._sample_frames( + client_socket, unwinder, b"line_d", b"done", {"inner"} + ) self.assertIsNotNone(frames_a) self.assertIsNotNone(frames_b) @@ -2228,12 +2432,15 @@ def inner(): self.assertEqual(inner_d.funcname, "inner") # Line numbers must be different and increasing (execution moves forward) - self.assertLess(inner_a.lineno, inner_b.lineno, - "Line B should be after line A") - self.assertLess(inner_b.lineno, inner_c.lineno, - "Line C should be after line B") - self.assertLess(inner_c.lineno, inner_d.lineno, - "Line D should be after line C") + self.assertLess( + inner_a.lineno, inner_b.lineno, "Line B should be after line A" + ) + self.assertLess( + inner_b.lineno, inner_c.lineno, "Line C should be after line B" + ) + self.assertLess( + inner_c.lineno, inner_d.lineno, "Line D should be after line C" + ) @skip_if_not_supported @unittest.skipIf( @@ -2255,13 +2462,23 @@ def outer(): outer() """ - with self._target_process(script_body) as (p, client_socket, make_unwinder): + with self._target_process(script_body) as ( + p, + client_socket, + make_unwinder, + ): unwinder = make_unwinder(cache_frames=True) frames_deep = self._sample_frames( - client_socket, unwinder, b"at_inner", b"ack", {"inner", "outer"}) + client_socket, + unwinder, + b"at_inner", + b"ack", + {"inner", "outer"}, + ) frames_shallow = self._sample_frames( - client_socket, unwinder, b"at_outer", b"done", {"outer"}) + client_socket, unwinder, b"at_outer", b"done", {"outer"} + ) self.assertIsNotNone(frames_deep) self.assertIsNotNone(frames_shallow) @@ -2297,13 +2514,27 @@ def top(): top() """ - with self._target_process(script_body) as (p, client_socket, make_unwinder): + with self._target_process(script_body) as ( + p, + client_socket, + make_unwinder, + ): unwinder = make_unwinder(cache_frames=True) frames_before = self._sample_frames( - client_socket, unwinder, b"at_middle", b"ack", {"middle", "top"}) + client_socket, + unwinder, + b"at_middle", + b"ack", + {"middle", "top"}, + ) frames_after = self._sample_frames( - client_socket, unwinder, b"at_deeper", b"done", {"deeper", "middle", "top"}) + client_socket, + unwinder, + b"at_deeper", + b"done", + {"deeper", "middle", "top"}, + ) self.assertIsNotNone(frames_before) self.assertIsNotNone(frames_after) @@ -2345,15 +2576,29 @@ def func_a(): func_a() """ - with self._target_process(script_body) as (p, client_socket, make_unwinder): + with self._target_process(script_body) as ( + p, + client_socket, + make_unwinder, + ): unwinder = make_unwinder(cache_frames=True) # Sample at C: stack is A→B→C frames_c = self._sample_frames( - client_socket, unwinder, b"at_c", b"ack", {"func_a", "func_b", "func_c"}) + client_socket, + unwinder, + b"at_c", + b"ack", + {"func_a", "func_b", "func_c"}, + ) # Sample at D: stack is A→B→D (C returned, D called) frames_d = self._sample_frames( - client_socket, unwinder, b"at_d", b"done", {"func_a", "func_b", "func_d"}) + client_socket, + unwinder, + b"at_d", + b"done", + {"func_a", "func_b", "func_d"}, + ) self.assertIsNotNone(frames_c) self.assertIsNotNone(frames_d) @@ -2376,8 +2621,16 @@ def find_frame(frames, funcname): self.assertIsNotNone(frame_b_in_d) # The bottom frames (A, B) should be the SAME objects (cache reuse) - self.assertIs(frame_a_in_c, frame_a_in_d, "func_a frame should be reused from cache") - self.assertIs(frame_b_in_c, frame_b_in_d, "func_b frame should be reused from cache") + self.assertIs( + frame_a_in_c, + frame_a_in_d, + "func_a frame should be reused from cache", + ) + self.assertIs( + frame_b_in_c, + frame_b_in_d, + "func_b frame should be reused from cache", + ) @skip_if_not_supported @unittest.skipIf( @@ -2399,13 +2652,19 @@ def recurse(n): recurse(5) """ - with self._target_process(script_body) as (p, client_socket, make_unwinder): + with self._target_process(script_body) as ( + p, + client_socket, + make_unwinder, + ): unwinder = make_unwinder(cache_frames=True) frames1 = self._sample_frames( - client_socket, unwinder, b"sync1", b"ack", {"recurse"}) + client_socket, unwinder, b"sync1", b"ack", {"recurse"} + ) frames2 = self._sample_frames( - client_socket, unwinder, b"sync2", b"done", {"recurse"}) + client_socket, unwinder, b"sync2", b"done", {"recurse"} + ) self.assertIsNotNone(frames1) self.assertIsNotNone(frames2) @@ -2421,8 +2680,11 @@ def recurse(n): # Parent frames (index 1+) should be identical objects (cache reuse) for i in range(1, len(frames1)): - self.assertIs(frames1[i], frames2[i], - f"Frame {i}: recursive frames must be same object") + self.assertIs( + frames1[i], + frames2[i], + f"Frame {i}: recursive frames must be same object", + ) @skip_if_not_supported @unittest.skipIf( @@ -2444,16 +2706,24 @@ def level1(): level1() """ - with self._target_process(script_body) as (p, client_socket, make_unwinder): - self._wait_for_signal(client_socket, b"ready") + with self._target_process(script_body) as ( + p, + client_socket, + make_unwinder, + ): + _wait_for_signal(client_socket, b"ready") # Sample with cache unwinder_cache = make_unwinder(cache_frames=True) - frames_cached = self._get_frames(unwinder_cache, {"level1", "level2", "level3"}) + frames_cached = self._get_frames_with_retry( + unwinder_cache, {"level1", "level2", "level3"} + ) # Sample without cache unwinder_no_cache = make_unwinder(cache_frames=False) - frames_no_cache = self._get_frames(unwinder_no_cache, {"level1", "level2", "level3"}) + frames_no_cache = self._get_frames_with_retry( + unwinder_no_cache, {"level1", "level2", "level3"} + ) client_socket.sendall(b"done") @@ -2526,7 +2796,11 @@ def foo2(): t2.join() """ - with self._target_process(script_body) as (p, client_socket, make_unwinder): + with self._target_process(script_body) as ( + p, + client_socket, + make_unwinder, + ): unwinder = make_unwinder(cache_frames=True) buffer = b"" @@ -2624,16 +2898,24 @@ def get_thread_frames(target_funcs): # Thread 1 at blech1: bar1/baz1 should be GONE (cache invalidated) self.assertIn("blech1", t1_blech) self.assertIn("foo1", t1_blech) - self.assertNotIn("bar1", t1_blech, "Cache not invalidated: bar1 still present") - self.assertNotIn("baz1", t1_blech, "Cache not invalidated: baz1 still present") + self.assertNotIn( + "bar1", t1_blech, "Cache not invalidated: bar1 still present" + ) + self.assertNotIn( + "baz1", t1_blech, "Cache not invalidated: baz1 still present" + ) # No cross-contamination self.assertNotIn("blech2", t1_blech) # Thread 2 at blech2: bar2/baz2 should be GONE (cache invalidated) self.assertIn("blech2", t2_blech) self.assertIn("foo2", t2_blech) - self.assertNotIn("bar2", t2_blech, "Cache not invalidated: bar2 still present") - self.assertNotIn("baz2", t2_blech, "Cache not invalidated: baz2 still present") + self.assertNotIn( + "bar2", t2_blech, "Cache not invalidated: bar2 still present" + ) + self.assertNotIn( + "baz2", t2_blech, "Cache not invalidated: baz2 still present" + ) # No cross-contamination self.assertNotIn("blech1", t2_blech) @@ -2663,17 +2945,25 @@ def level1(): level1() """ - with self._target_process(script_body) as (p, client_socket, make_unwinder): + with self._target_process(script_body) as ( + p, + client_socket, + make_unwinder, + ): expected = {"level1", "level2", "level3", "level4"} # First unwinder samples - this sets last_profiled_frame in target unwinder1 = make_unwinder(cache_frames=True) - frames1 = self._sample_frames(client_socket, unwinder1, b"sync1", b"ack", expected) + frames1 = self._sample_frames( + client_socket, unwinder1, b"sync1", b"ack", expected + ) # Create NEW unwinder (empty cache) and sample # The target still has last_profiled_frame set from unwinder1 unwinder2 = make_unwinder(cache_frames=True) - frames2 = self._sample_frames(client_socket, unwinder2, b"sync2", b"done", expected) + frames2 = self._sample_frames( + client_socket, unwinder2, b"sync2", b"done", expected + ) self.assertIsNotNone(frames1) self.assertIsNotNone(frames2) @@ -2687,8 +2977,11 @@ def level1(): self.assertIn(level, funcs2, f"{level} missing from second sample") # Should have same stack depth - self.assertEqual(len(frames1), len(frames2), - "New unwinder should return complete stack despite stale last_profiled_frame") + self.assertEqual( + len(frames1), + len(frames2), + "New unwinder should return complete stack despite stale last_profiled_frame", + ) @skip_if_not_supported @unittest.skipIf( @@ -2719,16 +3012,30 @@ def recurse(n): recurse({depth}) """ - with self._target_process(script_body) as (p, client_socket, make_unwinder): + with self._target_process(script_body) as ( + p, + client_socket, + make_unwinder, + ): unwinder_cache = make_unwinder(cache_frames=True) unwinder_no_cache = make_unwinder(cache_frames=False) frames_cached = self._sample_frames( - client_socket, unwinder_cache, b"ready", b"ack", {"recurse"}, expected_frames=1102 + client_socket, + unwinder_cache, + b"ready", + b"ack", + {"recurse"}, + expected_frames=1102, ) # Sample again with no cache for comparison frames_no_cache = self._sample_frames( - client_socket, unwinder_no_cache, b"ready2", b"done", {"recurse"}, expected_frames=1102 + client_socket, + unwinder_no_cache, + b"ready2", + b"done", + {"recurse"}, + expected_frames=1102, ) self.assertIsNotNone(frames_cached) @@ -2738,12 +3045,19 @@ def recurse(n): cached_count = [f.funcname for f in frames_cached].count("recurse") no_cache_count = [f.funcname for f in frames_no_cache].count("recurse") - self.assertGreater(cached_count, 1000, "Should have >1000 recurse frames") - self.assertGreater(no_cache_count, 1000, "Should have >1000 recurse frames") + self.assertGreater( + cached_count, 1000, "Should have >1000 recurse frames" + ) + self.assertGreater( + no_cache_count, 1000, "Should have >1000 recurse frames" + ) # Both modes should produce same frame count - self.assertEqual(len(frames_cached), len(frames_no_cache), - "Cache exhaustion should not affect stack completeness") + self.assertEqual( + len(frames_cached), + len(frames_no_cache), + "Cache exhaustion should not affect stack completeness", + ) @skip_if_not_supported @unittest.skipIf( @@ -2759,7 +3073,7 @@ def test_get_stats(self): with self._target_process(script_body) as (p, client_socket, _): unwinder = RemoteUnwinder(p.pid, all_threads=True, stats=True) - self._wait_for_signal(client_socket, b"ready") + _wait_for_signal(client_socket, b"ready") # Take a sample unwinder.get_stack_trace() @@ -2769,14 +3083,18 @@ def test_get_stats(self): # Verify expected keys exist expected_keys = [ - 'total_samples', 'frame_cache_hits', 'frame_cache_misses', - 'frame_cache_partial_hits', 'frames_read_from_cache', - 'frames_read_from_memory', 'frame_cache_hit_rate' + "total_samples", + "frame_cache_hits", + "frame_cache_misses", + "frame_cache_partial_hits", + "frames_read_from_cache", + "frames_read_from_memory", + "frame_cache_hit_rate", ] for key in expected_keys: self.assertIn(key, stats) - self.assertEqual(stats['total_samples'], 1) + self.assertEqual(stats["total_samples"], 1) @skip_if_not_supported @unittest.skipIf( @@ -2791,8 +3109,10 @@ def test_get_stats_disabled_raises(self): """ with self._target_process(script_body) as (p, client_socket, _): - unwinder = RemoteUnwinder(p.pid, all_threads=True) # stats=False by default - self._wait_for_signal(client_socket, b"ready") + unwinder = RemoteUnwinder( + p.pid, all_threads=True + ) # stats=False by default + _wait_for_signal(client_socket, b"ready") with self.assertRaises(RuntimeError): unwinder.get_stats()