import unittest import os import textwrap import contextlib import importlib import sys import socket import threading import time from contextlib import contextmanager from asyncio import staggered, taskgroups, base_events, tasks from unittest.mock import ANY from test.support import ( os_helper, SHORT_TIMEOUT, busy_retry, requires_gil_enabled, ) from test.support.script_helper import make_script from test.support.socket_helper import find_unused_port import subprocess # Profiling mode constants PROFILING_MODE_WALL = 0 PROFILING_MODE_CPU = 1 PROFILING_MODE_GIL = 2 PROFILING_MODE_ALL = 3 # Thread status flags THREAD_STATUS_HAS_GIL = 1 << 0 THREAD_STATUS_ON_CPU = 1 << 1 THREAD_STATUS_UNKNOWN = 1 << 2 # Maximum number of retry attempts for operations that may fail transiently MAX_TRIES = 10 try: from concurrent import interpreters except ImportError: interpreters = None PROCESS_VM_READV_SUPPORTED = False try: from _remote_debugging import PROCESS_VM_READV_SUPPORTED from _remote_debugging import RemoteUnwinder from _remote_debugging import FrameInfo, CoroInfo, TaskInfo except ImportError: raise unittest.SkipTest( "Test only runs when _remote_debugging is available" ) # ============================================================================ # Module-level helper functions # ============================================================================ def _make_test_script(script_dir, script_basename, source): to_return = make_script(script_dir, script_basename, source) importlib.invalidate_caches() return to_return def _create_server_socket(port, backlog=1): """Create and configure a server socket for test communication.""" server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(backlog) return server_socket def _wait_for_signal(sock, expected_signals, timeout=SHORT_TIMEOUT): """ Wait for expected signal(s) from a socket with proper timeout and EOF handling. Args: sock: Connected socket to read from expected_signals: Single bytes object or list of bytes objects to wait for timeout: Socket timeout in seconds Returns: bytes: Complete accumulated response buffer Raises: RuntimeError: If connection closed before signal received or timeout """ if isinstance(expected_signals, bytes): expected_signals = [expected_signals] sock.settimeout(timeout) buffer = b"" while True: # Check if all expected signals are in buffer if all(sig in buffer for sig in expected_signals): return buffer try: chunk = sock.recv(4096) if not chunk: # EOF - connection closed raise RuntimeError( f"Connection closed before receiving expected signals. " f"Expected: {expected_signals}, Got: {buffer[-200:]!r}" ) buffer += chunk except socket.timeout: raise RuntimeError( f"Timeout waiting for signals. " f"Expected: {expected_signals}, Got: {buffer[-200:]!r}" ) def _wait_for_n_signals(sock, signal_pattern, count, timeout=SHORT_TIMEOUT): """ Wait for N occurrences of a signal pattern. Args: sock: Connected socket to read from signal_pattern: bytes pattern to count (e.g., b"ready") count: Number of occurrences expected timeout: Socket timeout in seconds Returns: bytes: Complete accumulated response buffer Raises: RuntimeError: If connection closed or timeout before receiving all signals """ sock.settimeout(timeout) buffer = b"" found_count = 0 while found_count < count: try: chunk = sock.recv(4096) if not chunk: raise RuntimeError( f"Connection closed after {found_count}/{count} signals. " f"Last 200 bytes: {buffer[-200:]!r}" ) buffer += chunk # Count occurrences in entire buffer found_count = buffer.count(signal_pattern) except socket.timeout: raise RuntimeError( f"Timeout waiting for {count} signals (found {found_count}). " f"Last 200 bytes: {buffer[-200:]!r}" ) return buffer @contextmanager def _managed_subprocess(args, timeout=SHORT_TIMEOUT): """ Context manager for subprocess lifecycle management. Ensures process is properly terminated and cleaned up even on exceptions. Uses graceful termination first, then forceful kill if needed. """ p = subprocess.Popen(args) try: yield p finally: try: p.terminate() try: p.wait(timeout=timeout) except subprocess.TimeoutExpired: p.kill() try: p.wait(timeout=timeout) except subprocess.TimeoutExpired: pass # Process refuses to die, nothing more we can do except OSError: pass # Process already dead def _cleanup_sockets(*sockets): """Safely close multiple sockets, ignoring errors.""" for sock in sockets: if sock is not None: try: sock.close() except OSError: pass # ============================================================================ # Decorators and skip conditions # ============================================================================ skip_if_not_supported = unittest.skipIf( ( sys.platform != "darwin" and sys.platform != "linux" and sys.platform != "win32" ), "Test only runs on Linux, Windows and MacOS", ) def requires_subinterpreters(meth): """Decorator to skip a test if subinterpreters are not supported.""" return unittest.skipIf(interpreters is None, "subinterpreters required")( meth ) # ============================================================================ # Simple wrapper functions for RemoteUnwinder # ============================================================================ def get_stack_trace(pid): for _ in busy_retry(SHORT_TIMEOUT): try: unwinder = RemoteUnwinder(pid, all_threads=True, debug=True) return unwinder.get_stack_trace() except RuntimeError as e: continue raise RuntimeError("Failed to get stack trace after retries") def get_async_stack_trace(pid): for _ in busy_retry(SHORT_TIMEOUT): try: unwinder = RemoteUnwinder(pid, debug=True) return unwinder.get_async_stack_trace() except RuntimeError as e: continue raise RuntimeError("Failed to get async stack trace after retries") def get_all_awaited_by(pid): for _ in busy_retry(SHORT_TIMEOUT): try: unwinder = RemoteUnwinder(pid, debug=True) return unwinder.get_all_awaited_by() except RuntimeError as e: continue raise RuntimeError("Failed to get all awaited_by after retries") # ============================================================================ # Base test class with shared infrastructure # ============================================================================ class RemoteInspectionTestBase(unittest.TestCase): """Base class for remote inspection tests with common helpers.""" maxDiff = None def _run_script_and_get_trace( self, script, trace_func, wait_for_signals=None, port=None, backlog=1, ): """ Common pattern: run a script, wait for signals, get trace. Args: script: Script content (will be formatted with port if {port} present) trace_func: Function to call with pid to get trace (e.g., get_stack_trace) wait_for_signals: Signal(s) to wait for before getting trace port: Port to use (auto-selected if None) backlog: Socket listen backlog Returns: tuple: (trace_result, script_name) """ if port is None: port = find_unused_port() # Format script with port if needed if "{port}" in script or "{{port}}" in script: script = script.replace("{{port}}", "{port}").format(port=port) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port, backlog) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None if wait_for_signals: _wait_for_signal(client_socket, wait_for_signals) try: trace = trace_func(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) return trace, script_name finally: _cleanup_sockets(client_socket, server_socket) def _find_frame_in_trace(self, stack_trace, predicate): """ Find a frame matching predicate in stack trace. Args: stack_trace: List of InterpreterInfo objects predicate: Function(frame) -> bool Returns: FrameInfo or None """ for interpreter_info in stack_trace: for thread_info in interpreter_info.threads: for frame in thread_info.frame_info: if predicate(frame): return frame return None def _find_thread_by_id(self, stack_trace, thread_id): """Find a thread by its native thread ID.""" for interpreter_info in stack_trace: for thread_info in interpreter_info.threads: if thread_info.thread_id == thread_id: return thread_info return None def _find_thread_with_frame(self, stack_trace, frame_predicate): """Find a thread containing a frame matching predicate.""" for interpreter_info in stack_trace: for thread_info in interpreter_info.threads: for frame in thread_info.frame_info: if frame_predicate(frame): return thread_info return None def _get_thread_statuses(self, stack_trace): """Extract thread_id -> status mapping from stack trace.""" statuses = {} for interpreter_info in stack_trace: for thread_info in interpreter_info.threads: statuses[thread_info.thread_id] = thread_info.status return statuses def _get_task_id_map(self, stack_trace): """Create task_id -> task mapping from async stack trace.""" return {task.task_id: task for task in stack_trace[0].awaited_by} def _get_awaited_by_relationships(self, stack_trace): """Extract task name to awaited_by set mapping.""" id_to_task = self._get_task_id_map(stack_trace) return { task.task_name: set( id_to_task[awaited.task_name].task_name for awaited in task.awaited_by ) for task in stack_trace[0].awaited_by } def _extract_coroutine_stacks(self, stack_trace): """Extract and format coroutine stacks from tasks.""" return { task.task_name: sorted( tuple(tuple(frame) for frame in coro.call_stack) for coro in task.coroutine_stack ) for task in stack_trace[0].awaited_by } # ============================================================================ # Test classes # ============================================================================ class TestGetStackTrace(RemoteInspectionTestBase): @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_remote_stack_trace(self): port = find_unused_port() script = textwrap.dedent( f"""\ import time, sys, socket, threading sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def bar(): for x in range(100): if x == 50: baz() def baz(): foo() def foo(): sock.sendall(b"ready:thread\\n"); time.sleep(10_000) t = threading.Thread(target=bar) t.start() sock.sendall(b"ready:main\\n"); t.join() """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None _wait_for_signal( client_socket, [b"ready:main", b"ready:thread"] ) try: stack_trace = get_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) thread_expected_stack_trace = [ FrameInfo([script_name, 15, "foo"]), FrameInfo([script_name, 12, "baz"]), FrameInfo([script_name, 9, "bar"]), FrameInfo([threading.__file__, ANY, "Thread.run"]), FrameInfo( [ threading.__file__, ANY, "Thread._bootstrap_inner", ] ), FrameInfo( [threading.__file__, ANY, "Thread._bootstrap"] ), ] # Find expected thread stack found_thread = self._find_thread_with_frame( stack_trace, lambda f: f.funcname == "foo" and f.lineno == 15, ) self.assertIsNotNone( found_thread, "Expected thread stack trace not found" ) self.assertEqual( found_thread.frame_info, thread_expected_stack_trace ) # Check main thread main_frame = FrameInfo([script_name, 19, ""]) found_main = self._find_frame_in_trace( stack_trace, lambda f: f == main_frame ) self.assertIsNotNone( found_main, "Main thread stack trace not found" ) finally: _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_async_remote_stack_trace(self): port = find_unused_port() script = textwrap.dedent( f"""\ import asyncio import time import sys import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def c5(): sock.sendall(b"ready"); time.sleep(10_000) async def c4(): await asyncio.sleep(0) c5() async def c3(): await c4() async def c2(): await c3() async def c1(task): await task async def main(): async with asyncio.TaskGroup() as tg: task = tg.create_task(c2(), name="c2_root") tg.create_task(c1(task), name="sub_main_1") tg.create_task(c1(task), name="sub_main_2") def new_eager_loop(): loop = asyncio.new_event_loop() eager_task_factory = asyncio.create_eager_task_factory( asyncio.Task) loop.set_task_factory(eager_task_factory) return loop asyncio.run(main(), loop_factory={{TASK_FACTORY}}) """ ) for task_factory_variant in "asyncio.new_event_loop", "new_eager_loop": with ( self.subTest(task_factory_variant=task_factory_variant), os_helper.temp_dir() as work_dir, ): script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script( script_dir, "script", script.format(TASK_FACTORY=task_factory_variant), ) client_socket = None try: with _managed_subprocess( [sys.executable, script_name] ) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None response = _wait_for_signal(client_socket, b"ready") self.assertIn(b"ready", response) try: stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Check all tasks are present tasks_names = [ task.task_name for task in stack_trace[0].awaited_by ] for task_name in [ "c2_root", "sub_main_1", "sub_main_2", ]: self.assertIn(task_name, tasks_names) # Check awaited_by relationships relationships = self._get_awaited_by_relationships( stack_trace ) self.assertEqual( relationships, { "c2_root": { "Task-1", "sub_main_1", "sub_main_2", }, "Task-1": set(), "sub_main_1": {"Task-1"}, "sub_main_2": {"Task-1"}, }, ) # Check coroutine stacks coroutine_stacks = self._extract_coroutine_stacks( stack_trace ) self.assertEqual( coroutine_stacks, { "Task-1": [ ( tuple( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), tuple( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), tuple([script_name, 26, "main"]), ) ], "c2_root": [ ( tuple([script_name, 10, "c5"]), tuple([script_name, 14, "c4"]), tuple([script_name, 17, "c3"]), tuple([script_name, 20, "c2"]), ) ], "sub_main_1": [ (tuple([script_name, 23, "c1"]),) ], "sub_main_2": [ (tuple([script_name, 23, "c1"]),) ], }, ) # Check awaited_by coroutine stacks id_to_task = self._get_task_id_map(stack_trace) awaited_by_coroutine_stacks = { task.task_name: sorted( ( id_to_task[coro.task_name].task_name, tuple( tuple(frame) for frame in coro.call_stack ), ) for coro in task.awaited_by ) for task in stack_trace[0].awaited_by } self.assertEqual( awaited_by_coroutine_stacks, { "Task-1": [], "c2_root": [ ( "Task-1", ( tuple( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), tuple( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), tuple([script_name, 26, "main"]), ), ), ( "sub_main_1", (tuple([script_name, 23, "c1"]),), ), ( "sub_main_2", (tuple([script_name, 23, "c1"]),), ), ], "sub_main_1": [ ( "Task-1", ( tuple( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), tuple( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), tuple([script_name, 26, "main"]), ), ) ], "sub_main_2": [ ( "Task-1", ( tuple( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), tuple( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), tuple([script_name, 26, "main"]), ), ) ], }, ) finally: _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_asyncgen_remote_stack_trace(self): port = find_unused_port() script = textwrap.dedent( f"""\ import asyncio import time import sys import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) async def gen_nested_call(): sock.sendall(b"ready"); time.sleep(10_000) async def gen(): for num in range(2): yield num if num == 1: await gen_nested_call() async def main(): async for el in gen(): pass asyncio.run(main()) """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None response = _wait_for_signal(client_socket, b"ready") self.assertIn(b"ready", response) try: stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # For this simple asyncgen test, we only expect one task self.assertEqual(len(stack_trace[0].awaited_by), 1) task = stack_trace[0].awaited_by[0] self.assertEqual(task.task_name, "Task-1") # Check the coroutine stack coroutine_stack = sorted( tuple(tuple(frame) for frame in coro.call_stack) for coro in task.coroutine_stack ) self.assertEqual( coroutine_stack, [ ( tuple([script_name, 10, "gen_nested_call"]), tuple([script_name, 16, "gen"]), tuple([script_name, 19, "main"]), ) ], ) # No awaited_by relationships expected self.assertEqual(task.awaited_by, []) finally: _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_async_gather_remote_stack_trace(self): port = find_unused_port() script = textwrap.dedent( f"""\ import asyncio import time import sys import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) async def deep(): await asyncio.sleep(0) sock.sendall(b"ready"); time.sleep(10_000) async def c1(): await asyncio.sleep(0) await deep() async def c2(): await asyncio.sleep(0) async def main(): await asyncio.gather(c1(), c2()) asyncio.run(main()) """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None response = _wait_for_signal(client_socket, b"ready") self.assertIn(b"ready", response) try: stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Check all tasks are present tasks_names = [ task.task_name for task in stack_trace[0].awaited_by ] for task_name in ["Task-1", "Task-2"]: self.assertIn(task_name, tasks_names) # Check awaited_by relationships relationships = self._get_awaited_by_relationships( stack_trace ) self.assertEqual( relationships, { "Task-1": set(), "Task-2": {"Task-1"}, }, ) # Check coroutine stacks coroutine_stacks = self._extract_coroutine_stacks( stack_trace ) self.assertEqual( coroutine_stacks, { "Task-1": [(tuple([script_name, 21, "main"]),)], "Task-2": [ ( tuple([script_name, 11, "deep"]), tuple([script_name, 15, "c1"]), ) ], }, ) # Check awaited_by coroutine stacks id_to_task = self._get_task_id_map(stack_trace) awaited_by_coroutine_stacks = { task.task_name: sorted( ( id_to_task[coro.task_name].task_name, tuple( tuple(frame) for frame in coro.call_stack ), ) for coro in task.awaited_by ) for task in stack_trace[0].awaited_by } self.assertEqual( awaited_by_coroutine_stacks, { "Task-1": [], "Task-2": [ ("Task-1", (tuple([script_name, 21, "main"]),)) ], }, ) finally: _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_async_staggered_race_remote_stack_trace(self): port = find_unused_port() script = textwrap.dedent( f"""\ import asyncio.staggered import time import sys import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) async def deep(): await asyncio.sleep(0) sock.sendall(b"ready"); time.sleep(10_000) async def c1(): await asyncio.sleep(0) await deep() async def c2(): await asyncio.sleep(10_000) async def main(): await asyncio.staggered.staggered_race( [c1, c2], delay=None, ) asyncio.run(main()) """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None response = _wait_for_signal(client_socket, b"ready") self.assertIn(b"ready", response) try: stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Check all tasks are present tasks_names = [ task.task_name for task in stack_trace[0].awaited_by ] for task_name in ["Task-1", "Task-2"]: self.assertIn(task_name, tasks_names) # Check awaited_by relationships relationships = self._get_awaited_by_relationships( stack_trace ) self.assertEqual( relationships, { "Task-1": set(), "Task-2": {"Task-1"}, }, ) # Check coroutine stacks coroutine_stacks = self._extract_coroutine_stacks( stack_trace ) self.assertEqual( coroutine_stacks, { "Task-1": [ ( tuple( [ staggered.__file__, ANY, "staggered_race", ] ), tuple([script_name, 21, "main"]), ) ], "Task-2": [ ( tuple([script_name, 11, "deep"]), tuple([script_name, 15, "c1"]), tuple( [ staggered.__file__, ANY, "staggered_race..run_one_coro", ] ), ) ], }, ) # Check awaited_by coroutine stacks id_to_task = self._get_task_id_map(stack_trace) awaited_by_coroutine_stacks = { task.task_name: sorted( ( id_to_task[coro.task_name].task_name, tuple( tuple(frame) for frame in coro.call_stack ), ) for coro in task.awaited_by ) for task in stack_trace[0].awaited_by } self.assertEqual( awaited_by_coroutine_stacks, { "Task-1": [], "Task-2": [ ( "Task-1", ( tuple( [ staggered.__file__, ANY, "staggered_race", ] ), tuple([script_name, 21, "main"]), ), ) ], }, ) finally: _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_async_global_awaited_by(self): # Reduced from 1000 to 100 to avoid file descriptor exhaustion # when running tests in parallel (e.g., -j 20) NUM_TASKS = 100 port = find_unused_port() script = textwrap.dedent( f"""\ import asyncio import os import random import sys import socket from string import ascii_lowercase, digits from test.support import socket_helper, SHORT_TIMEOUT HOST = '127.0.0.1' PORT = socket_helper.find_unused_port() connections = 0 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) class EchoServerProtocol(asyncio.Protocol): def connection_made(self, transport): global connections connections += 1 self.transport = transport def data_received(self, data): self.transport.write(data) self.transport.close() async def echo_client(message): reader, writer = await asyncio.open_connection(HOST, PORT) writer.write(message.encode()) await writer.drain() data = await reader.read(100) assert message == data.decode() writer.close() await writer.wait_closed() sock.sendall(b"ready") await asyncio.sleep(SHORT_TIMEOUT) async def echo_client_spam(server): async with asyncio.TaskGroup() as tg: while connections < {NUM_TASKS}: msg = list(ascii_lowercase + digits) random.shuffle(msg) tg.create_task(echo_client("".join(msg))) await asyncio.sleep(0) server.close() await server.wait_closed() async def main(): loop = asyncio.get_running_loop() server = await loop.create_server(EchoServerProtocol, HOST, PORT) async with server: async with asyncio.TaskGroup() as tg: tg.create_task(server.serve_forever(), name="server task") tg.create_task(echo_client_spam(server), name="echo client spam") asyncio.run(main()) """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None # Wait for NUM_TASKS "ready" signals try: _wait_for_n_signals(client_socket, b"ready", NUM_TASKS) except RuntimeError as e: self.fail(str(e)) try: all_awaited_by = get_all_awaited_by(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Expected: a list of two elements: 1 thread, 1 interp self.assertEqual(len(all_awaited_by), 2) # Expected: a tuple with the thread ID and the awaited_by list self.assertEqual(len(all_awaited_by[0]), 2) # Expected: no tasks in the fallback per-interp task list self.assertEqual(all_awaited_by[1], (0, [])) entries = all_awaited_by[0][1] # Expected: at least NUM_TASKS pending tasks self.assertGreaterEqual(len(entries), NUM_TASKS) # Check the main task structure main_stack = [ FrameInfo( [taskgroups.__file__, ANY, "TaskGroup._aexit"] ), FrameInfo( [taskgroups.__file__, ANY, "TaskGroup.__aexit__"] ), FrameInfo([script_name, 52, "main"]), ] self.assertIn( TaskInfo( [ANY, "Task-1", [CoroInfo([main_stack, ANY])], []] ), entries, ) self.assertIn( TaskInfo( [ ANY, "server task", [ CoroInfo( [ [ FrameInfo( [ base_events.__file__, ANY, "Server.serve_forever", ] ) ], ANY, ] ) ], [ CoroInfo( [ [ FrameInfo( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), FrameInfo( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), FrameInfo( [script_name, ANY, "main"] ), ], ANY, ] ) ], ] ), entries, ) self.assertIn( TaskInfo( [ ANY, "Task-4", [ CoroInfo( [ [ FrameInfo( [ tasks.__file__, ANY, "sleep", ] ), FrameInfo( [ script_name, 36, "echo_client", ] ), ], ANY, ] ) ], [ CoroInfo( [ [ FrameInfo( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), FrameInfo( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), FrameInfo( [ script_name, 39, "echo_client_spam", ] ), ], ANY, ] ) ], ] ), entries, ) expected_awaited_by = [ CoroInfo( [ [ FrameInfo( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), FrameInfo( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), FrameInfo( [script_name, 39, "echo_client_spam"] ), ], ANY, ] ) ] tasks_with_awaited = [ task for task in entries if task.awaited_by == expected_awaited_by ] self.assertGreaterEqual(len(tasks_with_awaited), NUM_TASKS) # Final task should be from echo client spam (not on Windows) if sys.platform != "win32": self.assertEqual( tasks_with_awaited[-1].awaited_by, entries[-1].awaited_by, ) finally: _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_self_trace(self): stack_trace = get_stack_trace(os.getpid()) this_thread_stack = None for interpreter_info in stack_trace: for thread_info in interpreter_info.threads: if thread_info.thread_id == threading.get_native_id(): this_thread_stack = thread_info.frame_info break if this_thread_stack: break self.assertIsNotNone(this_thread_stack) self.assertEqual( this_thread_stack[:2], [ FrameInfo( [ __file__, get_stack_trace.__code__.co_firstlineno + 4, "get_stack_trace", ] ), FrameInfo( [ __file__, self.test_self_trace.__code__.co_firstlineno + 6, "TestGetStackTrace.test_self_trace", ] ), ], ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) @requires_subinterpreters def test_subinterpreter_stack_trace(self): port = find_unused_port() import pickle subinterp_code = textwrap.dedent(f""" import socket import time def sub_worker(): def nested_func(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) sock.sendall(b"ready:sub\\n") time.sleep(10_000) nested_func() sub_worker() """).strip() pickled_code = pickle.dumps(subinterp_code) script = textwrap.dedent( f""" from concurrent import interpreters import time import sys import socket import threading sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def main_worker(): sock.sendall(b"ready:main\\n") time.sleep(10_000) def run_subinterp(): subinterp = interpreters.create() import pickle pickled_code = {pickled_code!r} subinterp_code = pickle.loads(pickled_code) subinterp.exec(subinterp_code) sub_thread = threading.Thread(target=run_subinterp) sub_thread.start() main_thread = threading.Thread(target=main_worker) main_thread.start() main_thread.join() sub_thread.join() """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_sockets = [] try: with _managed_subprocess([sys.executable, script_name]) as p: # Accept connections from both main and subinterpreter responses = set() while len(responses) < 2: try: client_socket, _ = server_socket.accept() client_sockets.append(client_socket) response = client_socket.recv(1024) if b"ready:main" in response: responses.add("main") if b"ready:sub" in response: responses.add("sub") except socket.timeout: break server_socket.close() server_socket = None try: stack_trace = get_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Verify we have at least one interpreter self.assertGreaterEqual(len(stack_trace), 1) # Look for main interpreter (ID 0) and subinterpreter (ID > 0) main_interp = None sub_interp = None for interpreter_info in stack_trace: if interpreter_info.interpreter_id == 0: main_interp = interpreter_info elif interpreter_info.interpreter_id > 0: sub_interp = interpreter_info self.assertIsNotNone( main_interp, "Main interpreter should be present" ) # Check main interpreter has expected stack trace main_found = self._find_frame_in_trace( [main_interp], lambda f: f.funcname == "main_worker" ) self.assertIsNotNone( main_found, "Main interpreter should have main_worker in stack", ) # If subinterpreter is present, check its stack trace if sub_interp: sub_found = self._find_frame_in_trace( [sub_interp], lambda f: f.funcname in ("sub_worker", "nested_func"), ) self.assertIsNotNone( sub_found, "Subinterpreter should have sub_worker or nested_func in stack", ) finally: _cleanup_sockets(*client_sockets, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) @requires_subinterpreters def test_multiple_subinterpreters_with_threads(self): port = find_unused_port() import pickle subinterp1_code = textwrap.dedent(f""" import socket import time import threading def worker1(): def nested_func(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) sock.sendall(b"ready:sub1-t1\\n") time.sleep(10_000) nested_func() def worker2(): def nested_func(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) sock.sendall(b"ready:sub1-t2\\n") time.sleep(10_000) nested_func() t1 = threading.Thread(target=worker1) t2 = threading.Thread(target=worker2) t1.start() t2.start() t1.join() t2.join() """).strip() subinterp2_code = textwrap.dedent(f""" import socket import time import threading def worker1(): def nested_func(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) sock.sendall(b"ready:sub2-t1\\n") time.sleep(10_000) nested_func() def worker2(): def nested_func(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) sock.sendall(b"ready:sub2-t2\\n") time.sleep(10_000) nested_func() t1 = threading.Thread(target=worker1) t2 = threading.Thread(target=worker2) t1.start() t2.start() t1.join() t2.join() """).strip() pickled_code1 = pickle.dumps(subinterp1_code) pickled_code2 = pickle.dumps(subinterp2_code) script = textwrap.dedent( f""" from concurrent import interpreters import time import sys import socket import threading sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def main_worker(): sock.sendall(b"ready:main\\n") time.sleep(10_000) def run_subinterp1(): subinterp = interpreters.create() import pickle pickled_code = {pickled_code1!r} subinterp_code = pickle.loads(pickled_code) subinterp.exec(subinterp_code) def run_subinterp2(): subinterp = interpreters.create() import pickle pickled_code = {pickled_code2!r} subinterp_code = pickle.loads(pickled_code) subinterp.exec(subinterp_code) sub1_thread = threading.Thread(target=run_subinterp1) sub2_thread = threading.Thread(target=run_subinterp2) sub1_thread.start() sub2_thread.start() main_thread = threading.Thread(target=main_worker) main_thread.start() main_thread.join() sub1_thread.join() sub2_thread.join() """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port, backlog=5) script_name = _make_test_script(script_dir, "script", script) client_sockets = [] try: with _managed_subprocess([sys.executable, script_name]) as p: # Accept connections from main and all subinterpreter threads expected_responses = { "ready:main", "ready:sub1-t1", "ready:sub1-t2", "ready:sub2-t1", "ready:sub2-t2", } responses = set() while len(responses) < 5: try: client_socket, _ = server_socket.accept() client_sockets.append(client_socket) response = client_socket.recv(1024) response_str = response.decode().strip() if response_str in expected_responses: responses.add(response_str) except socket.timeout: break server_socket.close() server_socket = None try: stack_trace = get_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Verify we have multiple interpreters self.assertGreaterEqual(len(stack_trace), 2) # Count interpreters by ID interpreter_ids = { interp.interpreter_id for interp in stack_trace } self.assertIn( 0, interpreter_ids, "Main interpreter should be present", ) self.assertGreaterEqual(len(interpreter_ids), 3) # Count total threads total_threads = sum( len(interp.threads) for interp in stack_trace ) self.assertGreaterEqual(total_threads, 5) # Look for expected function names all_funcnames = set() for interpreter_info in stack_trace: for thread_info in interpreter_info.threads: for frame in thread_info.frame_info: all_funcnames.add(frame.funcname) expected_funcs = { "main_worker", "worker1", "worker2", "nested_func", } found_funcs = expected_funcs.intersection(all_funcnames) self.assertGreater(len(found_funcs), 0) finally: _cleanup_sockets(*client_sockets, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) @requires_gil_enabled("Free threaded builds don't have an 'active thread'") def test_only_active_thread(self): port = find_unused_port() script = textwrap.dedent( f"""\ import time, sys, socket, threading sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def worker_thread(name, barrier, ready_event): barrier.wait() ready_event.wait() time.sleep(10_000) def main_work(): sock.sendall(b"working\\n") count = 0 while count < 100000000: count += 1 if count % 10000000 == 0: pass sock.sendall(b"done\\n") num_threads = 3 barrier = threading.Barrier(num_threads + 1) ready_event = threading.Event() threads = [] for i in range(num_threads): t = threading.Thread(target=worker_thread, args=(f"Worker-{{i}}", barrier, ready_event)) t.start() threads.append(t) barrier.wait() sock.sendall(b"ready\\n") ready_event.set() main_work() """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None # Wait for ready and working signals _wait_for_signal(client_socket, [b"ready", b"working"]) try: # Get stack trace with all threads unwinder_all = RemoteUnwinder(p.pid, all_threads=True) for _ in range(MAX_TRIES): all_traces = unwinder_all.get_stack_trace() found = self._find_frame_in_trace( all_traces, lambda f: f.funcname == "main_work" and f.lineno > 12, ) if found: break time.sleep(0.1) else: self.fail( "Main thread did not start its busy work on time" ) # Get stack trace with only GIL holder unwinder_gil = RemoteUnwinder( p.pid, only_active_thread=True ) gil_traces = unwinder_gil.get_stack_trace() except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Count threads total_threads = sum( len(interp.threads) for interp in all_traces ) self.assertGreater(total_threads, 1) total_gil_threads = sum( len(interp.threads) for interp in gil_traces ) self.assertEqual(total_gil_threads, 1) # Get the GIL holder thread ID gil_thread_id = None for interpreter_info in gil_traces: if interpreter_info.threads: gil_thread_id = interpreter_info.threads[ 0 ].thread_id break # Get all thread IDs all_thread_ids = [] for interpreter_info in all_traces: for thread_info in interpreter_info.threads: all_thread_ids.append(thread_info.thread_id) self.assertIn(gil_thread_id, all_thread_ids) finally: _cleanup_sockets(client_socket, server_socket) class TestUnsupportedPlatformHandling(unittest.TestCase): @unittest.skipIf( sys.platform in ("linux", "darwin", "win32"), "Test only runs on unsupported platforms (not Linux, macOS, or Windows)", ) @unittest.skipIf( sys.platform == "android", "Android raises Linux-specific exception" ) def test_unsupported_platform_error(self): with self.assertRaises(RuntimeError) as cm: RemoteUnwinder(os.getpid()) self.assertIn( "Reading the PyRuntime section is not supported on this platform", str(cm.exception), ) class TestDetectionOfThreadStatus(RemoteInspectionTestBase): def _run_thread_status_test(self, mode, check_condition): """ Common pattern for thread status detection tests. Args: mode: Profiling mode (PROFILING_MODE_CPU, PROFILING_MODE_GIL, etc.) check_condition: Function(statuses, sleeper_tid, busy_tid) -> bool """ port = find_unused_port() script = textwrap.dedent( f"""\ import time, sys, socket, threading import os sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def sleeper(): tid = threading.get_native_id() sock.sendall(f'ready:sleeper:{{tid}}\\n'.encode()) time.sleep(10000) def busy(): tid = threading.get_native_id() sock.sendall(f'ready:busy:{{tid}}\\n'.encode()) x = 0 while True: x = x + 1 time.sleep(0.5) t1 = threading.Thread(target=sleeper) t2 = threading.Thread(target=busy) t1.start() t2.start() sock.sendall(b'ready:main\\n') t1.join() t2.join() sock.close() """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script( script_dir, "thread_status_script", script ) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None # Wait for all ready signals and parse TIDs response = _wait_for_signal( client_socket, [b"ready:main", b"ready:sleeper", b"ready:busy"], ) sleeper_tid = None busy_tid = None for line in response.split(b"\n"): if line.startswith(b"ready:sleeper:"): try: sleeper_tid = int(line.split(b":")[-1]) except (ValueError, IndexError): pass elif line.startswith(b"ready:busy:"): try: busy_tid = int(line.split(b":")[-1]) except (ValueError, IndexError): pass self.assertIsNotNone( sleeper_tid, "Sleeper thread id not received" ) self.assertIsNotNone( busy_tid, "Busy thread id not received" ) # Sample until we see expected thread states statuses = {} try: unwinder = RemoteUnwinder( p.pid, all_threads=True, mode=mode, skip_non_matching_threads=False, ) for _ in range(MAX_TRIES): traces = unwinder.get_stack_trace() statuses = self._get_thread_statuses(traces) if check_condition( statuses, sleeper_tid, busy_tid ): break time.sleep(0.5) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) return statuses, sleeper_tid, busy_tid finally: _cleanup_sockets(client_socket, server_socket) @unittest.skipIf( sys.platform not in ("linux", "darwin", "win32"), "Test only runs on supported platforms (Linux, macOS, or Windows)", ) @unittest.skipIf( sys.platform == "android", "Android raises Linux-specific exception" ) def test_thread_status_detection(self): def check_cpu_status(statuses, sleeper_tid, busy_tid): return ( sleeper_tid in statuses and busy_tid in statuses and not (statuses[sleeper_tid] & THREAD_STATUS_ON_CPU) and (statuses[busy_tid] & THREAD_STATUS_ON_CPU) ) statuses, sleeper_tid, busy_tid = self._run_thread_status_test( PROFILING_MODE_CPU, check_cpu_status ) self.assertIn(sleeper_tid, statuses) self.assertIn(busy_tid, statuses) self.assertFalse( statuses[sleeper_tid] & THREAD_STATUS_ON_CPU, "Sleeper thread should be off CPU", ) self.assertTrue( statuses[busy_tid] & THREAD_STATUS_ON_CPU, "Busy thread should be on CPU", ) @unittest.skipIf( sys.platform not in ("linux", "darwin", "win32"), "Test only runs on supported platforms (Linux, macOS, or Windows)", ) @unittest.skipIf( sys.platform == "android", "Android raises Linux-specific exception" ) def test_thread_status_gil_detection(self): def check_gil_status(statuses, sleeper_tid, busy_tid): return ( sleeper_tid in statuses and busy_tid in statuses and not (statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL) and (statuses[busy_tid] & THREAD_STATUS_HAS_GIL) ) statuses, sleeper_tid, busy_tid = self._run_thread_status_test( PROFILING_MODE_GIL, check_gil_status ) self.assertIn(sleeper_tid, statuses) self.assertIn(busy_tid, statuses) self.assertFalse( statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL, "Sleeper thread should not have GIL", ) self.assertTrue( statuses[busy_tid] & THREAD_STATUS_HAS_GIL, "Busy thread should have GIL", ) @unittest.skipIf( sys.platform not in ("linux", "darwin", "win32"), "Test only runs on supported platforms (Linux, macOS, or Windows)", ) @unittest.skipIf( sys.platform == "android", "Android raises Linux-specific exception" ) def test_thread_status_all_mode_detection(self): port = find_unused_port() script = textwrap.dedent( f"""\ import socket import threading import time import sys def sleeper_thread(): conn = socket.create_connection(("localhost", {port})) conn.sendall(b"sleeper:" + str(threading.get_native_id()).encode()) while True: time.sleep(1) def busy_thread(): conn = socket.create_connection(("localhost", {port})) conn.sendall(b"busy:" + str(threading.get_native_id()).encode()) while True: sum(range(100000)) t1 = threading.Thread(target=sleeper_thread) t2 = threading.Thread(target=busy_thread) t1.start() t2.start() t1.join() t2.join() """ ) with os_helper.temp_dir() as tmp_dir: script_file = make_script(tmp_dir, "script", script) server_socket = _create_server_socket(port, backlog=2) client_sockets = [] try: with _managed_subprocess( [sys.executable, script_file], ) as p: sleeper_tid = None busy_tid = None # Receive thread IDs from the child process for _ in range(2): client_socket, _ = server_socket.accept() client_sockets.append(client_socket) line = client_socket.recv(1024) if line: if line.startswith(b"sleeper:"): try: sleeper_tid = int(line.split(b":")[-1]) except (ValueError, IndexError): pass elif line.startswith(b"busy:"): try: busy_tid = int(line.split(b":")[-1]) except (ValueError, IndexError): pass server_socket.close() server_socket = None statuses = {} try: unwinder = RemoteUnwinder( p.pid, all_threads=True, mode=PROFILING_MODE_ALL, skip_non_matching_threads=False, ) for _ in range(MAX_TRIES): traces = unwinder.get_stack_trace() statuses = self._get_thread_statuses(traces) # Check ALL mode provides both GIL and CPU info if ( sleeper_tid in statuses and busy_tid in statuses and not ( statuses[sleeper_tid] & THREAD_STATUS_ON_CPU ) and not ( statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL ) and (statuses[busy_tid] & THREAD_STATUS_ON_CPU) and ( statuses[busy_tid] & THREAD_STATUS_HAS_GIL ) ): break time.sleep(0.5) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) self.assertIsNotNone( sleeper_tid, "Sleeper thread id not received" ) self.assertIsNotNone( busy_tid, "Busy thread id not received" ) self.assertIn(sleeper_tid, statuses) self.assertIn(busy_tid, statuses) # Sleeper: off CPU, no GIL self.assertFalse( statuses[sleeper_tid] & THREAD_STATUS_ON_CPU, "Sleeper should be off CPU", ) self.assertFalse( statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL, "Sleeper should not have GIL", ) # Busy: on CPU, has GIL self.assertTrue( statuses[busy_tid] & THREAD_STATUS_ON_CPU, "Busy should be on CPU", ) self.assertTrue( statuses[busy_tid] & THREAD_STATUS_HAS_GIL, "Busy should have GIL", ) finally: _cleanup_sockets(*client_sockets, server_socket) class TestFrameCaching(RemoteInspectionTestBase): """Test that frame caching produces correct results. Uses socket-based synchronization for deterministic testing. All tests verify cache reuse via object identity checks (assertIs). """ @contextmanager def _target_process(self, script_body): """Context manager for running a target process with socket sync.""" port = find_unused_port() script = f"""\ import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) {textwrap.dedent(script_body)} """ with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None def make_unwinder(cache_frames=True): return RemoteUnwinder( p.pid, all_threads=True, cache_frames=cache_frames ) yield p, client_socket, make_unwinder except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) finally: _cleanup_sockets(client_socket, server_socket) def _get_frames_with_retry(self, unwinder, required_funcs): """Get frames containing required_funcs, with retry for transient errors.""" for _ in range(MAX_TRIES): with contextlib.suppress(OSError, RuntimeError): traces = unwinder.get_stack_trace() for interp in traces: for thread in interp.threads: funcs = {f.funcname for f in thread.frame_info} if required_funcs.issubset(funcs): return thread.frame_info time.sleep(0.1) return None def _sample_frames( self, client_socket, unwinder, wait_signal, send_ack, required_funcs, expected_frames=1, ): """Wait for signal, sample frames with retry until required funcs present, send ack.""" _wait_for_signal(client_socket, wait_signal) frames = None for _ in range(MAX_TRIES): frames = self._get_frames_with_retry(unwinder, required_funcs) if frames and len(frames) >= expected_frames: break time.sleep(0.1) client_socket.sendall(send_ack) return frames @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_cache_hit_same_stack(self): """Test that consecutive samples reuse cached parent frame objects. The current frame (index 0) is always re-read from memory to get updated line numbers, so it may be a different object. Parent frames (index 1+) should be identical objects from cache. """ script_body = """\ def level3(): sock.sendall(b"sync1") sock.recv(16) sock.sendall(b"sync2") sock.recv(16) sock.sendall(b"sync3") sock.recv(16) def level2(): level3() def level1(): level2() level1() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) expected = {"level1", "level2", "level3"} frames1 = self._sample_frames( client_socket, unwinder, b"sync1", b"ack", expected ) frames2 = self._sample_frames( client_socket, unwinder, b"sync2", b"ack", expected ) frames3 = self._sample_frames( client_socket, unwinder, b"sync3", b"done", expected ) self.assertIsNotNone(frames1) self.assertIsNotNone(frames2) self.assertIsNotNone(frames3) self.assertEqual(len(frames1), len(frames2)) self.assertEqual(len(frames2), len(frames3)) # Current frame (index 0) is always re-read, so check value equality self.assertEqual(frames1[0].funcname, frames2[0].funcname) self.assertEqual(frames2[0].funcname, frames3[0].funcname) # Parent frames (index 1+) must be identical objects (cache reuse) for i in range(1, len(frames1)): f1, f2, f3 = frames1[i], frames2[i], frames3[i] self.assertIs( f1, f2, f"Frame {i}: samples 1-2 must be same object" ) self.assertIs( f2, f3, f"Frame {i}: samples 2-3 must be same object" ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_line_number_updates_in_same_frame(self): """Test that line numbers are correctly updated when execution moves within a function. When the profiler samples at different points within the same function, it must report the correct line number for each sample, not stale cached values. """ script_body = """\ def outer(): inner() def inner(): sock.sendall(b"line_a"); sock.recv(16) sock.sendall(b"line_b"); sock.recv(16) sock.sendall(b"line_c"); sock.recv(16) sock.sendall(b"line_d"); sock.recv(16) outer() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) frames_a = self._sample_frames( client_socket, unwinder, b"line_a", b"ack", {"inner"} ) frames_b = self._sample_frames( client_socket, unwinder, b"line_b", b"ack", {"inner"} ) frames_c = self._sample_frames( client_socket, unwinder, b"line_c", b"ack", {"inner"} ) frames_d = self._sample_frames( client_socket, unwinder, b"line_d", b"done", {"inner"} ) self.assertIsNotNone(frames_a) self.assertIsNotNone(frames_b) self.assertIsNotNone(frames_c) self.assertIsNotNone(frames_d) # Get the 'inner' frame from each sample (should be index 0) inner_a = frames_a[0] inner_b = frames_b[0] inner_c = frames_c[0] inner_d = frames_d[0] self.assertEqual(inner_a.funcname, "inner") self.assertEqual(inner_b.funcname, "inner") self.assertEqual(inner_c.funcname, "inner") self.assertEqual(inner_d.funcname, "inner") # Line numbers must be different and increasing (execution moves forward) self.assertLess( inner_a.lineno, inner_b.lineno, "Line B should be after line A" ) self.assertLess( inner_b.lineno, inner_c.lineno, "Line C should be after line B" ) self.assertLess( inner_c.lineno, inner_d.lineno, "Line D should be after line C" ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_cache_invalidation_on_return(self): """Test cache invalidation when stack shrinks (function returns).""" script_body = """\ def inner(): sock.sendall(b"at_inner") sock.recv(16) def outer(): inner() sock.sendall(b"at_outer") sock.recv(16) outer() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) frames_deep = self._sample_frames( client_socket, unwinder, b"at_inner", b"ack", {"inner", "outer"}, ) frames_shallow = self._sample_frames( client_socket, unwinder, b"at_outer", b"done", {"outer"} ) self.assertIsNotNone(frames_deep) self.assertIsNotNone(frames_shallow) funcs_deep = [f.funcname for f in frames_deep] funcs_shallow = [f.funcname for f in frames_shallow] self.assertIn("inner", funcs_deep) self.assertIn("outer", funcs_deep) self.assertNotIn("inner", funcs_shallow) self.assertIn("outer", funcs_shallow) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_cache_invalidation_on_call(self): """Test cache invalidation when stack grows (new function called).""" script_body = """\ def deeper(): sock.sendall(b"at_deeper") sock.recv(16) def middle(): sock.sendall(b"at_middle") sock.recv(16) deeper() def top(): middle() top() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) frames_before = self._sample_frames( client_socket, unwinder, b"at_middle", b"ack", {"middle", "top"}, ) frames_after = self._sample_frames( client_socket, unwinder, b"at_deeper", b"done", {"deeper", "middle", "top"}, ) self.assertIsNotNone(frames_before) self.assertIsNotNone(frames_after) funcs_before = [f.funcname for f in frames_before] funcs_after = [f.funcname for f in frames_after] self.assertIn("middle", funcs_before) self.assertIn("top", funcs_before) self.assertNotIn("deeper", funcs_before) self.assertIn("deeper", funcs_after) self.assertIn("middle", funcs_after) self.assertIn("top", funcs_after) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_partial_stack_reuse(self): """Test that unchanged bottom frames are reused when top changes (A→B→C to A→B→D).""" script_body = """\ def func_c(): sock.sendall(b"at_c") sock.recv(16) def func_d(): sock.sendall(b"at_d") sock.recv(16) def func_b(): func_c() func_d() def func_a(): func_b() func_a() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) # Sample at C: stack is A→B→C frames_c = self._sample_frames( client_socket, unwinder, b"at_c", b"ack", {"func_a", "func_b", "func_c"}, ) # Sample at D: stack is A→B→D (C returned, D called) frames_d = self._sample_frames( client_socket, unwinder, b"at_d", b"done", {"func_a", "func_b", "func_d"}, ) self.assertIsNotNone(frames_c) self.assertIsNotNone(frames_d) # Find func_a and func_b frames in both samples def find_frame(frames, funcname): for f in frames: if f.funcname == funcname: return f return None frame_a_in_c = find_frame(frames_c, "func_a") frame_b_in_c = find_frame(frames_c, "func_b") frame_a_in_d = find_frame(frames_d, "func_a") frame_b_in_d = find_frame(frames_d, "func_b") self.assertIsNotNone(frame_a_in_c) self.assertIsNotNone(frame_b_in_c) self.assertIsNotNone(frame_a_in_d) self.assertIsNotNone(frame_b_in_d) # The bottom frames (A, B) should be the SAME objects (cache reuse) self.assertIs( frame_a_in_c, frame_a_in_d, "func_a frame should be reused from cache", ) self.assertIs( frame_b_in_c, frame_b_in_d, "func_b frame should be reused from cache", ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_recursive_frames(self): """Test caching with same function appearing multiple times (recursion).""" script_body = """\ def recurse(n): if n <= 0: sock.sendall(b"sync1") sock.recv(16) sock.sendall(b"sync2") sock.recv(16) else: recurse(n - 1) recurse(5) """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) frames1 = self._sample_frames( client_socket, unwinder, b"sync1", b"ack", {"recurse"} ) frames2 = self._sample_frames( client_socket, unwinder, b"sync2", b"done", {"recurse"} ) self.assertIsNotNone(frames1) self.assertIsNotNone(frames2) # Should have multiple "recurse" frames (6 total: recurse(5) down to recurse(0)) recurse_count = sum(1 for f in frames1 if f.funcname == "recurse") self.assertEqual(recurse_count, 6, "Should have 6 recursive frames") self.assertEqual(len(frames1), len(frames2)) # Current frame (index 0) is re-read, check value equality self.assertEqual(frames1[0].funcname, frames2[0].funcname) # Parent frames (index 1+) should be identical objects (cache reuse) for i in range(1, len(frames1)): self.assertIs( frames1[i], frames2[i], f"Frame {i}: recursive frames must be same object", ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_cache_vs_no_cache_equivalence(self): """Test that cache_frames=True and cache_frames=False produce equivalent results.""" script_body = """\ def level3(): sock.sendall(b"ready"); sock.recv(16) def level2(): level3() def level1(): level2() level1() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): _wait_for_signal(client_socket, b"ready") # Sample with cache unwinder_cache = make_unwinder(cache_frames=True) frames_cached = self._get_frames_with_retry( unwinder_cache, {"level1", "level2", "level3"} ) # Sample without cache unwinder_no_cache = make_unwinder(cache_frames=False) frames_no_cache = self._get_frames_with_retry( unwinder_no_cache, {"level1", "level2", "level3"} ) client_socket.sendall(b"done") self.assertIsNotNone(frames_cached) self.assertIsNotNone(frames_no_cache) # Same number of frames self.assertEqual(len(frames_cached), len(frames_no_cache)) # Same function names in same order funcs_cached = [f.funcname for f in frames_cached] funcs_no_cache = [f.funcname for f in frames_no_cache] self.assertEqual(funcs_cached, funcs_no_cache) # Same line numbers lines_cached = [f.lineno for f in frames_cached] lines_no_cache = [f.lineno for f in frames_no_cache] self.assertEqual(lines_cached, lines_no_cache) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_cache_per_thread_isolation(self): """Test that frame cache is per-thread and cache invalidation works independently.""" script_body = """\ import threading lock = threading.Lock() def sync(msg): with lock: sock.sendall(msg + b"\\n") sock.recv(1) # Thread 1 functions def baz1(): sync(b"t1:baz1") def bar1(): baz1() def blech1(): sync(b"t1:blech1") def foo1(): bar1() # Goes down to baz1, syncs blech1() # Returns up, goes down to blech1, syncs # Thread 2 functions def baz2(): sync(b"t2:baz2") def bar2(): baz2() def blech2(): sync(b"t2:blech2") def foo2(): bar2() # Goes down to baz2, syncs blech2() # Returns up, goes down to blech2, syncs t1 = threading.Thread(target=foo1) t2 = threading.Thread(target=foo2) t1.start() t2.start() t1.join() t2.join() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) # Message dispatch table: signal -> required functions for that thread dispatch = { b"t1:baz1": {"baz1", "bar1", "foo1"}, b"t2:baz2": {"baz2", "bar2", "foo2"}, b"t1:blech1": {"blech1", "foo1"}, b"t2:blech2": {"blech2", "foo2"}, } # Track results for each sync point results = {} # Process 4 sync points (order depends on thread scheduling) buffer = _wait_for_signal(client_socket, b"\n") for i in range(4): # Extract first message from buffer msg, sep, buffer = buffer.partition(b"\n") self.assertIn(msg, dispatch, f"Unexpected message: {msg!r}") # Sample frames for the thread at this sync point required_funcs = dispatch[msg] frames = self._get_frames_with_retry(unwinder, required_funcs) self.assertIsNotNone(frames, f"Thread not found for {msg!r}") results[msg] = [f.funcname for f in frames] # Release thread and wait for next message (if not last) client_socket.sendall(b"k") if i < 3: buffer += _wait_for_signal(client_socket, b"\n") # Validate Phase 1: baz snapshots t1_baz = results.get(b"t1:baz1") t2_baz = results.get(b"t2:baz2") self.assertIsNotNone(t1_baz, "Missing t1:baz1 snapshot") self.assertIsNotNone(t2_baz, "Missing t2:baz2 snapshot") # Thread 1 at baz1: should have foo1->bar1->baz1 self.assertIn("baz1", t1_baz) self.assertIn("bar1", t1_baz) self.assertIn("foo1", t1_baz) self.assertNotIn("blech1", t1_baz) # No cross-contamination self.assertNotIn("baz2", t1_baz) self.assertNotIn("bar2", t1_baz) self.assertNotIn("foo2", t1_baz) # Thread 2 at baz2: should have foo2->bar2->baz2 self.assertIn("baz2", t2_baz) self.assertIn("bar2", t2_baz) self.assertIn("foo2", t2_baz) self.assertNotIn("blech2", t2_baz) # No cross-contamination self.assertNotIn("baz1", t2_baz) self.assertNotIn("bar1", t2_baz) self.assertNotIn("foo1", t2_baz) # Validate Phase 2: blech snapshots (cache invalidation test) t1_blech = results.get(b"t1:blech1") t2_blech = results.get(b"t2:blech2") self.assertIsNotNone(t1_blech, "Missing t1:blech1 snapshot") self.assertIsNotNone(t2_blech, "Missing t2:blech2 snapshot") # Thread 1 at blech1: bar1/baz1 should be GONE (cache invalidated) self.assertIn("blech1", t1_blech) self.assertIn("foo1", t1_blech) self.assertNotIn( "bar1", t1_blech, "Cache not invalidated: bar1 still present" ) self.assertNotIn( "baz1", t1_blech, "Cache not invalidated: baz1 still present" ) # No cross-contamination self.assertNotIn("blech2", t1_blech) # Thread 2 at blech2: bar2/baz2 should be GONE (cache invalidated) self.assertIn("blech2", t2_blech) self.assertIn("foo2", t2_blech) self.assertNotIn( "bar2", t2_blech, "Cache not invalidated: bar2 still present" ) self.assertNotIn( "baz2", t2_blech, "Cache not invalidated: baz2 still present" ) # No cross-contamination self.assertNotIn("blech1", t2_blech) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_new_unwinder_with_stale_last_profiled_frame(self): """Test that a new unwinder returns complete stack when cache lookup misses.""" script_body = """\ def level4(): sock.sendall(b"sync1") sock.recv(16) sock.sendall(b"sync2") sock.recv(16) def level3(): level4() def level2(): level3() def level1(): level2() level1() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): expected = {"level1", "level2", "level3", "level4"} # First unwinder samples - this sets last_profiled_frame in target unwinder1 = make_unwinder(cache_frames=True) frames1 = self._sample_frames( client_socket, unwinder1, b"sync1", b"ack", expected ) # Create NEW unwinder (empty cache) and sample # The target still has last_profiled_frame set from unwinder1 unwinder2 = make_unwinder(cache_frames=True) frames2 = self._sample_frames( client_socket, unwinder2, b"sync2", b"done", expected ) self.assertIsNotNone(frames1) self.assertIsNotNone(frames2) funcs1 = [f.funcname for f in frames1] funcs2 = [f.funcname for f in frames2] # Both should have all levels for level in ["level1", "level2", "level3", "level4"]: self.assertIn(level, funcs1, f"{level} missing from first sample") self.assertIn(level, funcs2, f"{level} missing from second sample") # Should have same stack depth self.assertEqual( len(frames1), len(frames2), "New unwinder should return complete stack despite stale last_profiled_frame", ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_cache_exhaustion(self): """Test cache works when frame limit (1024) is exceeded. FRAME_CACHE_MAX_FRAMES=1024. With 1100 recursive frames, the cache can't store all of them but should still work. """ # Use 1100 to exceed FRAME_CACHE_MAX_FRAMES=1024 depth = 1100 script_body = f"""\ import sys sys.setrecursionlimit(2000) def recurse(n): if n <= 0: sock.sendall(b"ready") sock.recv(16) # wait for ack sock.sendall(b"ready2") sock.recv(16) # wait for done return recurse(n - 1) recurse({depth}) """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder_cache = make_unwinder(cache_frames=True) unwinder_no_cache = make_unwinder(cache_frames=False) frames_cached = self._sample_frames( client_socket, unwinder_cache, b"ready", b"ack", {"recurse"}, expected_frames=1102, ) # Sample again with no cache for comparison frames_no_cache = self._sample_frames( client_socket, unwinder_no_cache, b"ready2", b"done", {"recurse"}, expected_frames=1102, ) self.assertIsNotNone(frames_cached) self.assertIsNotNone(frames_no_cache) # Both should have many recurse frames (> 1024 limit) cached_count = [f.funcname for f in frames_cached].count("recurse") no_cache_count = [f.funcname for f in frames_no_cache].count("recurse") self.assertGreater( cached_count, 1000, "Should have >1000 recurse frames" ) self.assertGreater( no_cache_count, 1000, "Should have >1000 recurse frames" ) # Both modes should produce same frame count self.assertEqual( len(frames_cached), len(frames_no_cache), "Cache exhaustion should not affect stack completeness", ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_get_stats(self): """Test that get_stats() returns statistics when stats=True.""" script_body = """\ sock.sendall(b"ready") sock.recv(16) """ with self._target_process(script_body) as (p, client_socket, _): unwinder = RemoteUnwinder(p.pid, all_threads=True, stats=True) _wait_for_signal(client_socket, b"ready") # Take a sample unwinder.get_stack_trace() stats = unwinder.get_stats() client_socket.sendall(b"done") # Verify expected keys exist expected_keys = [ "total_samples", "frame_cache_hits", "frame_cache_misses", "frame_cache_partial_hits", "frames_read_from_cache", "frames_read_from_memory", "frame_cache_hit_rate", ] for key in expected_keys: self.assertIn(key, stats) self.assertEqual(stats["total_samples"], 1) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_get_stats_disabled_raises(self): """Test that get_stats() raises RuntimeError when stats=False.""" script_body = """\ sock.sendall(b"ready") sock.recv(16) """ with self._target_process(script_body) as (p, client_socket, _): unwinder = RemoteUnwinder( p.pid, all_threads=True ) # stats=False by default _wait_for_signal(client_socket, b"ready") with self.assertRaises(RuntimeError): unwinder.get_stats() client_socket.sendall(b"done") if __name__ == "__main__": unittest.main()