import contextlib import unittest import os import textwrap import importlib import sys import socket import threading import time from contextlib import contextmanager from asyncio import staggered, taskgroups, base_events, tasks from unittest.mock import ANY from test.support import ( os_helper, SHORT_TIMEOUT, busy_retry, requires_gil_enabled, ) from test.support.script_helper import make_script from test.support.socket_helper import find_unused_port import subprocess # Profiling mode constants PROFILING_MODE_WALL = 0 PROFILING_MODE_CPU = 1 PROFILING_MODE_GIL = 2 PROFILING_MODE_ALL = 3 # Thread status flags THREAD_STATUS_HAS_GIL = 1 << 0 THREAD_STATUS_ON_CPU = 1 << 1 THREAD_STATUS_UNKNOWN = 1 << 2 # Maximum number of retry attempts for operations that may fail transiently MAX_TRIES = 10 try: from concurrent import interpreters except ImportError: interpreters = None PROCESS_VM_READV_SUPPORTED = False try: from _remote_debugging import PROCESS_VM_READV_SUPPORTED from _remote_debugging import RemoteUnwinder from _remote_debugging import FrameInfo, CoroInfo, TaskInfo except ImportError: raise unittest.SkipTest( "Test only runs when _remote_debugging is available" ) # ============================================================================ # Module-level helper functions # ============================================================================ def _make_test_script(script_dir, script_basename, source): to_return = make_script(script_dir, script_basename, source) importlib.invalidate_caches() return to_return def _create_server_socket(port, backlog=1): """Create and configure a server socket for test communication.""" server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(backlog) return server_socket def _wait_for_signal(sock, expected_signals, timeout=SHORT_TIMEOUT): """ Wait for expected signal(s) from a socket with proper timeout and EOF handling. Args: sock: Connected socket to read from expected_signals: Single bytes object or list of bytes objects to wait for timeout: Socket timeout in seconds Returns: bytes: Complete accumulated response buffer Raises: RuntimeError: If connection closed before signal received or timeout """ if isinstance(expected_signals, bytes): expected_signals = [expected_signals] sock.settimeout(timeout) buffer = b"" while True: # Check if all expected signals are in buffer if all(sig in buffer for sig in expected_signals): return buffer try: chunk = sock.recv(4096) if not chunk: # EOF - connection closed raise RuntimeError( f"Connection closed before receiving expected signals. " f"Expected: {expected_signals}, Got: {buffer[-200:]!r}" ) buffer += chunk except socket.timeout: raise RuntimeError( f"Timeout waiting for signals. " f"Expected: {expected_signals}, Got: {buffer[-200:]!r}" ) def _wait_for_n_signals(sock, signal_pattern, count, timeout=SHORT_TIMEOUT): """ Wait for N occurrences of a signal pattern. Args: sock: Connected socket to read from signal_pattern: bytes pattern to count (e.g., b"ready") count: Number of occurrences expected timeout: Socket timeout in seconds Returns: bytes: Complete accumulated response buffer Raises: RuntimeError: If connection closed or timeout before receiving all signals """ sock.settimeout(timeout) buffer = b"" found_count = 0 while found_count < count: try: chunk = sock.recv(4096) if not chunk: raise RuntimeError( f"Connection closed after {found_count}/{count} signals. " f"Last 200 bytes: {buffer[-200:]!r}" ) buffer += chunk # Count occurrences in entire buffer found_count = buffer.count(signal_pattern) except socket.timeout: raise RuntimeError( f"Timeout waiting for {count} signals (found {found_count}). " f"Last 200 bytes: {buffer[-200:]!r}" ) return buffer @contextmanager def _managed_subprocess(args, timeout=SHORT_TIMEOUT): """ Context manager for subprocess lifecycle management. Ensures process is properly terminated and cleaned up even on exceptions. Uses graceful termination first, then forceful kill if needed. """ p = subprocess.Popen(args) try: yield p finally: try: p.terminate() try: p.wait(timeout=timeout) except subprocess.TimeoutExpired: p.kill() try: p.wait(timeout=timeout) except subprocess.TimeoutExpired: pass # Process refuses to die, nothing more we can do except OSError: pass # Process already dead def _cleanup_sockets(*sockets): """Safely close multiple sockets, ignoring errors.""" for sock in sockets: if sock is not None: try: sock.close() except OSError: pass # ============================================================================ # Decorators and skip conditions # ============================================================================ skip_if_not_supported = unittest.skipIf( ( sys.platform != "darwin" and sys.platform != "linux" and sys.platform != "win32" ), "Test only runs on Linux, Windows and MacOS", ) def requires_subinterpreters(meth): """Decorator to skip a test if subinterpreters are not supported.""" return unittest.skipIf(interpreters is None, "subinterpreters required")( meth ) # ============================================================================ # Simple wrapper functions for RemoteUnwinder # ============================================================================ # Errors that can occur transiently when reading process memory without synchronization RETRIABLE_ERRORS = ( "Task list appears corrupted", "Invalid linked list structure reading remote memory", "Unknown error reading memory", "Unhandled frame owner", "Failed to parse initial frame", "Failed to process frame chain", "Failed to unwind stack", ) def _is_retriable_error(exc): """Check if an exception is a transient error that should be retried.""" msg = str(exc) return any(msg.startswith(err) or err in msg for err in RETRIABLE_ERRORS) def get_stack_trace(pid): for _ in busy_retry(SHORT_TIMEOUT): try: unwinder = RemoteUnwinder(pid, all_threads=True, debug=True) return unwinder.get_stack_trace() except RuntimeError as e: if _is_retriable_error(e): continue raise raise RuntimeError("Failed to get stack trace after retries") def get_async_stack_trace(pid): for _ in busy_retry(SHORT_TIMEOUT): try: unwinder = RemoteUnwinder(pid, debug=True) return unwinder.get_async_stack_trace() except RuntimeError as e: if _is_retriable_error(e): continue raise raise RuntimeError("Failed to get async stack trace after retries") def get_all_awaited_by(pid): for _ in busy_retry(SHORT_TIMEOUT): try: unwinder = RemoteUnwinder(pid, debug=True) return unwinder.get_all_awaited_by() except RuntimeError as e: if _is_retriable_error(e): continue raise raise RuntimeError("Failed to get all awaited_by after retries") # ============================================================================ # Base test class with shared infrastructure # ============================================================================ class RemoteInspectionTestBase(unittest.TestCase): """Base class for remote inspection tests with common helpers.""" maxDiff = None def _run_script_and_get_trace( self, script, trace_func, wait_for_signals=None, port=None, backlog=1, ): """ Common pattern: run a script, wait for signals, get trace. Args: script: Script content (will be formatted with port if {port} present) trace_func: Function to call with pid to get trace (e.g., get_stack_trace) wait_for_signals: Signal(s) to wait for before getting trace port: Port to use (auto-selected if None) backlog: Socket listen backlog Returns: tuple: (trace_result, script_name) """ if port is None: port = find_unused_port() # Format script with port if needed if "{port}" in script or "{{port}}" in script: script = script.replace("{{port}}", "{port}").format(port=port) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port, backlog) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None if wait_for_signals: _wait_for_signal(client_socket, wait_for_signals) try: trace = trace_func(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) return trace, script_name finally: _cleanup_sockets(client_socket, server_socket) def _find_frame_in_trace(self, stack_trace, predicate): """ Find a frame matching predicate in stack trace. Args: stack_trace: List of InterpreterInfo objects predicate: Function(frame) -> bool Returns: FrameInfo or None """ for interpreter_info in stack_trace: for thread_info in interpreter_info.threads: for frame in thread_info.frame_info: if predicate(frame): return frame return None def _find_thread_by_id(self, stack_trace, thread_id): """Find a thread by its native thread ID.""" for interpreter_info in stack_trace: for thread_info in interpreter_info.threads: if thread_info.thread_id == thread_id: return thread_info return None def _find_thread_with_frame(self, stack_trace, frame_predicate): """Find a thread containing a frame matching predicate.""" for interpreter_info in stack_trace: for thread_info in interpreter_info.threads: for frame in thread_info.frame_info: if frame_predicate(frame): return thread_info return None def _get_thread_statuses(self, stack_trace): """Extract thread_id -> status mapping from stack trace.""" statuses = {} for interpreter_info in stack_trace: for thread_info in interpreter_info.threads: statuses[thread_info.thread_id] = thread_info.status return statuses def _get_task_id_map(self, stack_trace): """Create task_id -> task mapping from async stack trace.""" return {task.task_id: task for task in stack_trace[0].awaited_by} def _get_awaited_by_relationships(self, stack_trace): """Extract task name to awaited_by set mapping.""" id_to_task = self._get_task_id_map(stack_trace) return { task.task_name: set( id_to_task[awaited.task_name].task_name for awaited in task.awaited_by ) for task in stack_trace[0].awaited_by } def _extract_coroutine_stacks(self, stack_trace): """Extract and format coroutine stacks from tasks.""" return { task.task_name: sorted( tuple(tuple(frame) for frame in coro.call_stack) for coro in task.coroutine_stack ) for task in stack_trace[0].awaited_by } # ============================================================================ # Test classes # ============================================================================ class TestGetStackTrace(RemoteInspectionTestBase): @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_remote_stack_trace(self): port = find_unused_port() script = textwrap.dedent( f"""\ import time, sys, socket, threading sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def bar(): for x in range(100): if x == 50: baz() def baz(): foo() def foo(): sock.sendall(b"ready:thread\\n"); time.sleep(10_000) t = threading.Thread(target=bar) t.start() sock.sendall(b"ready:main\\n"); t.join() """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None _wait_for_signal( client_socket, [b"ready:main", b"ready:thread"] ) try: stack_trace = get_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) thread_expected_stack_trace = [ FrameInfo([script_name, 15, "foo"]), FrameInfo([script_name, 12, "baz"]), FrameInfo([script_name, 9, "bar"]), FrameInfo([threading.__file__, ANY, "Thread.run"]), FrameInfo( [ threading.__file__, ANY, "Thread._bootstrap_inner", ] ), FrameInfo( [threading.__file__, ANY, "Thread._bootstrap"] ), ] # Find expected thread stack found_thread = self._find_thread_with_frame( stack_trace, lambda f: f.funcname == "foo" and f.lineno == 15, ) self.assertIsNotNone( found_thread, "Expected thread stack trace not found" ) self.assertEqual( found_thread.frame_info, thread_expected_stack_trace ) # Check main thread main_frame = FrameInfo([script_name, 19, ""]) found_main = self._find_frame_in_trace( stack_trace, lambda f: f == main_frame ) self.assertIsNotNone( found_main, "Main thread stack trace not found" ) finally: _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_async_remote_stack_trace(self): port = find_unused_port() script = textwrap.dedent( f"""\ import asyncio import time import sys import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def c5(): sock.sendall(b"ready"); time.sleep(10_000) async def c4(): await asyncio.sleep(0) c5() async def c3(): await c4() async def c2(): await c3() async def c1(task): await task async def main(): async with asyncio.TaskGroup() as tg: task = tg.create_task(c2(), name="c2_root") tg.create_task(c1(task), name="sub_main_1") tg.create_task(c1(task), name="sub_main_2") def new_eager_loop(): loop = asyncio.new_event_loop() eager_task_factory = asyncio.create_eager_task_factory( asyncio.Task) loop.set_task_factory(eager_task_factory) return loop asyncio.run(main(), loop_factory={{TASK_FACTORY}}) """ ) for task_factory_variant in "asyncio.new_event_loop", "new_eager_loop": with ( self.subTest(task_factory_variant=task_factory_variant), os_helper.temp_dir() as work_dir, ): script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script( script_dir, "script", script.format(TASK_FACTORY=task_factory_variant), ) client_socket = None try: with _managed_subprocess( [sys.executable, script_name] ) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None response = _wait_for_signal(client_socket, b"ready") self.assertIn(b"ready", response) try: stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Check all tasks are present tasks_names = [ task.task_name for task in stack_trace[0].awaited_by ] for task_name in [ "c2_root", "sub_main_1", "sub_main_2", ]: self.assertIn(task_name, tasks_names) # Check awaited_by relationships relationships = self._get_awaited_by_relationships( stack_trace ) self.assertEqual( relationships, { "c2_root": { "Task-1", "sub_main_1", "sub_main_2", }, "Task-1": set(), "sub_main_1": {"Task-1"}, "sub_main_2": {"Task-1"}, }, ) # Check coroutine stacks coroutine_stacks = self._extract_coroutine_stacks( stack_trace ) self.assertEqual( coroutine_stacks, { "Task-1": [ ( tuple( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), tuple( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), tuple([script_name, 26, "main"]), ) ], "c2_root": [ ( tuple([script_name, 10, "c5"]), tuple([script_name, 14, "c4"]), tuple([script_name, 17, "c3"]), tuple([script_name, 20, "c2"]), ) ], "sub_main_1": [ (tuple([script_name, 23, "c1"]),) ], "sub_main_2": [ (tuple([script_name, 23, "c1"]),) ], }, ) # Check awaited_by coroutine stacks id_to_task = self._get_task_id_map(stack_trace) awaited_by_coroutine_stacks = { task.task_name: sorted( ( id_to_task[coro.task_name].task_name, tuple( tuple(frame) for frame in coro.call_stack ), ) for coro in task.awaited_by ) for task in stack_trace[0].awaited_by } self.assertEqual( awaited_by_coroutine_stacks, { "Task-1": [], "c2_root": [ ( "Task-1", ( tuple( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), tuple( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), tuple([script_name, 26, "main"]), ), ), ( "sub_main_1", (tuple([script_name, 23, "c1"]),), ), ( "sub_main_2", (tuple([script_name, 23, "c1"]),), ), ], "sub_main_1": [ ( "Task-1", ( tuple( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), tuple( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), tuple([script_name, 26, "main"]), ), ) ], "sub_main_2": [ ( "Task-1", ( tuple( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), tuple( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), tuple([script_name, 26, "main"]), ), ) ], }, ) finally: _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_asyncgen_remote_stack_trace(self): port = find_unused_port() script = textwrap.dedent( f"""\ import asyncio import time import sys import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) async def gen_nested_call(): sock.sendall(b"ready"); time.sleep(10_000) async def gen(): for num in range(2): yield num if num == 1: await gen_nested_call() async def main(): async for el in gen(): pass asyncio.run(main()) """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None response = _wait_for_signal(client_socket, b"ready") self.assertIn(b"ready", response) try: stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # For this simple asyncgen test, we only expect one task self.assertEqual(len(stack_trace[0].awaited_by), 1) task = stack_trace[0].awaited_by[0] self.assertEqual(task.task_name, "Task-1") # Check the coroutine stack coroutine_stack = sorted( tuple(tuple(frame) for frame in coro.call_stack) for coro in task.coroutine_stack ) self.assertEqual( coroutine_stack, [ ( tuple([script_name, 10, "gen_nested_call"]), tuple([script_name, 16, "gen"]), tuple([script_name, 19, "main"]), ) ], ) # No awaited_by relationships expected self.assertEqual(task.awaited_by, []) finally: _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_async_gather_remote_stack_trace(self): port = find_unused_port() script = textwrap.dedent( f"""\ import asyncio import time import sys import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) async def deep(): await asyncio.sleep(0) sock.sendall(b"ready"); time.sleep(10_000) async def c1(): await asyncio.sleep(0) await deep() async def c2(): await asyncio.sleep(0) async def main(): await asyncio.gather(c1(), c2()) asyncio.run(main()) """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None response = _wait_for_signal(client_socket, b"ready") self.assertIn(b"ready", response) try: stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Check all tasks are present tasks_names = [ task.task_name for task in stack_trace[0].awaited_by ] for task_name in ["Task-1", "Task-2"]: self.assertIn(task_name, tasks_names) # Check awaited_by relationships relationships = self._get_awaited_by_relationships( stack_trace ) self.assertEqual( relationships, { "Task-1": set(), "Task-2": {"Task-1"}, }, ) # Check coroutine stacks coroutine_stacks = self._extract_coroutine_stacks( stack_trace ) self.assertEqual( coroutine_stacks, { "Task-1": [(tuple([script_name, 21, "main"]),)], "Task-2": [ ( tuple([script_name, 11, "deep"]), tuple([script_name, 15, "c1"]), ) ], }, ) # Check awaited_by coroutine stacks id_to_task = self._get_task_id_map(stack_trace) awaited_by_coroutine_stacks = { task.task_name: sorted( ( id_to_task[coro.task_name].task_name, tuple( tuple(frame) for frame in coro.call_stack ), ) for coro in task.awaited_by ) for task in stack_trace[0].awaited_by } self.assertEqual( awaited_by_coroutine_stacks, { "Task-1": [], "Task-2": [ ("Task-1", (tuple([script_name, 21, "main"]),)) ], }, ) finally: _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_async_staggered_race_remote_stack_trace(self): port = find_unused_port() script = textwrap.dedent( f"""\ import asyncio.staggered import time import sys import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) async def deep(): await asyncio.sleep(0) sock.sendall(b"ready"); time.sleep(10_000) async def c1(): await asyncio.sleep(0) await deep() async def c2(): await asyncio.sleep(10_000) async def main(): await asyncio.staggered.staggered_race( [c1, c2], delay=None, ) asyncio.run(main()) """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None response = _wait_for_signal(client_socket, b"ready") self.assertIn(b"ready", response) try: stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Check all tasks are present tasks_names = [ task.task_name for task in stack_trace[0].awaited_by ] for task_name in ["Task-1", "Task-2"]: self.assertIn(task_name, tasks_names) # Check awaited_by relationships relationships = self._get_awaited_by_relationships( stack_trace ) self.assertEqual( relationships, { "Task-1": set(), "Task-2": {"Task-1"}, }, ) # Check coroutine stacks coroutine_stacks = self._extract_coroutine_stacks( stack_trace ) self.assertEqual( coroutine_stacks, { "Task-1": [ ( tuple( [ staggered.__file__, ANY, "staggered_race", ] ), tuple([script_name, 21, "main"]), ) ], "Task-2": [ ( tuple([script_name, 11, "deep"]), tuple([script_name, 15, "c1"]), tuple( [ staggered.__file__, ANY, "staggered_race..run_one_coro", ] ), ) ], }, ) # Check awaited_by coroutine stacks id_to_task = self._get_task_id_map(stack_trace) awaited_by_coroutine_stacks = { task.task_name: sorted( ( id_to_task[coro.task_name].task_name, tuple( tuple(frame) for frame in coro.call_stack ), ) for coro in task.awaited_by ) for task in stack_trace[0].awaited_by } self.assertEqual( awaited_by_coroutine_stacks, { "Task-1": [], "Task-2": [ ( "Task-1", ( tuple( [ staggered.__file__, ANY, "staggered_race", ] ), tuple([script_name, 21, "main"]), ), ) ], }, ) finally: _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_async_global_awaited_by(self): # Reduced from 1000 to 100 to avoid file descriptor exhaustion # when running tests in parallel (e.g., -j 20) NUM_TASKS = 100 port = find_unused_port() script = textwrap.dedent( f"""\ import asyncio import os import random import sys import socket from string import ascii_lowercase, digits from test.support import socket_helper, SHORT_TIMEOUT HOST = '127.0.0.1' PORT = socket_helper.find_unused_port() connections = 0 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) class EchoServerProtocol(asyncio.Protocol): def connection_made(self, transport): global connections connections += 1 self.transport = transport def data_received(self, data): self.transport.write(data) self.transport.close() async def echo_client(message): reader, writer = await asyncio.open_connection(HOST, PORT) writer.write(message.encode()) await writer.drain() data = await reader.read(100) assert message == data.decode() writer.close() await writer.wait_closed() sock.sendall(b"ready") await asyncio.sleep(SHORT_TIMEOUT) async def echo_client_spam(server): async with asyncio.TaskGroup() as tg: while connections < {NUM_TASKS}: msg = list(ascii_lowercase + digits) random.shuffle(msg) tg.create_task(echo_client("".join(msg))) await asyncio.sleep(0) server.close() await server.wait_closed() async def main(): loop = asyncio.get_running_loop() server = await loop.create_server(EchoServerProtocol, HOST, PORT) async with server: async with asyncio.TaskGroup() as tg: tg.create_task(server.serve_forever(), name="server task") tg.create_task(echo_client_spam(server), name="echo client spam") asyncio.run(main()) """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None # Wait for NUM_TASKS "ready" signals try: _wait_for_n_signals(client_socket, b"ready", NUM_TASKS) except RuntimeError as e: self.fail(str(e)) try: all_awaited_by = get_all_awaited_by(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Expected: a list of two elements: 1 thread, 1 interp self.assertEqual(len(all_awaited_by), 2) # Expected: a tuple with the thread ID and the awaited_by list self.assertEqual(len(all_awaited_by[0]), 2) # Expected: no tasks in the fallback per-interp task list self.assertEqual(all_awaited_by[1], (0, [])) entries = all_awaited_by[0][1] # Expected: at least NUM_TASKS pending tasks self.assertGreaterEqual(len(entries), NUM_TASKS) # Check the main task structure main_stack = [ FrameInfo( [taskgroups.__file__, ANY, "TaskGroup._aexit"] ), FrameInfo( [taskgroups.__file__, ANY, "TaskGroup.__aexit__"] ), FrameInfo([script_name, 52, "main"]), ] self.assertIn( TaskInfo( [ANY, "Task-1", [CoroInfo([main_stack, ANY])], []] ), entries, ) self.assertIn( TaskInfo( [ ANY, "server task", [ CoroInfo( [ [ FrameInfo( [ base_events.__file__, ANY, "Server.serve_forever", ] ) ], ANY, ] ) ], [ CoroInfo( [ [ FrameInfo( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), FrameInfo( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), FrameInfo( [script_name, ANY, "main"] ), ], ANY, ] ) ], ] ), entries, ) self.assertIn( TaskInfo( [ ANY, "Task-4", [ CoroInfo( [ [ FrameInfo( [ tasks.__file__, ANY, "sleep", ] ), FrameInfo( [ script_name, 36, "echo_client", ] ), ], ANY, ] ) ], [ CoroInfo( [ [ FrameInfo( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), FrameInfo( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), FrameInfo( [ script_name, 39, "echo_client_spam", ] ), ], ANY, ] ) ], ] ), entries, ) expected_awaited_by = [ CoroInfo( [ [ FrameInfo( [ taskgroups.__file__, ANY, "TaskGroup._aexit", ] ), FrameInfo( [ taskgroups.__file__, ANY, "TaskGroup.__aexit__", ] ), FrameInfo( [script_name, 39, "echo_client_spam"] ), ], ANY, ] ) ] tasks_with_awaited = [ task for task in entries if task.awaited_by == expected_awaited_by ] self.assertGreaterEqual(len(tasks_with_awaited), NUM_TASKS) # Final task should be from echo client spam (not on Windows) if sys.platform != "win32": self.assertEqual( tasks_with_awaited[-1].awaited_by, entries[-1].awaited_by, ) finally: _cleanup_sockets(client_socket, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_self_trace(self): stack_trace = get_stack_trace(os.getpid()) this_thread_stack = None for interpreter_info in stack_trace: for thread_info in interpreter_info.threads: if thread_info.thread_id == threading.get_native_id(): this_thread_stack = thread_info.frame_info break if this_thread_stack: break self.assertIsNotNone(this_thread_stack) self.assertEqual( this_thread_stack[:2], [ FrameInfo( [ __file__, get_stack_trace.__code__.co_firstlineno + 4, "get_stack_trace", ] ), FrameInfo( [ __file__, self.test_self_trace.__code__.co_firstlineno + 6, "TestGetStackTrace.test_self_trace", ] ), ], ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) @requires_subinterpreters def test_subinterpreter_stack_trace(self): port = find_unused_port() import pickle subinterp_code = textwrap.dedent(f""" import socket import time def sub_worker(): def nested_func(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) sock.sendall(b"ready:sub\\n") time.sleep(10_000) nested_func() sub_worker() """).strip() pickled_code = pickle.dumps(subinterp_code) script = textwrap.dedent( f""" from concurrent import interpreters import time import sys import socket import threading sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def main_worker(): sock.sendall(b"ready:main\\n") time.sleep(10_000) def run_subinterp(): subinterp = interpreters.create() import pickle pickled_code = {pickled_code!r} subinterp_code = pickle.loads(pickled_code) subinterp.exec(subinterp_code) sub_thread = threading.Thread(target=run_subinterp) sub_thread.start() main_thread = threading.Thread(target=main_worker) main_thread.start() main_thread.join() sub_thread.join() """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_sockets = [] try: with _managed_subprocess([sys.executable, script_name]) as p: # Accept connections from both main and subinterpreter responses = set() while len(responses) < 2: try: client_socket, _ = server_socket.accept() client_sockets.append(client_socket) response = client_socket.recv(1024) if b"ready:main" in response: responses.add("main") if b"ready:sub" in response: responses.add("sub") except socket.timeout: break server_socket.close() server_socket = None try: stack_trace = get_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Verify we have at least one interpreter self.assertGreaterEqual(len(stack_trace), 1) # Look for main interpreter (ID 0) and subinterpreter (ID > 0) main_interp = None sub_interp = None for interpreter_info in stack_trace: if interpreter_info.interpreter_id == 0: main_interp = interpreter_info elif interpreter_info.interpreter_id > 0: sub_interp = interpreter_info self.assertIsNotNone( main_interp, "Main interpreter should be present" ) # Check main interpreter has expected stack trace main_found = self._find_frame_in_trace( [main_interp], lambda f: f.funcname == "main_worker" ) self.assertIsNotNone( main_found, "Main interpreter should have main_worker in stack", ) # If subinterpreter is present, check its stack trace if sub_interp: sub_found = self._find_frame_in_trace( [sub_interp], lambda f: f.funcname in ("sub_worker", "nested_func"), ) self.assertIsNotNone( sub_found, "Subinterpreter should have sub_worker or nested_func in stack", ) finally: _cleanup_sockets(*client_sockets, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) @requires_subinterpreters def test_multiple_subinterpreters_with_threads(self): port = find_unused_port() import pickle subinterp1_code = textwrap.dedent(f""" import socket import time import threading def worker1(): def nested_func(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) sock.sendall(b"ready:sub1-t1\\n") time.sleep(10_000) nested_func() def worker2(): def nested_func(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) sock.sendall(b"ready:sub1-t2\\n") time.sleep(10_000) nested_func() t1 = threading.Thread(target=worker1) t2 = threading.Thread(target=worker2) t1.start() t2.start() t1.join() t2.join() """).strip() subinterp2_code = textwrap.dedent(f""" import socket import time import threading def worker1(): def nested_func(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) sock.sendall(b"ready:sub2-t1\\n") time.sleep(10_000) nested_func() def worker2(): def nested_func(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) sock.sendall(b"ready:sub2-t2\\n") time.sleep(10_000) nested_func() t1 = threading.Thread(target=worker1) t2 = threading.Thread(target=worker2) t1.start() t2.start() t1.join() t2.join() """).strip() pickled_code1 = pickle.dumps(subinterp1_code) pickled_code2 = pickle.dumps(subinterp2_code) script = textwrap.dedent( f""" from concurrent import interpreters import time import sys import socket import threading sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def main_worker(): sock.sendall(b"ready:main\\n") time.sleep(10_000) def run_subinterp1(): subinterp = interpreters.create() import pickle pickled_code = {pickled_code1!r} subinterp_code = pickle.loads(pickled_code) subinterp.exec(subinterp_code) def run_subinterp2(): subinterp = interpreters.create() import pickle pickled_code = {pickled_code2!r} subinterp_code = pickle.loads(pickled_code) subinterp.exec(subinterp_code) sub1_thread = threading.Thread(target=run_subinterp1) sub2_thread = threading.Thread(target=run_subinterp2) sub1_thread.start() sub2_thread.start() main_thread = threading.Thread(target=main_worker) main_thread.start() main_thread.join() sub1_thread.join() sub2_thread.join() """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port, backlog=5) script_name = _make_test_script(script_dir, "script", script) client_sockets = [] try: with _managed_subprocess([sys.executable, script_name]) as p: # Accept connections from main and all subinterpreter threads expected_responses = { "ready:main", "ready:sub1-t1", "ready:sub1-t2", "ready:sub2-t1", "ready:sub2-t2", } responses = set() while len(responses) < 5: try: client_socket, _ = server_socket.accept() client_sockets.append(client_socket) response = client_socket.recv(1024) response_str = response.decode().strip() if response_str in expected_responses: responses.add(response_str) except socket.timeout: break server_socket.close() server_socket = None try: stack_trace = get_stack_trace(p.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Verify we have multiple interpreters self.assertGreaterEqual(len(stack_trace), 2) # Count interpreters by ID interpreter_ids = { interp.interpreter_id for interp in stack_trace } self.assertIn( 0, interpreter_ids, "Main interpreter should be present", ) self.assertGreaterEqual(len(interpreter_ids), 3) # Count total threads total_threads = sum( len(interp.threads) for interp in stack_trace ) self.assertGreaterEqual(total_threads, 5) # Look for expected function names all_funcnames = set() for interpreter_info in stack_trace: for thread_info in interpreter_info.threads: for frame in thread_info.frame_info: all_funcnames.add(frame.funcname) expected_funcs = { "main_worker", "worker1", "worker2", "nested_func", } found_funcs = expected_funcs.intersection(all_funcnames) self.assertGreater(len(found_funcs), 0) finally: _cleanup_sockets(*client_sockets, server_socket) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) @requires_gil_enabled("Free threaded builds don't have an 'active thread'") def test_only_active_thread(self): port = find_unused_port() script = textwrap.dedent( f"""\ import time, sys, socket, threading sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def worker_thread(name, barrier, ready_event): barrier.wait() ready_event.wait() time.sleep(10_000) def main_work(): sock.sendall(b"working\\n") count = 0 while count < 100000000: count += 1 if count % 10000000 == 0: pass sock.sendall(b"done\\n") num_threads = 3 barrier = threading.Barrier(num_threads + 1) ready_event = threading.Event() threads = [] for i in range(num_threads): t = threading.Thread(target=worker_thread, args=(f"Worker-{{i}}", barrier, ready_event)) t.start() threads.append(t) barrier.wait() sock.sendall(b"ready\\n") ready_event.set() main_work() """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None # Wait for ready and working signals _wait_for_signal(client_socket, [b"ready", b"working"]) try: # Get stack trace with all threads unwinder_all = RemoteUnwinder(p.pid, all_threads=True) for _ in range(MAX_TRIES): all_traces = unwinder_all.get_stack_trace() found = self._find_frame_in_trace( all_traces, lambda f: f.funcname == "main_work" and f.lineno > 12, ) if found: break time.sleep(0.1) else: self.fail( "Main thread did not start its busy work on time" ) # Get stack trace with only GIL holder unwinder_gil = RemoteUnwinder( p.pid, only_active_thread=True ) gil_traces = unwinder_gil.get_stack_trace() except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) # Count threads total_threads = sum( len(interp.threads) for interp in all_traces ) self.assertGreater(total_threads, 1) total_gil_threads = sum( len(interp.threads) for interp in gil_traces ) self.assertEqual(total_gil_threads, 1) # Get the GIL holder thread ID gil_thread_id = None for interpreter_info in gil_traces: if interpreter_info.threads: gil_thread_id = interpreter_info.threads[ 0 ].thread_id break # Get all thread IDs all_thread_ids = [] for interpreter_info in all_traces: for thread_info in interpreter_info.threads: all_thread_ids.append(thread_info.thread_id) self.assertIn(gil_thread_id, all_thread_ids) finally: _cleanup_sockets(client_socket, server_socket) class TestUnsupportedPlatformHandling(unittest.TestCase): @unittest.skipIf( sys.platform in ("linux", "darwin", "win32"), "Test only runs on unsupported platforms (not Linux, macOS, or Windows)", ) @unittest.skipIf( sys.platform == "android", "Android raises Linux-specific exception" ) def test_unsupported_platform_error(self): with self.assertRaises(RuntimeError) as cm: RemoteUnwinder(os.getpid()) self.assertIn( "Reading the PyRuntime section is not supported on this platform", str(cm.exception), ) class TestDetectionOfThreadStatus(RemoteInspectionTestBase): def _run_thread_status_test(self, mode, check_condition): """ Common pattern for thread status detection tests. Args: mode: Profiling mode (PROFILING_MODE_CPU, PROFILING_MODE_GIL, etc.) check_condition: Function(statuses, sleeper_tid, busy_tid) -> bool """ port = find_unused_port() script = textwrap.dedent( f"""\ import time, sys, socket, threading import os sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) def sleeper(): tid = threading.get_native_id() sock.sendall(f'ready:sleeper:{{tid}}\\n'.encode()) time.sleep(10000) def busy(): tid = threading.get_native_id() sock.sendall(f'ready:busy:{{tid}}\\n'.encode()) x = 0 while True: x = x + 1 time.sleep(0.5) t1 = threading.Thread(target=sleeper) t2 = threading.Thread(target=busy) t1.start() t2.start() sock.sendall(b'ready:main\\n') t1.join() t2.join() sock.close() """ ) with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script( script_dir, "thread_status_script", script ) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None # Wait for all ready signals and parse TIDs response = _wait_for_signal( client_socket, [b"ready:main", b"ready:sleeper", b"ready:busy"], ) sleeper_tid = None busy_tid = None for line in response.split(b"\n"): if line.startswith(b"ready:sleeper:"): try: sleeper_tid = int(line.split(b":")[-1]) except (ValueError, IndexError): pass elif line.startswith(b"ready:busy:"): try: busy_tid = int(line.split(b":")[-1]) except (ValueError, IndexError): pass self.assertIsNotNone( sleeper_tid, "Sleeper thread id not received" ) self.assertIsNotNone( busy_tid, "Busy thread id not received" ) # Sample until we see expected thread states statuses = {} try: unwinder = RemoteUnwinder( p.pid, all_threads=True, mode=mode, skip_non_matching_threads=False, ) for _ in range(MAX_TRIES): traces = unwinder.get_stack_trace() statuses = self._get_thread_statuses(traces) if check_condition( statuses, sleeper_tid, busy_tid ): break time.sleep(0.5) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) return statuses, sleeper_tid, busy_tid finally: _cleanup_sockets(client_socket, server_socket) @unittest.skipIf( sys.platform not in ("linux", "darwin", "win32"), "Test only runs on supported platforms (Linux, macOS, or Windows)", ) @unittest.skipIf( sys.platform == "android", "Android raises Linux-specific exception" ) def test_thread_status_detection(self): def check_cpu_status(statuses, sleeper_tid, busy_tid): return ( sleeper_tid in statuses and busy_tid in statuses and not (statuses[sleeper_tid] & THREAD_STATUS_ON_CPU) and (statuses[busy_tid] & THREAD_STATUS_ON_CPU) ) statuses, sleeper_tid, busy_tid = self._run_thread_status_test( PROFILING_MODE_CPU, check_cpu_status ) self.assertIn(sleeper_tid, statuses) self.assertIn(busy_tid, statuses) self.assertFalse( statuses[sleeper_tid] & THREAD_STATUS_ON_CPU, "Sleeper thread should be off CPU", ) self.assertTrue( statuses[busy_tid] & THREAD_STATUS_ON_CPU, "Busy thread should be on CPU", ) @unittest.skipIf( sys.platform not in ("linux", "darwin", "win32"), "Test only runs on supported platforms (Linux, macOS, or Windows)", ) @unittest.skipIf( sys.platform == "android", "Android raises Linux-specific exception" ) def test_thread_status_gil_detection(self): def check_gil_status(statuses, sleeper_tid, busy_tid): return ( sleeper_tid in statuses and busy_tid in statuses and not (statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL) and (statuses[busy_tid] & THREAD_STATUS_HAS_GIL) ) statuses, sleeper_tid, busy_tid = self._run_thread_status_test( PROFILING_MODE_GIL, check_gil_status ) self.assertIn(sleeper_tid, statuses) self.assertIn(busy_tid, statuses) self.assertFalse( statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL, "Sleeper thread should not have GIL", ) self.assertTrue( statuses[busy_tid] & THREAD_STATUS_HAS_GIL, "Busy thread should have GIL", ) @unittest.skipIf( sys.platform not in ("linux", "darwin", "win32"), "Test only runs on supported platforms (Linux, macOS, or Windows)", ) @unittest.skipIf( sys.platform == "android", "Android raises Linux-specific exception" ) def test_thread_status_all_mode_detection(self): port = find_unused_port() script = textwrap.dedent( f"""\ import socket import threading import time import sys def sleeper_thread(): conn = socket.create_connection(("localhost", {port})) conn.sendall(b"sleeper:" + str(threading.get_native_id()).encode()) while True: time.sleep(1) def busy_thread(): conn = socket.create_connection(("localhost", {port})) conn.sendall(b"busy:" + str(threading.get_native_id()).encode()) while True: sum(range(100000)) t1 = threading.Thread(target=sleeper_thread) t2 = threading.Thread(target=busy_thread) t1.start() t2.start() t1.join() t2.join() """ ) with os_helper.temp_dir() as tmp_dir: script_file = make_script(tmp_dir, "script", script) server_socket = _create_server_socket(port, backlog=2) client_sockets = [] try: with _managed_subprocess( [sys.executable, script_file], ) as p: sleeper_tid = None busy_tid = None # Receive thread IDs from the child process for _ in range(2): client_socket, _ = server_socket.accept() client_sockets.append(client_socket) line = client_socket.recv(1024) if line: if line.startswith(b"sleeper:"): try: sleeper_tid = int(line.split(b":")[-1]) except (ValueError, IndexError): pass elif line.startswith(b"busy:"): try: busy_tid = int(line.split(b":")[-1]) except (ValueError, IndexError): pass server_socket.close() server_socket = None statuses = {} try: unwinder = RemoteUnwinder( p.pid, all_threads=True, mode=PROFILING_MODE_ALL, skip_non_matching_threads=False, ) for _ in range(MAX_TRIES): traces = unwinder.get_stack_trace() statuses = self._get_thread_statuses(traces) # Check ALL mode provides both GIL and CPU info if ( sleeper_tid in statuses and busy_tid in statuses and not ( statuses[sleeper_tid] & THREAD_STATUS_ON_CPU ) and not ( statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL ) and (statuses[busy_tid] & THREAD_STATUS_ON_CPU) and ( statuses[busy_tid] & THREAD_STATUS_HAS_GIL ) ): break time.sleep(0.5) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) self.assertIsNotNone( sleeper_tid, "Sleeper thread id not received" ) self.assertIsNotNone( busy_tid, "Busy thread id not received" ) self.assertIn(sleeper_tid, statuses) self.assertIn(busy_tid, statuses) # Sleeper: off CPU, no GIL self.assertFalse( statuses[sleeper_tid] & THREAD_STATUS_ON_CPU, "Sleeper should be off CPU", ) self.assertFalse( statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL, "Sleeper should not have GIL", ) # Busy: on CPU, has GIL self.assertTrue( statuses[busy_tid] & THREAD_STATUS_ON_CPU, "Busy should be on CPU", ) self.assertTrue( statuses[busy_tid] & THREAD_STATUS_HAS_GIL, "Busy should have GIL", ) finally: _cleanup_sockets(*client_sockets, server_socket) class TestFrameCaching(RemoteInspectionTestBase): """Test that frame caching produces correct results. Uses socket-based synchronization for deterministic testing. All tests verify cache reuse via object identity checks (assertIs). """ @contextmanager def _target_process(self, script_body): """Context manager for running a target process with socket sync.""" port = find_unused_port() script = f"""\ import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) {textwrap.dedent(script_body)} """ with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) server_socket = _create_server_socket(port) script_name = _make_test_script(script_dir, "script", script) client_socket = None try: with _managed_subprocess([sys.executable, script_name]) as p: client_socket, _ = server_socket.accept() server_socket.close() server_socket = None def make_unwinder(cache_frames=True): return RemoteUnwinder( p.pid, all_threads=True, cache_frames=cache_frames ) yield p, client_socket, make_unwinder except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) finally: _cleanup_sockets(client_socket, server_socket) def _get_frames_with_retry(self, unwinder, required_funcs): """Get frames containing required_funcs, with retry for transient errors.""" for _ in range(MAX_TRIES): try: traces = unwinder.get_stack_trace() for interp in traces: for thread in interp.threads: funcs = {f.funcname for f in thread.frame_info} if required_funcs.issubset(funcs): return thread.frame_info except RuntimeError as e: if _is_retriable_error(e): pass else: raise time.sleep(0.1) return None def _sample_frames( self, client_socket, unwinder, wait_signal, send_ack, required_funcs, expected_frames=1, ): """Wait for signal, sample frames with retry until required funcs present, send ack.""" _wait_for_signal(client_socket, wait_signal) frames = None for _ in range(MAX_TRIES): frames = self._get_frames_with_retry(unwinder, required_funcs) if frames and len(frames) >= expected_frames: break time.sleep(0.1) client_socket.sendall(send_ack) return frames @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_cache_hit_same_stack(self): """Test that consecutive samples reuse cached parent frame objects. The current frame (index 0) is always re-read from memory to get updated line numbers, so it may be a different object. Parent frames (index 1+) should be identical objects from cache. """ script_body = """\ def level3(): sock.sendall(b"sync1") sock.recv(16) sock.sendall(b"sync2") sock.recv(16) sock.sendall(b"sync3") sock.recv(16) def level2(): level3() def level1(): level2() level1() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) expected = {"level1", "level2", "level3"} frames1 = self._sample_frames( client_socket, unwinder, b"sync1", b"ack", expected ) frames2 = self._sample_frames( client_socket, unwinder, b"sync2", b"ack", expected ) frames3 = self._sample_frames( client_socket, unwinder, b"sync3", b"done", expected ) self.assertIsNotNone(frames1) self.assertIsNotNone(frames2) self.assertIsNotNone(frames3) self.assertEqual(len(frames1), len(frames2)) self.assertEqual(len(frames2), len(frames3)) # Current frame (index 0) is always re-read, so check value equality self.assertEqual(frames1[0].funcname, frames2[0].funcname) self.assertEqual(frames2[0].funcname, frames3[0].funcname) # Parent frames (index 1+) must be identical objects (cache reuse) for i in range(1, len(frames1)): f1, f2, f3 = frames1[i], frames2[i], frames3[i] self.assertIs( f1, f2, f"Frame {i}: samples 1-2 must be same object" ) self.assertIs( f2, f3, f"Frame {i}: samples 2-3 must be same object" ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_line_number_updates_in_same_frame(self): """Test that line numbers are correctly updated when execution moves within a function. When the profiler samples at different points within the same function, it must report the correct line number for each sample, not stale cached values. """ script_body = """\ def outer(): inner() def inner(): sock.sendall(b"line_a"); sock.recv(16) sock.sendall(b"line_b"); sock.recv(16) sock.sendall(b"line_c"); sock.recv(16) sock.sendall(b"line_d"); sock.recv(16) outer() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) frames_a = self._sample_frames( client_socket, unwinder, b"line_a", b"ack", {"inner"} ) frames_b = self._sample_frames( client_socket, unwinder, b"line_b", b"ack", {"inner"} ) frames_c = self._sample_frames( client_socket, unwinder, b"line_c", b"ack", {"inner"} ) frames_d = self._sample_frames( client_socket, unwinder, b"line_d", b"done", {"inner"} ) self.assertIsNotNone(frames_a) self.assertIsNotNone(frames_b) self.assertIsNotNone(frames_c) self.assertIsNotNone(frames_d) # Get the 'inner' frame from each sample (should be index 0) inner_a = frames_a[0] inner_b = frames_b[0] inner_c = frames_c[0] inner_d = frames_d[0] self.assertEqual(inner_a.funcname, "inner") self.assertEqual(inner_b.funcname, "inner") self.assertEqual(inner_c.funcname, "inner") self.assertEqual(inner_d.funcname, "inner") # Line numbers must be different and increasing (execution moves forward) self.assertLess( inner_a.lineno, inner_b.lineno, "Line B should be after line A" ) self.assertLess( inner_b.lineno, inner_c.lineno, "Line C should be after line B" ) self.assertLess( inner_c.lineno, inner_d.lineno, "Line D should be after line C" ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_cache_invalidation_on_return(self): """Test cache invalidation when stack shrinks (function returns).""" script_body = """\ def inner(): sock.sendall(b"at_inner") sock.recv(16) def outer(): inner() sock.sendall(b"at_outer") sock.recv(16) outer() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) frames_deep = self._sample_frames( client_socket, unwinder, b"at_inner", b"ack", {"inner", "outer"}, ) frames_shallow = self._sample_frames( client_socket, unwinder, b"at_outer", b"done", {"outer"} ) self.assertIsNotNone(frames_deep) self.assertIsNotNone(frames_shallow) funcs_deep = [f.funcname for f in frames_deep] funcs_shallow = [f.funcname for f in frames_shallow] self.assertIn("inner", funcs_deep) self.assertIn("outer", funcs_deep) self.assertNotIn("inner", funcs_shallow) self.assertIn("outer", funcs_shallow) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_cache_invalidation_on_call(self): """Test cache invalidation when stack grows (new function called).""" script_body = """\ def deeper(): sock.sendall(b"at_deeper") sock.recv(16) def middle(): sock.sendall(b"at_middle") sock.recv(16) deeper() def top(): middle() top() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) frames_before = self._sample_frames( client_socket, unwinder, b"at_middle", b"ack", {"middle", "top"}, ) frames_after = self._sample_frames( client_socket, unwinder, b"at_deeper", b"done", {"deeper", "middle", "top"}, ) self.assertIsNotNone(frames_before) self.assertIsNotNone(frames_after) funcs_before = [f.funcname for f in frames_before] funcs_after = [f.funcname for f in frames_after] self.assertIn("middle", funcs_before) self.assertIn("top", funcs_before) self.assertNotIn("deeper", funcs_before) self.assertIn("deeper", funcs_after) self.assertIn("middle", funcs_after) self.assertIn("top", funcs_after) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_partial_stack_reuse(self): """Test that unchanged bottom frames are reused when top changes (A→B→C to A→B→D).""" script_body = """\ def func_c(): sock.sendall(b"at_c") sock.recv(16) def func_d(): sock.sendall(b"at_d") sock.recv(16) def func_b(): func_c() func_d() def func_a(): func_b() func_a() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) # Sample at C: stack is A→B→C frames_c = self._sample_frames( client_socket, unwinder, b"at_c", b"ack", {"func_a", "func_b", "func_c"}, ) # Sample at D: stack is A→B→D (C returned, D called) frames_d = self._sample_frames( client_socket, unwinder, b"at_d", b"done", {"func_a", "func_b", "func_d"}, ) self.assertIsNotNone(frames_c) self.assertIsNotNone(frames_d) # Find func_a and func_b frames in both samples def find_frame(frames, funcname): for f in frames: if f.funcname == funcname: return f return None frame_a_in_c = find_frame(frames_c, "func_a") frame_b_in_c = find_frame(frames_c, "func_b") frame_a_in_d = find_frame(frames_d, "func_a") frame_b_in_d = find_frame(frames_d, "func_b") self.assertIsNotNone(frame_a_in_c) self.assertIsNotNone(frame_b_in_c) self.assertIsNotNone(frame_a_in_d) self.assertIsNotNone(frame_b_in_d) # The bottom frames (A, B) should be the SAME objects (cache reuse) self.assertIs( frame_a_in_c, frame_a_in_d, "func_a frame should be reused from cache", ) self.assertIs( frame_b_in_c, frame_b_in_d, "func_b frame should be reused from cache", ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_recursive_frames(self): """Test caching with same function appearing multiple times (recursion).""" script_body = """\ def recurse(n): if n <= 0: sock.sendall(b"sync1") sock.recv(16) sock.sendall(b"sync2") sock.recv(16) else: recurse(n - 1) recurse(5) """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) frames1 = self._sample_frames( client_socket, unwinder, b"sync1", b"ack", {"recurse"} ) frames2 = self._sample_frames( client_socket, unwinder, b"sync2", b"done", {"recurse"} ) self.assertIsNotNone(frames1) self.assertIsNotNone(frames2) # Should have multiple "recurse" frames (6 total: recurse(5) down to recurse(0)) recurse_count = sum(1 for f in frames1 if f.funcname == "recurse") self.assertEqual(recurse_count, 6, "Should have 6 recursive frames") self.assertEqual(len(frames1), len(frames2)) # Current frame (index 0) is re-read, check value equality self.assertEqual(frames1[0].funcname, frames2[0].funcname) # Parent frames (index 1+) should be identical objects (cache reuse) for i in range(1, len(frames1)): self.assertIs( frames1[i], frames2[i], f"Frame {i}: recursive frames must be same object", ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_cache_vs_no_cache_equivalence(self): """Test that cache_frames=True and cache_frames=False produce equivalent results.""" script_body = """\ def level3(): sock.sendall(b"ready"); sock.recv(16) def level2(): level3() def level1(): level2() level1() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): _wait_for_signal(client_socket, b"ready") # Sample with cache unwinder_cache = make_unwinder(cache_frames=True) frames_cached = self._get_frames_with_retry( unwinder_cache, {"level1", "level2", "level3"} ) # Sample without cache unwinder_no_cache = make_unwinder(cache_frames=False) frames_no_cache = self._get_frames_with_retry( unwinder_no_cache, {"level1", "level2", "level3"} ) client_socket.sendall(b"done") self.assertIsNotNone(frames_cached) self.assertIsNotNone(frames_no_cache) # Same number of frames self.assertEqual(len(frames_cached), len(frames_no_cache)) # Same function names in same order funcs_cached = [f.funcname for f in frames_cached] funcs_no_cache = [f.funcname for f in frames_no_cache] self.assertEqual(funcs_cached, funcs_no_cache) # Same line numbers lines_cached = [f.lineno for f in frames_cached] lines_no_cache = [f.lineno for f in frames_no_cache] self.assertEqual(lines_cached, lines_no_cache) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_cache_per_thread_isolation(self): """Test that frame cache is per-thread and cache invalidation works independently.""" script_body = """\ import threading lock = threading.Lock() def sync(msg): with lock: sock.sendall(msg + b"\\n") sock.recv(1) # Thread 1 functions def baz1(): sync(b"t1:baz1") def bar1(): baz1() def blech1(): sync(b"t1:blech1") def foo1(): bar1() # Goes down to baz1, syncs blech1() # Returns up, goes down to blech1, syncs # Thread 2 functions def baz2(): sync(b"t2:baz2") def bar2(): baz2() def blech2(): sync(b"t2:blech2") def foo2(): bar2() # Goes down to baz2, syncs blech2() # Returns up, goes down to blech2, syncs t1 = threading.Thread(target=foo1) t2 = threading.Thread(target=foo2) t1.start() t2.start() t1.join() t2.join() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder = make_unwinder(cache_frames=True) buffer = b"" def recv_msg(): """Receive a single message from socket.""" nonlocal buffer while b"\n" not in buffer: chunk = client_socket.recv(256) if not chunk: return None buffer += chunk msg, buffer = buffer.split(b"\n", 1) return msg def get_thread_frames(target_funcs): """Get frames for thread matching target functions.""" retries = 0 for _ in busy_retry(SHORT_TIMEOUT): if retries >= 5: break retries += 1 # On Windows, ReadProcessMemory can fail with OSError # (WinError 299) when frame pointers are in flux with contextlib.suppress(RuntimeError, OSError): traces = unwinder.get_stack_trace() for interp in traces: for thread in interp.threads: funcs = [f.funcname for f in thread.frame_info] if any(f in funcs for f in target_funcs): return funcs return None # Track results for each sync point results = {} # Process 4 sync points: baz1, baz2, blech1, blech2 # With the lock, threads are serialized - handle one at a time for _ in range(4): msg = recv_msg() self.assertIsNotNone(msg, "Expected message from subprocess") # Determine which thread/function and take snapshot if msg == b"t1:baz1": funcs = get_thread_frames(["baz1", "bar1", "foo1"]) self.assertIsNotNone(funcs, "Thread 1 not found at baz1") results["t1:baz1"] = funcs elif msg == b"t2:baz2": funcs = get_thread_frames(["baz2", "bar2", "foo2"]) self.assertIsNotNone(funcs, "Thread 2 not found at baz2") results["t2:baz2"] = funcs elif msg == b"t1:blech1": funcs = get_thread_frames(["blech1", "foo1"]) self.assertIsNotNone(funcs, "Thread 1 not found at blech1") results["t1:blech1"] = funcs elif msg == b"t2:blech2": funcs = get_thread_frames(["blech2", "foo2"]) self.assertIsNotNone(funcs, "Thread 2 not found at blech2") results["t2:blech2"] = funcs # Release thread to continue client_socket.sendall(b"k") # Validate Phase 1: baz snapshots t1_baz = results.get("t1:baz1") t2_baz = results.get("t2:baz2") self.assertIsNotNone(t1_baz, "Missing t1:baz1 snapshot") self.assertIsNotNone(t2_baz, "Missing t2:baz2 snapshot") # Thread 1 at baz1: should have foo1->bar1->baz1 self.assertIn("baz1", t1_baz) self.assertIn("bar1", t1_baz) self.assertIn("foo1", t1_baz) self.assertNotIn("blech1", t1_baz) # No cross-contamination self.assertNotIn("baz2", t1_baz) self.assertNotIn("bar2", t1_baz) self.assertNotIn("foo2", t1_baz) # Thread 2 at baz2: should have foo2->bar2->baz2 self.assertIn("baz2", t2_baz) self.assertIn("bar2", t2_baz) self.assertIn("foo2", t2_baz) self.assertNotIn("blech2", t2_baz) # No cross-contamination self.assertNotIn("baz1", t2_baz) self.assertNotIn("bar1", t2_baz) self.assertNotIn("foo1", t2_baz) # Validate Phase 2: blech snapshots (cache invalidation test) t1_blech = results.get("t1:blech1") t2_blech = results.get("t2:blech2") self.assertIsNotNone(t1_blech, "Missing t1:blech1 snapshot") self.assertIsNotNone(t2_blech, "Missing t2:blech2 snapshot") # Thread 1 at blech1: bar1/baz1 should be GONE (cache invalidated) self.assertIn("blech1", t1_blech) self.assertIn("foo1", t1_blech) self.assertNotIn( "bar1", t1_blech, "Cache not invalidated: bar1 still present" ) self.assertNotIn( "baz1", t1_blech, "Cache not invalidated: baz1 still present" ) # No cross-contamination self.assertNotIn("blech2", t1_blech) # Thread 2 at blech2: bar2/baz2 should be GONE (cache invalidated) self.assertIn("blech2", t2_blech) self.assertIn("foo2", t2_blech) self.assertNotIn( "bar2", t2_blech, "Cache not invalidated: bar2 still present" ) self.assertNotIn( "baz2", t2_blech, "Cache not invalidated: baz2 still present" ) # No cross-contamination self.assertNotIn("blech1", t2_blech) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_new_unwinder_with_stale_last_profiled_frame(self): """Test that a new unwinder returns complete stack when cache lookup misses.""" script_body = """\ def level4(): sock.sendall(b"sync1") sock.recv(16) sock.sendall(b"sync2") sock.recv(16) def level3(): level4() def level2(): level3() def level1(): level2() level1() """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): expected = {"level1", "level2", "level3", "level4"} # First unwinder samples - this sets last_profiled_frame in target unwinder1 = make_unwinder(cache_frames=True) frames1 = self._sample_frames( client_socket, unwinder1, b"sync1", b"ack", expected ) # Create NEW unwinder (empty cache) and sample # The target still has last_profiled_frame set from unwinder1 unwinder2 = make_unwinder(cache_frames=True) frames2 = self._sample_frames( client_socket, unwinder2, b"sync2", b"done", expected ) self.assertIsNotNone(frames1) self.assertIsNotNone(frames2) funcs1 = [f.funcname for f in frames1] funcs2 = [f.funcname for f in frames2] # Both should have all levels for level in ["level1", "level2", "level3", "level4"]: self.assertIn(level, funcs1, f"{level} missing from first sample") self.assertIn(level, funcs2, f"{level} missing from second sample") # Should have same stack depth self.assertEqual( len(frames1), len(frames2), "New unwinder should return complete stack despite stale last_profiled_frame", ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_cache_exhaustion(self): """Test cache works when frame limit (1024) is exceeded. FRAME_CACHE_MAX_FRAMES=1024. With 1100 recursive frames, the cache can't store all of them but should still work. """ # Use 1100 to exceed FRAME_CACHE_MAX_FRAMES=1024 depth = 1100 script_body = f"""\ import sys sys.setrecursionlimit(2000) def recurse(n): if n <= 0: sock.sendall(b"ready") sock.recv(16) # wait for ack sock.sendall(b"ready2") sock.recv(16) # wait for done return recurse(n - 1) recurse({depth}) """ with self._target_process(script_body) as ( p, client_socket, make_unwinder, ): unwinder_cache = make_unwinder(cache_frames=True) unwinder_no_cache = make_unwinder(cache_frames=False) frames_cached = self._sample_frames( client_socket, unwinder_cache, b"ready", b"ack", {"recurse"}, expected_frames=1102, ) # Sample again with no cache for comparison frames_no_cache = self._sample_frames( client_socket, unwinder_no_cache, b"ready2", b"done", {"recurse"}, expected_frames=1102, ) self.assertIsNotNone(frames_cached) self.assertIsNotNone(frames_no_cache) # Both should have many recurse frames (> 1024 limit) cached_count = [f.funcname for f in frames_cached].count("recurse") no_cache_count = [f.funcname for f in frames_no_cache].count("recurse") self.assertGreater( cached_count, 1000, "Should have >1000 recurse frames" ) self.assertGreater( no_cache_count, 1000, "Should have >1000 recurse frames" ) # Both modes should produce same frame count self.assertEqual( len(frames_cached), len(frames_no_cache), "Cache exhaustion should not affect stack completeness", ) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_get_stats(self): """Test that get_stats() returns statistics when stats=True.""" script_body = """\ sock.sendall(b"ready") sock.recv(16) """ with self._target_process(script_body) as (p, client_socket, _): unwinder = RemoteUnwinder(p.pid, all_threads=True, stats=True) _wait_for_signal(client_socket, b"ready") # Take a sample unwinder.get_stack_trace() stats = unwinder.get_stats() client_socket.sendall(b"done") # Verify expected keys exist expected_keys = [ "total_samples", "frame_cache_hits", "frame_cache_misses", "frame_cache_partial_hits", "frames_read_from_cache", "frames_read_from_memory", "frame_cache_hit_rate", ] for key in expected_keys: self.assertIn(key, stats) self.assertEqual(stats["total_samples"], 1) @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) def test_get_stats_disabled_raises(self): """Test that get_stats() raises RuntimeError when stats=False.""" script_body = """\ sock.sendall(b"ready") sock.recv(16) """ with self._target_process(script_body) as (p, client_socket, _): unwinder = RemoteUnwinder( p.pid, all_threads=True ) # stats=False by default _wait_for_signal(client_socket, b"ready") with self.assertRaises(RuntimeError): unwinder.get_stats() client_socket.sendall(b"done") if __name__ == "__main__": unittest.main()