cpython/Lib/test/test_external_inspection.py

3124 lines
113 KiB
Python

import contextlib
import unittest
import os
import textwrap
import importlib
import sys
import socket
import threading
import time
from contextlib import contextmanager
from asyncio import staggered, taskgroups, base_events, tasks
from unittest.mock import ANY
from test.support import (
os_helper,
SHORT_TIMEOUT,
busy_retry,
requires_gil_enabled,
)
from test.support.script_helper import make_script
from test.support.socket_helper import find_unused_port
import subprocess
# Profiling mode constants
PROFILING_MODE_WALL = 0
PROFILING_MODE_CPU = 1
PROFILING_MODE_GIL = 2
PROFILING_MODE_ALL = 3
# Thread status flags
THREAD_STATUS_HAS_GIL = 1 << 0
THREAD_STATUS_ON_CPU = 1 << 1
THREAD_STATUS_UNKNOWN = 1 << 2
# Maximum number of retry attempts for operations that may fail transiently
MAX_TRIES = 10
try:
from concurrent import interpreters
except ImportError:
interpreters = None
PROCESS_VM_READV_SUPPORTED = False
try:
from _remote_debugging import PROCESS_VM_READV_SUPPORTED
from _remote_debugging import RemoteUnwinder
from _remote_debugging import FrameInfo, CoroInfo, TaskInfo
except ImportError:
raise unittest.SkipTest(
"Test only runs when _remote_debugging is available"
)
# ============================================================================
# Module-level helper functions
# ============================================================================
def _make_test_script(script_dir, script_basename, source):
to_return = make_script(script_dir, script_basename, source)
importlib.invalidate_caches()
return to_return
def _create_server_socket(port, backlog=1):
"""Create and configure a server socket for test communication."""
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(backlog)
return server_socket
def _wait_for_signal(sock, expected_signals, timeout=SHORT_TIMEOUT):
"""
Wait for expected signal(s) from a socket with proper timeout and EOF handling.
Args:
sock: Connected socket to read from
expected_signals: Single bytes object or list of bytes objects to wait for
timeout: Socket timeout in seconds
Returns:
bytes: Complete accumulated response buffer
Raises:
RuntimeError: If connection closed before signal received or timeout
"""
if isinstance(expected_signals, bytes):
expected_signals = [expected_signals]
sock.settimeout(timeout)
buffer = b""
while True:
# Check if all expected signals are in buffer
if all(sig in buffer for sig in expected_signals):
return buffer
try:
chunk = sock.recv(4096)
if not chunk:
# EOF - connection closed
raise RuntimeError(
f"Connection closed before receiving expected signals. "
f"Expected: {expected_signals}, Got: {buffer[-200:]!r}"
)
buffer += chunk
except socket.timeout:
raise RuntimeError(
f"Timeout waiting for signals. "
f"Expected: {expected_signals}, Got: {buffer[-200:]!r}"
)
def _wait_for_n_signals(sock, signal_pattern, count, timeout=SHORT_TIMEOUT):
"""
Wait for N occurrences of a signal pattern.
Args:
sock: Connected socket to read from
signal_pattern: bytes pattern to count (e.g., b"ready")
count: Number of occurrences expected
timeout: Socket timeout in seconds
Returns:
bytes: Complete accumulated response buffer
Raises:
RuntimeError: If connection closed or timeout before receiving all signals
"""
sock.settimeout(timeout)
buffer = b""
found_count = 0
while found_count < count:
try:
chunk = sock.recv(4096)
if not chunk:
raise RuntimeError(
f"Connection closed after {found_count}/{count} signals. "
f"Last 200 bytes: {buffer[-200:]!r}"
)
buffer += chunk
# Count occurrences in entire buffer
found_count = buffer.count(signal_pattern)
except socket.timeout:
raise RuntimeError(
f"Timeout waiting for {count} signals (found {found_count}). "
f"Last 200 bytes: {buffer[-200:]!r}"
)
return buffer
@contextmanager
def _managed_subprocess(args, timeout=SHORT_TIMEOUT):
"""
Context manager for subprocess lifecycle management.
Ensures process is properly terminated and cleaned up even on exceptions.
Uses graceful termination first, then forceful kill if needed.
"""
p = subprocess.Popen(args)
try:
yield p
finally:
try:
p.terminate()
try:
p.wait(timeout=timeout)
except subprocess.TimeoutExpired:
p.kill()
try:
p.wait(timeout=timeout)
except subprocess.TimeoutExpired:
pass # Process refuses to die, nothing more we can do
except OSError:
pass # Process already dead
def _cleanup_sockets(*sockets):
"""Safely close multiple sockets, ignoring errors."""
for sock in sockets:
if sock is not None:
try:
sock.close()
except OSError:
pass
# ============================================================================
# Decorators and skip conditions
# ============================================================================
skip_if_not_supported = unittest.skipIf(
(
sys.platform != "darwin"
and sys.platform != "linux"
and sys.platform != "win32"
),
"Test only runs on Linux, Windows and MacOS",
)
def requires_subinterpreters(meth):
"""Decorator to skip a test if subinterpreters are not supported."""
return unittest.skipIf(interpreters is None, "subinterpreters required")(
meth
)
# ============================================================================
# Simple wrapper functions for RemoteUnwinder
# ============================================================================
# Errors that can occur transiently when reading process memory without synchronization
RETRIABLE_ERRORS = (
"Task list appears corrupted",
"Invalid linked list structure reading remote memory",
"Unknown error reading memory",
"Unhandled frame owner",
"Failed to parse initial frame",
"Failed to process frame chain",
"Failed to unwind stack",
)
def _is_retriable_error(exc):
"""Check if an exception is a transient error that should be retried."""
msg = str(exc)
return any(msg.startswith(err) or err in msg for err in RETRIABLE_ERRORS)
def get_stack_trace(pid):
for _ in busy_retry(SHORT_TIMEOUT):
try:
unwinder = RemoteUnwinder(pid, all_threads=True, debug=True)
return unwinder.get_stack_trace()
except RuntimeError as e:
if _is_retriable_error(e):
continue
raise
raise RuntimeError("Failed to get stack trace after retries")
def get_async_stack_trace(pid):
for _ in busy_retry(SHORT_TIMEOUT):
try:
unwinder = RemoteUnwinder(pid, debug=True)
return unwinder.get_async_stack_trace()
except RuntimeError as e:
if _is_retriable_error(e):
continue
raise
raise RuntimeError("Failed to get async stack trace after retries")
def get_all_awaited_by(pid):
for _ in busy_retry(SHORT_TIMEOUT):
try:
unwinder = RemoteUnwinder(pid, debug=True)
return unwinder.get_all_awaited_by()
except RuntimeError as e:
if _is_retriable_error(e):
continue
raise
raise RuntimeError("Failed to get all awaited_by after retries")
# ============================================================================
# Base test class with shared infrastructure
# ============================================================================
class RemoteInspectionTestBase(unittest.TestCase):
"""Base class for remote inspection tests with common helpers."""
maxDiff = None
def _run_script_and_get_trace(
self,
script,
trace_func,
wait_for_signals=None,
port=None,
backlog=1,
):
"""
Common pattern: run a script, wait for signals, get trace.
Args:
script: Script content (will be formatted with port if {port} present)
trace_func: Function to call with pid to get trace (e.g., get_stack_trace)
wait_for_signals: Signal(s) to wait for before getting trace
port: Port to use (auto-selected if None)
backlog: Socket listen backlog
Returns:
tuple: (trace_result, script_name)
"""
if port is None:
port = find_unused_port()
# Format script with port if needed
if "{port}" in script or "{{port}}" in script:
script = script.replace("{{port}}", "{port}").format(port=port)
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
server_socket = _create_server_socket(port, backlog)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
with _managed_subprocess([sys.executable, script_name]) as p:
client_socket, _ = server_socket.accept()
server_socket.close()
server_socket = None
if wait_for_signals:
_wait_for_signal(client_socket, wait_for_signals)
try:
trace = trace_func(p.pid)
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)
return trace, script_name
finally:
_cleanup_sockets(client_socket, server_socket)
def _find_frame_in_trace(self, stack_trace, predicate):
"""
Find a frame matching predicate in stack trace.
Args:
stack_trace: List of InterpreterInfo objects
predicate: Function(frame) -> bool
Returns:
FrameInfo or None
"""
for interpreter_info in stack_trace:
for thread_info in interpreter_info.threads:
for frame in thread_info.frame_info:
if predicate(frame):
return frame
return None
def _find_thread_by_id(self, stack_trace, thread_id):
"""Find a thread by its native thread ID."""
for interpreter_info in stack_trace:
for thread_info in interpreter_info.threads:
if thread_info.thread_id == thread_id:
return thread_info
return None
def _find_thread_with_frame(self, stack_trace, frame_predicate):
"""Find a thread containing a frame matching predicate."""
for interpreter_info in stack_trace:
for thread_info in interpreter_info.threads:
for frame in thread_info.frame_info:
if frame_predicate(frame):
return thread_info
return None
def _get_thread_statuses(self, stack_trace):
"""Extract thread_id -> status mapping from stack trace."""
statuses = {}
for interpreter_info in stack_trace:
for thread_info in interpreter_info.threads:
statuses[thread_info.thread_id] = thread_info.status
return statuses
def _get_task_id_map(self, stack_trace):
"""Create task_id -> task mapping from async stack trace."""
return {task.task_id: task for task in stack_trace[0].awaited_by}
def _get_awaited_by_relationships(self, stack_trace):
"""Extract task name to awaited_by set mapping."""
id_to_task = self._get_task_id_map(stack_trace)
return {
task.task_name: set(
id_to_task[awaited.task_name].task_name
for awaited in task.awaited_by
)
for task in stack_trace[0].awaited_by
}
def _extract_coroutine_stacks(self, stack_trace):
"""Extract and format coroutine stacks from tasks."""
return {
task.task_name: sorted(
tuple(tuple(frame) for frame in coro.call_stack)
for coro in task.coroutine_stack
)
for task in stack_trace[0].awaited_by
}
# ============================================================================
# Test classes
# ============================================================================
class TestGetStackTrace(RemoteInspectionTestBase):
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_remote_stack_trace(self):
port = find_unused_port()
script = textwrap.dedent(
f"""\
import time, sys, socket, threading
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
def bar():
for x in range(100):
if x == 50:
baz()
def baz():
foo()
def foo():
sock.sendall(b"ready:thread\\n"); time.sleep(10_000)
t = threading.Thread(target=bar)
t.start()
sock.sendall(b"ready:main\\n"); t.join()
"""
)
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
with _managed_subprocess([sys.executable, script_name]) as p:
client_socket, _ = server_socket.accept()
server_socket.close()
server_socket = None
_wait_for_signal(
client_socket, [b"ready:main", b"ready:thread"]
)
try:
stack_trace = get_stack_trace(p.pid)
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)
thread_expected_stack_trace = [
FrameInfo([script_name, 15, "foo"]),
FrameInfo([script_name, 12, "baz"]),
FrameInfo([script_name, 9, "bar"]),
FrameInfo([threading.__file__, ANY, "Thread.run"]),
FrameInfo(
[
threading.__file__,
ANY,
"Thread._bootstrap_inner",
]
),
FrameInfo(
[threading.__file__, ANY, "Thread._bootstrap"]
),
]
# Find expected thread stack
found_thread = self._find_thread_with_frame(
stack_trace,
lambda f: f.funcname == "foo" and f.lineno == 15,
)
self.assertIsNotNone(
found_thread, "Expected thread stack trace not found"
)
self.assertEqual(
found_thread.frame_info, thread_expected_stack_trace
)
# Check main thread
main_frame = FrameInfo([script_name, 19, "<module>"])
found_main = self._find_frame_in_trace(
stack_trace, lambda f: f == main_frame
)
self.assertIsNotNone(
found_main, "Main thread stack trace not found"
)
finally:
_cleanup_sockets(client_socket, server_socket)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_async_remote_stack_trace(self):
port = find_unused_port()
script = textwrap.dedent(
f"""\
import asyncio
import time
import sys
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
def c5():
sock.sendall(b"ready"); time.sleep(10_000)
async def c4():
await asyncio.sleep(0)
c5()
async def c3():
await c4()
async def c2():
await c3()
async def c1(task):
await task
async def main():
async with asyncio.TaskGroup() as tg:
task = tg.create_task(c2(), name="c2_root")
tg.create_task(c1(task), name="sub_main_1")
tg.create_task(c1(task), name="sub_main_2")
def new_eager_loop():
loop = asyncio.new_event_loop()
eager_task_factory = asyncio.create_eager_task_factory(
asyncio.Task)
loop.set_task_factory(eager_task_factory)
return loop
asyncio.run(main(), loop_factory={{TASK_FACTORY}})
"""
)
for task_factory_variant in "asyncio.new_event_loop", "new_eager_loop":
with (
self.subTest(task_factory_variant=task_factory_variant),
os_helper.temp_dir() as work_dir,
):
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
server_socket = _create_server_socket(port)
script_name = _make_test_script(
script_dir,
"script",
script.format(TASK_FACTORY=task_factory_variant),
)
client_socket = None
try:
with _managed_subprocess(
[sys.executable, script_name]
) as p:
client_socket, _ = server_socket.accept()
server_socket.close()
server_socket = None
response = _wait_for_signal(client_socket, b"ready")
self.assertIn(b"ready", response)
try:
stack_trace = get_async_stack_trace(p.pid)
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)
# Check all tasks are present
tasks_names = [
task.task_name
for task in stack_trace[0].awaited_by
]
for task_name in [
"c2_root",
"sub_main_1",
"sub_main_2",
]:
self.assertIn(task_name, tasks_names)
# Check awaited_by relationships
relationships = self._get_awaited_by_relationships(
stack_trace
)
self.assertEqual(
relationships,
{
"c2_root": {
"Task-1",
"sub_main_1",
"sub_main_2",
},
"Task-1": set(),
"sub_main_1": {"Task-1"},
"sub_main_2": {"Task-1"},
},
)
# Check coroutine stacks
coroutine_stacks = self._extract_coroutine_stacks(
stack_trace
)
self.assertEqual(
coroutine_stacks,
{
"Task-1": [
(
tuple(
[
taskgroups.__file__,
ANY,
"TaskGroup._aexit",
]
),
tuple(
[
taskgroups.__file__,
ANY,
"TaskGroup.__aexit__",
]
),
tuple([script_name, 26, "main"]),
)
],
"c2_root": [
(
tuple([script_name, 10, "c5"]),
tuple([script_name, 14, "c4"]),
tuple([script_name, 17, "c3"]),
tuple([script_name, 20, "c2"]),
)
],
"sub_main_1": [
(tuple([script_name, 23, "c1"]),)
],
"sub_main_2": [
(tuple([script_name, 23, "c1"]),)
],
},
)
# Check awaited_by coroutine stacks
id_to_task = self._get_task_id_map(stack_trace)
awaited_by_coroutine_stacks = {
task.task_name: sorted(
(
id_to_task[coro.task_name].task_name,
tuple(
tuple(frame)
for frame in coro.call_stack
),
)
for coro in task.awaited_by
)
for task in stack_trace[0].awaited_by
}
self.assertEqual(
awaited_by_coroutine_stacks,
{
"Task-1": [],
"c2_root": [
(
"Task-1",
(
tuple(
[
taskgroups.__file__,
ANY,
"TaskGroup._aexit",
]
),
tuple(
[
taskgroups.__file__,
ANY,
"TaskGroup.__aexit__",
]
),
tuple([script_name, 26, "main"]),
),
),
(
"sub_main_1",
(tuple([script_name, 23, "c1"]),),
),
(
"sub_main_2",
(tuple([script_name, 23, "c1"]),),
),
],
"sub_main_1": [
(
"Task-1",
(
tuple(
[
taskgroups.__file__,
ANY,
"TaskGroup._aexit",
]
),
tuple(
[
taskgroups.__file__,
ANY,
"TaskGroup.__aexit__",
]
),
tuple([script_name, 26, "main"]),
),
)
],
"sub_main_2": [
(
"Task-1",
(
tuple(
[
taskgroups.__file__,
ANY,
"TaskGroup._aexit",
]
),
tuple(
[
taskgroups.__file__,
ANY,
"TaskGroup.__aexit__",
]
),
tuple([script_name, 26, "main"]),
),
)
],
},
)
finally:
_cleanup_sockets(client_socket, server_socket)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_asyncgen_remote_stack_trace(self):
port = find_unused_port()
script = textwrap.dedent(
f"""\
import asyncio
import time
import sys
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
async def gen_nested_call():
sock.sendall(b"ready"); time.sleep(10_000)
async def gen():
for num in range(2):
yield num
if num == 1:
await gen_nested_call()
async def main():
async for el in gen():
pass
asyncio.run(main())
"""
)
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
with _managed_subprocess([sys.executable, script_name]) as p:
client_socket, _ = server_socket.accept()
server_socket.close()
server_socket = None
response = _wait_for_signal(client_socket, b"ready")
self.assertIn(b"ready", response)
try:
stack_trace = get_async_stack_trace(p.pid)
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)
# For this simple asyncgen test, we only expect one task
self.assertEqual(len(stack_trace[0].awaited_by), 1)
task = stack_trace[0].awaited_by[0]
self.assertEqual(task.task_name, "Task-1")
# Check the coroutine stack
coroutine_stack = sorted(
tuple(tuple(frame) for frame in coro.call_stack)
for coro in task.coroutine_stack
)
self.assertEqual(
coroutine_stack,
[
(
tuple([script_name, 10, "gen_nested_call"]),
tuple([script_name, 16, "gen"]),
tuple([script_name, 19, "main"]),
)
],
)
# No awaited_by relationships expected
self.assertEqual(task.awaited_by, [])
finally:
_cleanup_sockets(client_socket, server_socket)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_async_gather_remote_stack_trace(self):
port = find_unused_port()
script = textwrap.dedent(
f"""\
import asyncio
import time
import sys
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
async def deep():
await asyncio.sleep(0)
sock.sendall(b"ready"); time.sleep(10_000)
async def c1():
await asyncio.sleep(0)
await deep()
async def c2():
await asyncio.sleep(0)
async def main():
await asyncio.gather(c1(), c2())
asyncio.run(main())
"""
)
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
with _managed_subprocess([sys.executable, script_name]) as p:
client_socket, _ = server_socket.accept()
server_socket.close()
server_socket = None
response = _wait_for_signal(client_socket, b"ready")
self.assertIn(b"ready", response)
try:
stack_trace = get_async_stack_trace(p.pid)
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)
# Check all tasks are present
tasks_names = [
task.task_name for task in stack_trace[0].awaited_by
]
for task_name in ["Task-1", "Task-2"]:
self.assertIn(task_name, tasks_names)
# Check awaited_by relationships
relationships = self._get_awaited_by_relationships(
stack_trace
)
self.assertEqual(
relationships,
{
"Task-1": set(),
"Task-2": {"Task-1"},
},
)
# Check coroutine stacks
coroutine_stacks = self._extract_coroutine_stacks(
stack_trace
)
self.assertEqual(
coroutine_stacks,
{
"Task-1": [(tuple([script_name, 21, "main"]),)],
"Task-2": [
(
tuple([script_name, 11, "deep"]),
tuple([script_name, 15, "c1"]),
)
],
},
)
# Check awaited_by coroutine stacks
id_to_task = self._get_task_id_map(stack_trace)
awaited_by_coroutine_stacks = {
task.task_name: sorted(
(
id_to_task[coro.task_name].task_name,
tuple(
tuple(frame) for frame in coro.call_stack
),
)
for coro in task.awaited_by
)
for task in stack_trace[0].awaited_by
}
self.assertEqual(
awaited_by_coroutine_stacks,
{
"Task-1": [],
"Task-2": [
("Task-1", (tuple([script_name, 21, "main"]),))
],
},
)
finally:
_cleanup_sockets(client_socket, server_socket)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_async_staggered_race_remote_stack_trace(self):
port = find_unused_port()
script = textwrap.dedent(
f"""\
import asyncio.staggered
import time
import sys
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
async def deep():
await asyncio.sleep(0)
sock.sendall(b"ready"); time.sleep(10_000)
async def c1():
await asyncio.sleep(0)
await deep()
async def c2():
await asyncio.sleep(10_000)
async def main():
await asyncio.staggered.staggered_race(
[c1, c2],
delay=None,
)
asyncio.run(main())
"""
)
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
with _managed_subprocess([sys.executable, script_name]) as p:
client_socket, _ = server_socket.accept()
server_socket.close()
server_socket = None
response = _wait_for_signal(client_socket, b"ready")
self.assertIn(b"ready", response)
try:
stack_trace = get_async_stack_trace(p.pid)
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)
# Check all tasks are present
tasks_names = [
task.task_name for task in stack_trace[0].awaited_by
]
for task_name in ["Task-1", "Task-2"]:
self.assertIn(task_name, tasks_names)
# Check awaited_by relationships
relationships = self._get_awaited_by_relationships(
stack_trace
)
self.assertEqual(
relationships,
{
"Task-1": set(),
"Task-2": {"Task-1"},
},
)
# Check coroutine stacks
coroutine_stacks = self._extract_coroutine_stacks(
stack_trace
)
self.assertEqual(
coroutine_stacks,
{
"Task-1": [
(
tuple(
[
staggered.__file__,
ANY,
"staggered_race",
]
),
tuple([script_name, 21, "main"]),
)
],
"Task-2": [
(
tuple([script_name, 11, "deep"]),
tuple([script_name, 15, "c1"]),
tuple(
[
staggered.__file__,
ANY,
"staggered_race.<locals>.run_one_coro",
]
),
)
],
},
)
# Check awaited_by coroutine stacks
id_to_task = self._get_task_id_map(stack_trace)
awaited_by_coroutine_stacks = {
task.task_name: sorted(
(
id_to_task[coro.task_name].task_name,
tuple(
tuple(frame) for frame in coro.call_stack
),
)
for coro in task.awaited_by
)
for task in stack_trace[0].awaited_by
}
self.assertEqual(
awaited_by_coroutine_stacks,
{
"Task-1": [],
"Task-2": [
(
"Task-1",
(
tuple(
[
staggered.__file__,
ANY,
"staggered_race",
]
),
tuple([script_name, 21, "main"]),
),
)
],
},
)
finally:
_cleanup_sockets(client_socket, server_socket)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_async_global_awaited_by(self):
# Reduced from 1000 to 100 to avoid file descriptor exhaustion
# when running tests in parallel (e.g., -j 20)
NUM_TASKS = 100
port = find_unused_port()
script = textwrap.dedent(
f"""\
import asyncio
import os
import random
import sys
import socket
from string import ascii_lowercase, digits
from test.support import socket_helper, SHORT_TIMEOUT
HOST = '127.0.0.1'
PORT = socket_helper.find_unused_port()
connections = 0
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
class EchoServerProtocol(asyncio.Protocol):
def connection_made(self, transport):
global connections
connections += 1
self.transport = transport
def data_received(self, data):
self.transport.write(data)
self.transport.close()
async def echo_client(message):
reader, writer = await asyncio.open_connection(HOST, PORT)
writer.write(message.encode())
await writer.drain()
data = await reader.read(100)
assert message == data.decode()
writer.close()
await writer.wait_closed()
sock.sendall(b"ready")
await asyncio.sleep(SHORT_TIMEOUT)
async def echo_client_spam(server):
async with asyncio.TaskGroup() as tg:
while connections < {NUM_TASKS}:
msg = list(ascii_lowercase + digits)
random.shuffle(msg)
tg.create_task(echo_client("".join(msg)))
await asyncio.sleep(0)
server.close()
await server.wait_closed()
async def main():
loop = asyncio.get_running_loop()
server = await loop.create_server(EchoServerProtocol, HOST, PORT)
async with server:
async with asyncio.TaskGroup() as tg:
tg.create_task(server.serve_forever(), name="server task")
tg.create_task(echo_client_spam(server), name="echo client spam")
asyncio.run(main())
"""
)
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
with _managed_subprocess([sys.executable, script_name]) as p:
client_socket, _ = server_socket.accept()
server_socket.close()
server_socket = None
# Wait for NUM_TASKS "ready" signals
try:
_wait_for_n_signals(client_socket, b"ready", NUM_TASKS)
except RuntimeError as e:
self.fail(str(e))
try:
all_awaited_by = get_all_awaited_by(p.pid)
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)
# Expected: a list of two elements: 1 thread, 1 interp
self.assertEqual(len(all_awaited_by), 2)
# Expected: a tuple with the thread ID and the awaited_by list
self.assertEqual(len(all_awaited_by[0]), 2)
# Expected: no tasks in the fallback per-interp task list
self.assertEqual(all_awaited_by[1], (0, []))
entries = all_awaited_by[0][1]
# Expected: at least NUM_TASKS pending tasks
self.assertGreaterEqual(len(entries), NUM_TASKS)
# Check the main task structure
main_stack = [
FrameInfo(
[taskgroups.__file__, ANY, "TaskGroup._aexit"]
),
FrameInfo(
[taskgroups.__file__, ANY, "TaskGroup.__aexit__"]
),
FrameInfo([script_name, 52, "main"]),
]
self.assertIn(
TaskInfo(
[ANY, "Task-1", [CoroInfo([main_stack, ANY])], []]
),
entries,
)
self.assertIn(
TaskInfo(
[
ANY,
"server task",
[
CoroInfo(
[
[
FrameInfo(
[
base_events.__file__,
ANY,
"Server.serve_forever",
]
)
],
ANY,
]
)
],
[
CoroInfo(
[
[
FrameInfo(
[
taskgroups.__file__,
ANY,
"TaskGroup._aexit",
]
),
FrameInfo(
[
taskgroups.__file__,
ANY,
"TaskGroup.__aexit__",
]
),
FrameInfo(
[script_name, ANY, "main"]
),
],
ANY,
]
)
],
]
),
entries,
)
self.assertIn(
TaskInfo(
[
ANY,
"Task-4",
[
CoroInfo(
[
[
FrameInfo(
[
tasks.__file__,
ANY,
"sleep",
]
),
FrameInfo(
[
script_name,
36,
"echo_client",
]
),
],
ANY,
]
)
],
[
CoroInfo(
[
[
FrameInfo(
[
taskgroups.__file__,
ANY,
"TaskGroup._aexit",
]
),
FrameInfo(
[
taskgroups.__file__,
ANY,
"TaskGroup.__aexit__",
]
),
FrameInfo(
[
script_name,
39,
"echo_client_spam",
]
),
],
ANY,
]
)
],
]
),
entries,
)
expected_awaited_by = [
CoroInfo(
[
[
FrameInfo(
[
taskgroups.__file__,
ANY,
"TaskGroup._aexit",
]
),
FrameInfo(
[
taskgroups.__file__,
ANY,
"TaskGroup.__aexit__",
]
),
FrameInfo(
[script_name, 39, "echo_client_spam"]
),
],
ANY,
]
)
]
tasks_with_awaited = [
task
for task in entries
if task.awaited_by == expected_awaited_by
]
self.assertGreaterEqual(len(tasks_with_awaited), NUM_TASKS)
# Final task should be from echo client spam (not on Windows)
if sys.platform != "win32":
self.assertEqual(
tasks_with_awaited[-1].awaited_by,
entries[-1].awaited_by,
)
finally:
_cleanup_sockets(client_socket, server_socket)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_self_trace(self):
stack_trace = get_stack_trace(os.getpid())
this_thread_stack = None
for interpreter_info in stack_trace:
for thread_info in interpreter_info.threads:
if thread_info.thread_id == threading.get_native_id():
this_thread_stack = thread_info.frame_info
break
if this_thread_stack:
break
self.assertIsNotNone(this_thread_stack)
self.assertEqual(
this_thread_stack[:2],
[
FrameInfo(
[
__file__,
get_stack_trace.__code__.co_firstlineno + 4,
"get_stack_trace",
]
),
FrameInfo(
[
__file__,
self.test_self_trace.__code__.co_firstlineno + 6,
"TestGetStackTrace.test_self_trace",
]
),
],
)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
@requires_subinterpreters
def test_subinterpreter_stack_trace(self):
port = find_unused_port()
import pickle
subinterp_code = textwrap.dedent(f"""
import socket
import time
def sub_worker():
def nested_func():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
sock.sendall(b"ready:sub\\n")
time.sleep(10_000)
nested_func()
sub_worker()
""").strip()
pickled_code = pickle.dumps(subinterp_code)
script = textwrap.dedent(
f"""
from concurrent import interpreters
import time
import sys
import socket
import threading
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
def main_worker():
sock.sendall(b"ready:main\\n")
time.sleep(10_000)
def run_subinterp():
subinterp = interpreters.create()
import pickle
pickled_code = {pickled_code!r}
subinterp_code = pickle.loads(pickled_code)
subinterp.exec(subinterp_code)
sub_thread = threading.Thread(target=run_subinterp)
sub_thread.start()
main_thread = threading.Thread(target=main_worker)
main_thread.start()
main_thread.join()
sub_thread.join()
"""
)
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_sockets = []
try:
with _managed_subprocess([sys.executable, script_name]) as p:
# Accept connections from both main and subinterpreter
responses = set()
while len(responses) < 2:
try:
client_socket, _ = server_socket.accept()
client_sockets.append(client_socket)
response = client_socket.recv(1024)
if b"ready:main" in response:
responses.add("main")
if b"ready:sub" in response:
responses.add("sub")
except socket.timeout:
break
server_socket.close()
server_socket = None
try:
stack_trace = get_stack_trace(p.pid)
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)
# Verify we have at least one interpreter
self.assertGreaterEqual(len(stack_trace), 1)
# Look for main interpreter (ID 0) and subinterpreter (ID > 0)
main_interp = None
sub_interp = None
for interpreter_info in stack_trace:
if interpreter_info.interpreter_id == 0:
main_interp = interpreter_info
elif interpreter_info.interpreter_id > 0:
sub_interp = interpreter_info
self.assertIsNotNone(
main_interp, "Main interpreter should be present"
)
# Check main interpreter has expected stack trace
main_found = self._find_frame_in_trace(
[main_interp], lambda f: f.funcname == "main_worker"
)
self.assertIsNotNone(
main_found,
"Main interpreter should have main_worker in stack",
)
# If subinterpreter is present, check its stack trace
if sub_interp:
sub_found = self._find_frame_in_trace(
[sub_interp],
lambda f: f.funcname
in ("sub_worker", "nested_func"),
)
self.assertIsNotNone(
sub_found,
"Subinterpreter should have sub_worker or nested_func in stack",
)
finally:
_cleanup_sockets(*client_sockets, server_socket)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
@requires_subinterpreters
def test_multiple_subinterpreters_with_threads(self):
port = find_unused_port()
import pickle
subinterp1_code = textwrap.dedent(f"""
import socket
import time
import threading
def worker1():
def nested_func():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
sock.sendall(b"ready:sub1-t1\\n")
time.sleep(10_000)
nested_func()
def worker2():
def nested_func():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
sock.sendall(b"ready:sub1-t2\\n")
time.sleep(10_000)
nested_func()
t1 = threading.Thread(target=worker1)
t2 = threading.Thread(target=worker2)
t1.start()
t2.start()
t1.join()
t2.join()
""").strip()
subinterp2_code = textwrap.dedent(f"""
import socket
import time
import threading
def worker1():
def nested_func():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
sock.sendall(b"ready:sub2-t1\\n")
time.sleep(10_000)
nested_func()
def worker2():
def nested_func():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
sock.sendall(b"ready:sub2-t2\\n")
time.sleep(10_000)
nested_func()
t1 = threading.Thread(target=worker1)
t2 = threading.Thread(target=worker2)
t1.start()
t2.start()
t1.join()
t2.join()
""").strip()
pickled_code1 = pickle.dumps(subinterp1_code)
pickled_code2 = pickle.dumps(subinterp2_code)
script = textwrap.dedent(
f"""
from concurrent import interpreters
import time
import sys
import socket
import threading
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
def main_worker():
sock.sendall(b"ready:main\\n")
time.sleep(10_000)
def run_subinterp1():
subinterp = interpreters.create()
import pickle
pickled_code = {pickled_code1!r}
subinterp_code = pickle.loads(pickled_code)
subinterp.exec(subinterp_code)
def run_subinterp2():
subinterp = interpreters.create()
import pickle
pickled_code = {pickled_code2!r}
subinterp_code = pickle.loads(pickled_code)
subinterp.exec(subinterp_code)
sub1_thread = threading.Thread(target=run_subinterp1)
sub2_thread = threading.Thread(target=run_subinterp2)
sub1_thread.start()
sub2_thread.start()
main_thread = threading.Thread(target=main_worker)
main_thread.start()
main_thread.join()
sub1_thread.join()
sub2_thread.join()
"""
)
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
server_socket = _create_server_socket(port, backlog=5)
script_name = _make_test_script(script_dir, "script", script)
client_sockets = []
try:
with _managed_subprocess([sys.executable, script_name]) as p:
# Accept connections from main and all subinterpreter threads
expected_responses = {
"ready:main",
"ready:sub1-t1",
"ready:sub1-t2",
"ready:sub2-t1",
"ready:sub2-t2",
}
responses = set()
while len(responses) < 5:
try:
client_socket, _ = server_socket.accept()
client_sockets.append(client_socket)
response = client_socket.recv(1024)
response_str = response.decode().strip()
if response_str in expected_responses:
responses.add(response_str)
except socket.timeout:
break
server_socket.close()
server_socket = None
try:
stack_trace = get_stack_trace(p.pid)
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)
# Verify we have multiple interpreters
self.assertGreaterEqual(len(stack_trace), 2)
# Count interpreters by ID
interpreter_ids = {
interp.interpreter_id for interp in stack_trace
}
self.assertIn(
0,
interpreter_ids,
"Main interpreter should be present",
)
self.assertGreaterEqual(len(interpreter_ids), 3)
# Count total threads
total_threads = sum(
len(interp.threads) for interp in stack_trace
)
self.assertGreaterEqual(total_threads, 5)
# Look for expected function names
all_funcnames = set()
for interpreter_info in stack_trace:
for thread_info in interpreter_info.threads:
for frame in thread_info.frame_info:
all_funcnames.add(frame.funcname)
expected_funcs = {
"main_worker",
"worker1",
"worker2",
"nested_func",
}
found_funcs = expected_funcs.intersection(all_funcnames)
self.assertGreater(len(found_funcs), 0)
finally:
_cleanup_sockets(*client_sockets, server_socket)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
@requires_gil_enabled("Free threaded builds don't have an 'active thread'")
def test_only_active_thread(self):
port = find_unused_port()
script = textwrap.dedent(
f"""\
import time, sys, socket, threading
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
def worker_thread(name, barrier, ready_event):
barrier.wait()
ready_event.wait()
time.sleep(10_000)
def main_work():
sock.sendall(b"working\\n")
count = 0
while count < 100000000:
count += 1
if count % 10000000 == 0:
pass
sock.sendall(b"done\\n")
num_threads = 3
barrier = threading.Barrier(num_threads + 1)
ready_event = threading.Event()
threads = []
for i in range(num_threads):
t = threading.Thread(target=worker_thread, args=(f"Worker-{{i}}", barrier, ready_event))
t.start()
threads.append(t)
barrier.wait()
sock.sendall(b"ready\\n")
ready_event.set()
main_work()
"""
)
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
with _managed_subprocess([sys.executable, script_name]) as p:
client_socket, _ = server_socket.accept()
server_socket.close()
server_socket = None
# Wait for ready and working signals
_wait_for_signal(client_socket, [b"ready", b"working"])
try:
# Get stack trace with all threads
unwinder_all = RemoteUnwinder(p.pid, all_threads=True)
for _ in range(MAX_TRIES):
all_traces = unwinder_all.get_stack_trace()
found = self._find_frame_in_trace(
all_traces,
lambda f: f.funcname == "main_work"
and f.lineno > 12,
)
if found:
break
time.sleep(0.1)
else:
self.fail(
"Main thread did not start its busy work on time"
)
# Get stack trace with only GIL holder
unwinder_gil = RemoteUnwinder(
p.pid, only_active_thread=True
)
gil_traces = unwinder_gil.get_stack_trace()
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)
# Count threads
total_threads = sum(
len(interp.threads) for interp in all_traces
)
self.assertGreater(total_threads, 1)
total_gil_threads = sum(
len(interp.threads) for interp in gil_traces
)
self.assertEqual(total_gil_threads, 1)
# Get the GIL holder thread ID
gil_thread_id = None
for interpreter_info in gil_traces:
if interpreter_info.threads:
gil_thread_id = interpreter_info.threads[
0
].thread_id
break
# Get all thread IDs
all_thread_ids = []
for interpreter_info in all_traces:
for thread_info in interpreter_info.threads:
all_thread_ids.append(thread_info.thread_id)
self.assertIn(gil_thread_id, all_thread_ids)
finally:
_cleanup_sockets(client_socket, server_socket)
class TestUnsupportedPlatformHandling(unittest.TestCase):
@unittest.skipIf(
sys.platform in ("linux", "darwin", "win32"),
"Test only runs on unsupported platforms (not Linux, macOS, or Windows)",
)
@unittest.skipIf(
sys.platform == "android", "Android raises Linux-specific exception"
)
def test_unsupported_platform_error(self):
with self.assertRaises(RuntimeError) as cm:
RemoteUnwinder(os.getpid())
self.assertIn(
"Reading the PyRuntime section is not supported on this platform",
str(cm.exception),
)
class TestDetectionOfThreadStatus(RemoteInspectionTestBase):
def _run_thread_status_test(self, mode, check_condition):
"""
Common pattern for thread status detection tests.
Args:
mode: Profiling mode (PROFILING_MODE_CPU, PROFILING_MODE_GIL, etc.)
check_condition: Function(statuses, sleeper_tid, busy_tid) -> bool
"""
port = find_unused_port()
script = textwrap.dedent(
f"""\
import time, sys, socket, threading
import os
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
def sleeper():
tid = threading.get_native_id()
sock.sendall(f'ready:sleeper:{{tid}}\\n'.encode())
time.sleep(10000)
def busy():
tid = threading.get_native_id()
sock.sendall(f'ready:busy:{{tid}}\\n'.encode())
x = 0
while True:
x = x + 1
time.sleep(0.5)
t1 = threading.Thread(target=sleeper)
t2 = threading.Thread(target=busy)
t1.start()
t2.start()
sock.sendall(b'ready:main\\n')
t1.join()
t2.join()
sock.close()
"""
)
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
server_socket = _create_server_socket(port)
script_name = _make_test_script(
script_dir, "thread_status_script", script
)
client_socket = None
try:
with _managed_subprocess([sys.executable, script_name]) as p:
client_socket, _ = server_socket.accept()
server_socket.close()
server_socket = None
# Wait for all ready signals and parse TIDs
response = _wait_for_signal(
client_socket,
[b"ready:main", b"ready:sleeper", b"ready:busy"],
)
sleeper_tid = None
busy_tid = None
for line in response.split(b"\n"):
if line.startswith(b"ready:sleeper:"):
try:
sleeper_tid = int(line.split(b":")[-1])
except (ValueError, IndexError):
pass
elif line.startswith(b"ready:busy:"):
try:
busy_tid = int(line.split(b":")[-1])
except (ValueError, IndexError):
pass
self.assertIsNotNone(
sleeper_tid, "Sleeper thread id not received"
)
self.assertIsNotNone(
busy_tid, "Busy thread id not received"
)
# Sample until we see expected thread states
statuses = {}
try:
unwinder = RemoteUnwinder(
p.pid,
all_threads=True,
mode=mode,
skip_non_matching_threads=False,
)
for _ in range(MAX_TRIES):
traces = unwinder.get_stack_trace()
statuses = self._get_thread_statuses(traces)
if check_condition(
statuses, sleeper_tid, busy_tid
):
break
time.sleep(0.5)
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)
return statuses, sleeper_tid, busy_tid
finally:
_cleanup_sockets(client_socket, server_socket)
@unittest.skipIf(
sys.platform not in ("linux", "darwin", "win32"),
"Test only runs on supported platforms (Linux, macOS, or Windows)",
)
@unittest.skipIf(
sys.platform == "android", "Android raises Linux-specific exception"
)
def test_thread_status_detection(self):
def check_cpu_status(statuses, sleeper_tid, busy_tid):
return (
sleeper_tid in statuses
and busy_tid in statuses
and not (statuses[sleeper_tid] & THREAD_STATUS_ON_CPU)
and (statuses[busy_tid] & THREAD_STATUS_ON_CPU)
)
statuses, sleeper_tid, busy_tid = self._run_thread_status_test(
PROFILING_MODE_CPU, check_cpu_status
)
self.assertIn(sleeper_tid, statuses)
self.assertIn(busy_tid, statuses)
self.assertFalse(
statuses[sleeper_tid] & THREAD_STATUS_ON_CPU,
"Sleeper thread should be off CPU",
)
self.assertTrue(
statuses[busy_tid] & THREAD_STATUS_ON_CPU,
"Busy thread should be on CPU",
)
@unittest.skipIf(
sys.platform not in ("linux", "darwin", "win32"),
"Test only runs on supported platforms (Linux, macOS, or Windows)",
)
@unittest.skipIf(
sys.platform == "android", "Android raises Linux-specific exception"
)
def test_thread_status_gil_detection(self):
def check_gil_status(statuses, sleeper_tid, busy_tid):
return (
sleeper_tid in statuses
and busy_tid in statuses
and not (statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL)
and (statuses[busy_tid] & THREAD_STATUS_HAS_GIL)
)
statuses, sleeper_tid, busy_tid = self._run_thread_status_test(
PROFILING_MODE_GIL, check_gil_status
)
self.assertIn(sleeper_tid, statuses)
self.assertIn(busy_tid, statuses)
self.assertFalse(
statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL,
"Sleeper thread should not have GIL",
)
self.assertTrue(
statuses[busy_tid] & THREAD_STATUS_HAS_GIL,
"Busy thread should have GIL",
)
@unittest.skipIf(
sys.platform not in ("linux", "darwin", "win32"),
"Test only runs on supported platforms (Linux, macOS, or Windows)",
)
@unittest.skipIf(
sys.platform == "android", "Android raises Linux-specific exception"
)
def test_thread_status_all_mode_detection(self):
port = find_unused_port()
script = textwrap.dedent(
f"""\
import socket
import threading
import time
import sys
def sleeper_thread():
conn = socket.create_connection(("localhost", {port}))
conn.sendall(b"sleeper:" + str(threading.get_native_id()).encode())
while True:
time.sleep(1)
def busy_thread():
conn = socket.create_connection(("localhost", {port}))
conn.sendall(b"busy:" + str(threading.get_native_id()).encode())
while True:
sum(range(100000))
t1 = threading.Thread(target=sleeper_thread)
t2 = threading.Thread(target=busy_thread)
t1.start()
t2.start()
t1.join()
t2.join()
"""
)
with os_helper.temp_dir() as tmp_dir:
script_file = make_script(tmp_dir, "script", script)
server_socket = _create_server_socket(port, backlog=2)
client_sockets = []
try:
with _managed_subprocess(
[sys.executable, script_file],
) as p:
sleeper_tid = None
busy_tid = None
# Receive thread IDs from the child process
for _ in range(2):
client_socket, _ = server_socket.accept()
client_sockets.append(client_socket)
line = client_socket.recv(1024)
if line:
if line.startswith(b"sleeper:"):
try:
sleeper_tid = int(line.split(b":")[-1])
except (ValueError, IndexError):
pass
elif line.startswith(b"busy:"):
try:
busy_tid = int(line.split(b":")[-1])
except (ValueError, IndexError):
pass
server_socket.close()
server_socket = None
statuses = {}
try:
unwinder = RemoteUnwinder(
p.pid,
all_threads=True,
mode=PROFILING_MODE_ALL,
skip_non_matching_threads=False,
)
for _ in range(MAX_TRIES):
traces = unwinder.get_stack_trace()
statuses = self._get_thread_statuses(traces)
# Check ALL mode provides both GIL and CPU info
if (
sleeper_tid in statuses
and busy_tid in statuses
and not (
statuses[sleeper_tid]
& THREAD_STATUS_ON_CPU
)
and not (
statuses[sleeper_tid]
& THREAD_STATUS_HAS_GIL
)
and (statuses[busy_tid] & THREAD_STATUS_ON_CPU)
and (
statuses[busy_tid] & THREAD_STATUS_HAS_GIL
)
):
break
time.sleep(0.5)
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)
self.assertIsNotNone(
sleeper_tid, "Sleeper thread id not received"
)
self.assertIsNotNone(
busy_tid, "Busy thread id not received"
)
self.assertIn(sleeper_tid, statuses)
self.assertIn(busy_tid, statuses)
# Sleeper: off CPU, no GIL
self.assertFalse(
statuses[sleeper_tid] & THREAD_STATUS_ON_CPU,
"Sleeper should be off CPU",
)
self.assertFalse(
statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL,
"Sleeper should not have GIL",
)
# Busy: on CPU, has GIL
self.assertTrue(
statuses[busy_tid] & THREAD_STATUS_ON_CPU,
"Busy should be on CPU",
)
self.assertTrue(
statuses[busy_tid] & THREAD_STATUS_HAS_GIL,
"Busy should have GIL",
)
finally:
_cleanup_sockets(*client_sockets, server_socket)
class TestFrameCaching(RemoteInspectionTestBase):
"""Test that frame caching produces correct results.
Uses socket-based synchronization for deterministic testing.
All tests verify cache reuse via object identity checks (assertIs).
"""
@contextmanager
def _target_process(self, script_body):
"""Context manager for running a target process with socket sync."""
port = find_unused_port()
script = f"""\
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
{textwrap.dedent(script_body)}
"""
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
with _managed_subprocess([sys.executable, script_name]) as p:
client_socket, _ = server_socket.accept()
server_socket.close()
server_socket = None
def make_unwinder(cache_frames=True):
return RemoteUnwinder(
p.pid, all_threads=True, cache_frames=cache_frames
)
yield p, client_socket, make_unwinder
except PermissionError:
self.skipTest(
"Insufficient permissions to read the stack trace"
)
finally:
_cleanup_sockets(client_socket, server_socket)
def _get_frames_with_retry(self, unwinder, required_funcs):
"""Get frames containing required_funcs, with retry for transient errors."""
for _ in range(MAX_TRIES):
try:
traces = unwinder.get_stack_trace()
for interp in traces:
for thread in interp.threads:
funcs = {f.funcname for f in thread.frame_info}
if required_funcs.issubset(funcs):
return thread.frame_info
except RuntimeError as e:
if _is_retriable_error(e):
pass
else:
raise
time.sleep(0.1)
return None
def _sample_frames(
self,
client_socket,
unwinder,
wait_signal,
send_ack,
required_funcs,
expected_frames=1,
):
"""Wait for signal, sample frames with retry until required funcs present, send ack."""
_wait_for_signal(client_socket, wait_signal)
frames = None
for _ in range(MAX_TRIES):
frames = self._get_frames_with_retry(unwinder, required_funcs)
if frames and len(frames) >= expected_frames:
break
time.sleep(0.1)
client_socket.sendall(send_ack)
return frames
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_cache_hit_same_stack(self):
"""Test that consecutive samples reuse cached parent frame objects.
The current frame (index 0) is always re-read from memory to get
updated line numbers, so it may be a different object. Parent frames
(index 1+) should be identical objects from cache.
"""
script_body = """\
def level3():
sock.sendall(b"sync1")
sock.recv(16)
sock.sendall(b"sync2")
sock.recv(16)
sock.sendall(b"sync3")
sock.recv(16)
def level2():
level3()
def level1():
level2()
level1()
"""
with self._target_process(script_body) as (
p,
client_socket,
make_unwinder,
):
unwinder = make_unwinder(cache_frames=True)
expected = {"level1", "level2", "level3"}
frames1 = self._sample_frames(
client_socket, unwinder, b"sync1", b"ack", expected
)
frames2 = self._sample_frames(
client_socket, unwinder, b"sync2", b"ack", expected
)
frames3 = self._sample_frames(
client_socket, unwinder, b"sync3", b"done", expected
)
self.assertIsNotNone(frames1)
self.assertIsNotNone(frames2)
self.assertIsNotNone(frames3)
self.assertEqual(len(frames1), len(frames2))
self.assertEqual(len(frames2), len(frames3))
# Current frame (index 0) is always re-read, so check value equality
self.assertEqual(frames1[0].funcname, frames2[0].funcname)
self.assertEqual(frames2[0].funcname, frames3[0].funcname)
# Parent frames (index 1+) must be identical objects (cache reuse)
for i in range(1, len(frames1)):
f1, f2, f3 = frames1[i], frames2[i], frames3[i]
self.assertIs(
f1, f2, f"Frame {i}: samples 1-2 must be same object"
)
self.assertIs(
f2, f3, f"Frame {i}: samples 2-3 must be same object"
)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_line_number_updates_in_same_frame(self):
"""Test that line numbers are correctly updated when execution moves within a function.
When the profiler samples at different points within the same function,
it must report the correct line number for each sample, not stale cached values.
"""
script_body = """\
def outer():
inner()
def inner():
sock.sendall(b"line_a"); sock.recv(16)
sock.sendall(b"line_b"); sock.recv(16)
sock.sendall(b"line_c"); sock.recv(16)
sock.sendall(b"line_d"); sock.recv(16)
outer()
"""
with self._target_process(script_body) as (
p,
client_socket,
make_unwinder,
):
unwinder = make_unwinder(cache_frames=True)
frames_a = self._sample_frames(
client_socket, unwinder, b"line_a", b"ack", {"inner"}
)
frames_b = self._sample_frames(
client_socket, unwinder, b"line_b", b"ack", {"inner"}
)
frames_c = self._sample_frames(
client_socket, unwinder, b"line_c", b"ack", {"inner"}
)
frames_d = self._sample_frames(
client_socket, unwinder, b"line_d", b"done", {"inner"}
)
self.assertIsNotNone(frames_a)
self.assertIsNotNone(frames_b)
self.assertIsNotNone(frames_c)
self.assertIsNotNone(frames_d)
# Get the 'inner' frame from each sample (should be index 0)
inner_a = frames_a[0]
inner_b = frames_b[0]
inner_c = frames_c[0]
inner_d = frames_d[0]
self.assertEqual(inner_a.funcname, "inner")
self.assertEqual(inner_b.funcname, "inner")
self.assertEqual(inner_c.funcname, "inner")
self.assertEqual(inner_d.funcname, "inner")
# Line numbers must be different and increasing (execution moves forward)
self.assertLess(
inner_a.lineno, inner_b.lineno, "Line B should be after line A"
)
self.assertLess(
inner_b.lineno, inner_c.lineno, "Line C should be after line B"
)
self.assertLess(
inner_c.lineno, inner_d.lineno, "Line D should be after line C"
)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_cache_invalidation_on_return(self):
"""Test cache invalidation when stack shrinks (function returns)."""
script_body = """\
def inner():
sock.sendall(b"at_inner")
sock.recv(16)
def outer():
inner()
sock.sendall(b"at_outer")
sock.recv(16)
outer()
"""
with self._target_process(script_body) as (
p,
client_socket,
make_unwinder,
):
unwinder = make_unwinder(cache_frames=True)
frames_deep = self._sample_frames(
client_socket,
unwinder,
b"at_inner",
b"ack",
{"inner", "outer"},
)
frames_shallow = self._sample_frames(
client_socket, unwinder, b"at_outer", b"done", {"outer"}
)
self.assertIsNotNone(frames_deep)
self.assertIsNotNone(frames_shallow)
funcs_deep = [f.funcname for f in frames_deep]
funcs_shallow = [f.funcname for f in frames_shallow]
self.assertIn("inner", funcs_deep)
self.assertIn("outer", funcs_deep)
self.assertNotIn("inner", funcs_shallow)
self.assertIn("outer", funcs_shallow)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_cache_invalidation_on_call(self):
"""Test cache invalidation when stack grows (new function called)."""
script_body = """\
def deeper():
sock.sendall(b"at_deeper")
sock.recv(16)
def middle():
sock.sendall(b"at_middle")
sock.recv(16)
deeper()
def top():
middle()
top()
"""
with self._target_process(script_body) as (
p,
client_socket,
make_unwinder,
):
unwinder = make_unwinder(cache_frames=True)
frames_before = self._sample_frames(
client_socket,
unwinder,
b"at_middle",
b"ack",
{"middle", "top"},
)
frames_after = self._sample_frames(
client_socket,
unwinder,
b"at_deeper",
b"done",
{"deeper", "middle", "top"},
)
self.assertIsNotNone(frames_before)
self.assertIsNotNone(frames_after)
funcs_before = [f.funcname for f in frames_before]
funcs_after = [f.funcname for f in frames_after]
self.assertIn("middle", funcs_before)
self.assertIn("top", funcs_before)
self.assertNotIn("deeper", funcs_before)
self.assertIn("deeper", funcs_after)
self.assertIn("middle", funcs_after)
self.assertIn("top", funcs_after)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_partial_stack_reuse(self):
"""Test that unchanged bottom frames are reused when top changes (A→B→C to A→B→D)."""
script_body = """\
def func_c():
sock.sendall(b"at_c")
sock.recv(16)
def func_d():
sock.sendall(b"at_d")
sock.recv(16)
def func_b():
func_c()
func_d()
def func_a():
func_b()
func_a()
"""
with self._target_process(script_body) as (
p,
client_socket,
make_unwinder,
):
unwinder = make_unwinder(cache_frames=True)
# Sample at C: stack is A→B→C
frames_c = self._sample_frames(
client_socket,
unwinder,
b"at_c",
b"ack",
{"func_a", "func_b", "func_c"},
)
# Sample at D: stack is A→B→D (C returned, D called)
frames_d = self._sample_frames(
client_socket,
unwinder,
b"at_d",
b"done",
{"func_a", "func_b", "func_d"},
)
self.assertIsNotNone(frames_c)
self.assertIsNotNone(frames_d)
# Find func_a and func_b frames in both samples
def find_frame(frames, funcname):
for f in frames:
if f.funcname == funcname:
return f
return None
frame_a_in_c = find_frame(frames_c, "func_a")
frame_b_in_c = find_frame(frames_c, "func_b")
frame_a_in_d = find_frame(frames_d, "func_a")
frame_b_in_d = find_frame(frames_d, "func_b")
self.assertIsNotNone(frame_a_in_c)
self.assertIsNotNone(frame_b_in_c)
self.assertIsNotNone(frame_a_in_d)
self.assertIsNotNone(frame_b_in_d)
# The bottom frames (A, B) should be the SAME objects (cache reuse)
self.assertIs(
frame_a_in_c,
frame_a_in_d,
"func_a frame should be reused from cache",
)
self.assertIs(
frame_b_in_c,
frame_b_in_d,
"func_b frame should be reused from cache",
)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_recursive_frames(self):
"""Test caching with same function appearing multiple times (recursion)."""
script_body = """\
def recurse(n):
if n <= 0:
sock.sendall(b"sync1")
sock.recv(16)
sock.sendall(b"sync2")
sock.recv(16)
else:
recurse(n - 1)
recurse(5)
"""
with self._target_process(script_body) as (
p,
client_socket,
make_unwinder,
):
unwinder = make_unwinder(cache_frames=True)
frames1 = self._sample_frames(
client_socket, unwinder, b"sync1", b"ack", {"recurse"}
)
frames2 = self._sample_frames(
client_socket, unwinder, b"sync2", b"done", {"recurse"}
)
self.assertIsNotNone(frames1)
self.assertIsNotNone(frames2)
# Should have multiple "recurse" frames (6 total: recurse(5) down to recurse(0))
recurse_count = sum(1 for f in frames1 if f.funcname == "recurse")
self.assertEqual(recurse_count, 6, "Should have 6 recursive frames")
self.assertEqual(len(frames1), len(frames2))
# Current frame (index 0) is re-read, check value equality
self.assertEqual(frames1[0].funcname, frames2[0].funcname)
# Parent frames (index 1+) should be identical objects (cache reuse)
for i in range(1, len(frames1)):
self.assertIs(
frames1[i],
frames2[i],
f"Frame {i}: recursive frames must be same object",
)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_cache_vs_no_cache_equivalence(self):
"""Test that cache_frames=True and cache_frames=False produce equivalent results."""
script_body = """\
def level3():
sock.sendall(b"ready"); sock.recv(16)
def level2():
level3()
def level1():
level2()
level1()
"""
with self._target_process(script_body) as (
p,
client_socket,
make_unwinder,
):
_wait_for_signal(client_socket, b"ready")
# Sample with cache
unwinder_cache = make_unwinder(cache_frames=True)
frames_cached = self._get_frames_with_retry(
unwinder_cache, {"level1", "level2", "level3"}
)
# Sample without cache
unwinder_no_cache = make_unwinder(cache_frames=False)
frames_no_cache = self._get_frames_with_retry(
unwinder_no_cache, {"level1", "level2", "level3"}
)
client_socket.sendall(b"done")
self.assertIsNotNone(frames_cached)
self.assertIsNotNone(frames_no_cache)
# Same number of frames
self.assertEqual(len(frames_cached), len(frames_no_cache))
# Same function names in same order
funcs_cached = [f.funcname for f in frames_cached]
funcs_no_cache = [f.funcname for f in frames_no_cache]
self.assertEqual(funcs_cached, funcs_no_cache)
# Same line numbers
lines_cached = [f.lineno for f in frames_cached]
lines_no_cache = [f.lineno for f in frames_no_cache]
self.assertEqual(lines_cached, lines_no_cache)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_cache_per_thread_isolation(self):
"""Test that frame cache is per-thread and cache invalidation works independently."""
script_body = """\
import threading
lock = threading.Lock()
def sync(msg):
with lock:
sock.sendall(msg + b"\\n")
sock.recv(1)
# Thread 1 functions
def baz1():
sync(b"t1:baz1")
def bar1():
baz1()
def blech1():
sync(b"t1:blech1")
def foo1():
bar1() # Goes down to baz1, syncs
blech1() # Returns up, goes down to blech1, syncs
# Thread 2 functions
def baz2():
sync(b"t2:baz2")
def bar2():
baz2()
def blech2():
sync(b"t2:blech2")
def foo2():
bar2() # Goes down to baz2, syncs
blech2() # Returns up, goes down to blech2, syncs
t1 = threading.Thread(target=foo1)
t2 = threading.Thread(target=foo2)
t1.start()
t2.start()
t1.join()
t2.join()
"""
with self._target_process(script_body) as (
p,
client_socket,
make_unwinder,
):
unwinder = make_unwinder(cache_frames=True)
buffer = b""
def recv_msg():
"""Receive a single message from socket."""
nonlocal buffer
while b"\n" not in buffer:
chunk = client_socket.recv(256)
if not chunk:
return None
buffer += chunk
msg, buffer = buffer.split(b"\n", 1)
return msg
def get_thread_frames(target_funcs):
"""Get frames for thread matching target functions."""
retries = 0
for _ in busy_retry(SHORT_TIMEOUT):
if retries >= 5:
break
retries += 1
# On Windows, ReadProcessMemory can fail with OSError
# (WinError 299) when frame pointers are in flux
with contextlib.suppress(RuntimeError, OSError):
traces = unwinder.get_stack_trace()
for interp in traces:
for thread in interp.threads:
funcs = [f.funcname for f in thread.frame_info]
if any(f in funcs for f in target_funcs):
return funcs
return None
# Track results for each sync point
results = {}
# Process 4 sync points: baz1, baz2, blech1, blech2
# With the lock, threads are serialized - handle one at a time
for _ in range(4):
msg = recv_msg()
self.assertIsNotNone(msg, "Expected message from subprocess")
# Determine which thread/function and take snapshot
if msg == b"t1:baz1":
funcs = get_thread_frames(["baz1", "bar1", "foo1"])
self.assertIsNotNone(funcs, "Thread 1 not found at baz1")
results["t1:baz1"] = funcs
elif msg == b"t2:baz2":
funcs = get_thread_frames(["baz2", "bar2", "foo2"])
self.assertIsNotNone(funcs, "Thread 2 not found at baz2")
results["t2:baz2"] = funcs
elif msg == b"t1:blech1":
funcs = get_thread_frames(["blech1", "foo1"])
self.assertIsNotNone(funcs, "Thread 1 not found at blech1")
results["t1:blech1"] = funcs
elif msg == b"t2:blech2":
funcs = get_thread_frames(["blech2", "foo2"])
self.assertIsNotNone(funcs, "Thread 2 not found at blech2")
results["t2:blech2"] = funcs
# Release thread to continue
client_socket.sendall(b"k")
# Validate Phase 1: baz snapshots
t1_baz = results.get("t1:baz1")
t2_baz = results.get("t2:baz2")
self.assertIsNotNone(t1_baz, "Missing t1:baz1 snapshot")
self.assertIsNotNone(t2_baz, "Missing t2:baz2 snapshot")
# Thread 1 at baz1: should have foo1->bar1->baz1
self.assertIn("baz1", t1_baz)
self.assertIn("bar1", t1_baz)
self.assertIn("foo1", t1_baz)
self.assertNotIn("blech1", t1_baz)
# No cross-contamination
self.assertNotIn("baz2", t1_baz)
self.assertNotIn("bar2", t1_baz)
self.assertNotIn("foo2", t1_baz)
# Thread 2 at baz2: should have foo2->bar2->baz2
self.assertIn("baz2", t2_baz)
self.assertIn("bar2", t2_baz)
self.assertIn("foo2", t2_baz)
self.assertNotIn("blech2", t2_baz)
# No cross-contamination
self.assertNotIn("baz1", t2_baz)
self.assertNotIn("bar1", t2_baz)
self.assertNotIn("foo1", t2_baz)
# Validate Phase 2: blech snapshots (cache invalidation test)
t1_blech = results.get("t1:blech1")
t2_blech = results.get("t2:blech2")
self.assertIsNotNone(t1_blech, "Missing t1:blech1 snapshot")
self.assertIsNotNone(t2_blech, "Missing t2:blech2 snapshot")
# Thread 1 at blech1: bar1/baz1 should be GONE (cache invalidated)
self.assertIn("blech1", t1_blech)
self.assertIn("foo1", t1_blech)
self.assertNotIn(
"bar1", t1_blech, "Cache not invalidated: bar1 still present"
)
self.assertNotIn(
"baz1", t1_blech, "Cache not invalidated: baz1 still present"
)
# No cross-contamination
self.assertNotIn("blech2", t1_blech)
# Thread 2 at blech2: bar2/baz2 should be GONE (cache invalidated)
self.assertIn("blech2", t2_blech)
self.assertIn("foo2", t2_blech)
self.assertNotIn(
"bar2", t2_blech, "Cache not invalidated: bar2 still present"
)
self.assertNotIn(
"baz2", t2_blech, "Cache not invalidated: baz2 still present"
)
# No cross-contamination
self.assertNotIn("blech1", t2_blech)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_new_unwinder_with_stale_last_profiled_frame(self):
"""Test that a new unwinder returns complete stack when cache lookup misses."""
script_body = """\
def level4():
sock.sendall(b"sync1")
sock.recv(16)
sock.sendall(b"sync2")
sock.recv(16)
def level3():
level4()
def level2():
level3()
def level1():
level2()
level1()
"""
with self._target_process(script_body) as (
p,
client_socket,
make_unwinder,
):
expected = {"level1", "level2", "level3", "level4"}
# First unwinder samples - this sets last_profiled_frame in target
unwinder1 = make_unwinder(cache_frames=True)
frames1 = self._sample_frames(
client_socket, unwinder1, b"sync1", b"ack", expected
)
# Create NEW unwinder (empty cache) and sample
# The target still has last_profiled_frame set from unwinder1
unwinder2 = make_unwinder(cache_frames=True)
frames2 = self._sample_frames(
client_socket, unwinder2, b"sync2", b"done", expected
)
self.assertIsNotNone(frames1)
self.assertIsNotNone(frames2)
funcs1 = [f.funcname for f in frames1]
funcs2 = [f.funcname for f in frames2]
# Both should have all levels
for level in ["level1", "level2", "level3", "level4"]:
self.assertIn(level, funcs1, f"{level} missing from first sample")
self.assertIn(level, funcs2, f"{level} missing from second sample")
# Should have same stack depth
self.assertEqual(
len(frames1),
len(frames2),
"New unwinder should return complete stack despite stale last_profiled_frame",
)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_cache_exhaustion(self):
"""Test cache works when frame limit (1024) is exceeded.
FRAME_CACHE_MAX_FRAMES=1024. With 1100 recursive frames,
the cache can't store all of them but should still work.
"""
# Use 1100 to exceed FRAME_CACHE_MAX_FRAMES=1024
depth = 1100
script_body = f"""\
import sys
sys.setrecursionlimit(2000)
def recurse(n):
if n <= 0:
sock.sendall(b"ready")
sock.recv(16) # wait for ack
sock.sendall(b"ready2")
sock.recv(16) # wait for done
return
recurse(n - 1)
recurse({depth})
"""
with self._target_process(script_body) as (
p,
client_socket,
make_unwinder,
):
unwinder_cache = make_unwinder(cache_frames=True)
unwinder_no_cache = make_unwinder(cache_frames=False)
frames_cached = self._sample_frames(
client_socket,
unwinder_cache,
b"ready",
b"ack",
{"recurse"},
expected_frames=1102,
)
# Sample again with no cache for comparison
frames_no_cache = self._sample_frames(
client_socket,
unwinder_no_cache,
b"ready2",
b"done",
{"recurse"},
expected_frames=1102,
)
self.assertIsNotNone(frames_cached)
self.assertIsNotNone(frames_no_cache)
# Both should have many recurse frames (> 1024 limit)
cached_count = [f.funcname for f in frames_cached].count("recurse")
no_cache_count = [f.funcname for f in frames_no_cache].count("recurse")
self.assertGreater(
cached_count, 1000, "Should have >1000 recurse frames"
)
self.assertGreater(
no_cache_count, 1000, "Should have >1000 recurse frames"
)
# Both modes should produce same frame count
self.assertEqual(
len(frames_cached),
len(frames_no_cache),
"Cache exhaustion should not affect stack completeness",
)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_get_stats(self):
"""Test that get_stats() returns statistics when stats=True."""
script_body = """\
sock.sendall(b"ready")
sock.recv(16)
"""
with self._target_process(script_body) as (p, client_socket, _):
unwinder = RemoteUnwinder(p.pid, all_threads=True, stats=True)
_wait_for_signal(client_socket, b"ready")
# Take a sample
unwinder.get_stack_trace()
stats = unwinder.get_stats()
client_socket.sendall(b"done")
# Verify expected keys exist
expected_keys = [
"total_samples",
"frame_cache_hits",
"frame_cache_misses",
"frame_cache_partial_hits",
"frames_read_from_cache",
"frames_read_from_memory",
"frame_cache_hit_rate",
]
for key in expected_keys:
self.assertIn(key, stats)
self.assertEqual(stats["total_samples"], 1)
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_get_stats_disabled_raises(self):
"""Test that get_stats() raises RuntimeError when stats=False."""
script_body = """\
sock.sendall(b"ready")
sock.recv(16)
"""
with self._target_process(script_body) as (p, client_socket, _):
unwinder = RemoteUnwinder(
p.pid, all_threads=True
) # stats=False by default
_wait_for_signal(client_socket, b"ready")
with self.assertRaises(RuntimeError):
unwinder.get_stats()
client_socket.sendall(b"done")
if __name__ == "__main__":
unittest.main()