gh-138122: Don't sample partial frame chains (#141912)

This commit is contained in:
Pablo Galindo Salgado 2025-12-07 15:53:48 +00:00 committed by GitHub
parent c5b37228af
commit d6d850df89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 135 additions and 105 deletions

View file

@ -1,7 +1,7 @@
import contextlib
import unittest
import os
import textwrap
import contextlib
import importlib
import sys
import socket
@ -216,33 +216,13 @@ def requires_subinterpreters(meth):
# Simple wrapper functions for RemoteUnwinder
# ============================================================================
# Errors that can occur transiently when reading process memory without synchronization
RETRIABLE_ERRORS = (
"Task list appears corrupted",
"Invalid linked list structure reading remote memory",
"Unknown error reading memory",
"Unhandled frame owner",
"Failed to parse initial frame",
"Failed to process frame chain",
"Failed to unwind stack",
)
def _is_retriable_error(exc):
"""Check if an exception is a transient error that should be retried."""
msg = str(exc)
return any(msg.startswith(err) or err in msg for err in RETRIABLE_ERRORS)
def get_stack_trace(pid):
for _ in busy_retry(SHORT_TIMEOUT):
try:
unwinder = RemoteUnwinder(pid, all_threads=True, debug=True)
return unwinder.get_stack_trace()
except RuntimeError as e:
if _is_retriable_error(e):
continue
raise
continue
raise RuntimeError("Failed to get stack trace after retries")
@ -252,9 +232,7 @@ def get_async_stack_trace(pid):
unwinder = RemoteUnwinder(pid, debug=True)
return unwinder.get_async_stack_trace()
except RuntimeError as e:
if _is_retriable_error(e):
continue
raise
continue
raise RuntimeError("Failed to get async stack trace after retries")
@ -264,9 +242,7 @@ def get_all_awaited_by(pid):
unwinder = RemoteUnwinder(pid, debug=True)
return unwinder.get_all_awaited_by()
except RuntimeError as e:
if _is_retriable_error(e):
continue
raise
continue
raise RuntimeError("Failed to get all awaited_by after retries")
@ -2268,18 +2244,13 @@ def make_unwinder(cache_frames=True):
def _get_frames_with_retry(self, unwinder, required_funcs):
"""Get frames containing required_funcs, with retry for transient errors."""
for _ in range(MAX_TRIES):
try:
with contextlib.suppress(OSError, RuntimeError):
traces = unwinder.get_stack_trace()
for interp in traces:
for thread in interp.threads:
funcs = {f.funcname for f in thread.frame_info}
if required_funcs.issubset(funcs):
return thread.frame_info
except RuntimeError as e:
if _is_retriable_error(e):
pass
else:
raise
time.sleep(0.1)
return None
@ -2802,70 +2773,39 @@ def foo2():
make_unwinder,
):
unwinder = make_unwinder(cache_frames=True)
buffer = b""
def recv_msg():
"""Receive a single message from socket."""
nonlocal buffer
while b"\n" not in buffer:
chunk = client_socket.recv(256)
if not chunk:
return None
buffer += chunk
msg, buffer = buffer.split(b"\n", 1)
return msg
def get_thread_frames(target_funcs):
"""Get frames for thread matching target functions."""
retries = 0
for _ in busy_retry(SHORT_TIMEOUT):
if retries >= 5:
break
retries += 1
# On Windows, ReadProcessMemory can fail with OSError
# (WinError 299) when frame pointers are in flux
with contextlib.suppress(RuntimeError, OSError):
traces = unwinder.get_stack_trace()
for interp in traces:
for thread in interp.threads:
funcs = [f.funcname for f in thread.frame_info]
if any(f in funcs for f in target_funcs):
return funcs
return None
# Message dispatch table: signal -> required functions for that thread
dispatch = {
b"t1:baz1": {"baz1", "bar1", "foo1"},
b"t2:baz2": {"baz2", "bar2", "foo2"},
b"t1:blech1": {"blech1", "foo1"},
b"t2:blech2": {"blech2", "foo2"},
}
# Track results for each sync point
results = {}
# Process 4 sync points: baz1, baz2, blech1, blech2
# With the lock, threads are serialized - handle one at a time
for _ in range(4):
msg = recv_msg()
self.assertIsNotNone(msg, "Expected message from subprocess")
# Process 4 sync points (order depends on thread scheduling)
buffer = _wait_for_signal(client_socket, b"\n")
for i in range(4):
# Extract first message from buffer
msg, sep, buffer = buffer.partition(b"\n")
self.assertIn(msg, dispatch, f"Unexpected message: {msg!r}")
# Determine which thread/function and take snapshot
if msg == b"t1:baz1":
funcs = get_thread_frames(["baz1", "bar1", "foo1"])
self.assertIsNotNone(funcs, "Thread 1 not found at baz1")
results["t1:baz1"] = funcs
elif msg == b"t2:baz2":
funcs = get_thread_frames(["baz2", "bar2", "foo2"])
self.assertIsNotNone(funcs, "Thread 2 not found at baz2")
results["t2:baz2"] = funcs
elif msg == b"t1:blech1":
funcs = get_thread_frames(["blech1", "foo1"])
self.assertIsNotNone(funcs, "Thread 1 not found at blech1")
results["t1:blech1"] = funcs
elif msg == b"t2:blech2":
funcs = get_thread_frames(["blech2", "foo2"])
self.assertIsNotNone(funcs, "Thread 2 not found at blech2")
results["t2:blech2"] = funcs
# Sample frames for the thread at this sync point
required_funcs = dispatch[msg]
frames = self._get_frames_with_retry(unwinder, required_funcs)
self.assertIsNotNone(frames, f"Thread not found for {msg!r}")
results[msg] = [f.funcname for f in frames]
# Release thread to continue
# Release thread and wait for next message (if not last)
client_socket.sendall(b"k")
if i < 3:
buffer += _wait_for_signal(client_socket, b"\n")
# Validate Phase 1: baz snapshots
t1_baz = results.get("t1:baz1")
t2_baz = results.get("t2:baz2")
t1_baz = results.get(b"t1:baz1")
t2_baz = results.get(b"t2:baz2")
self.assertIsNotNone(t1_baz, "Missing t1:baz1 snapshot")
self.assertIsNotNone(t2_baz, "Missing t2:baz2 snapshot")
@ -2890,8 +2830,8 @@ def get_thread_frames(target_funcs):
self.assertNotIn("foo1", t2_baz)
# Validate Phase 2: blech snapshots (cache invalidation test)
t1_blech = results.get("t1:blech1")
t2_blech = results.get("t2:blech2")
t1_blech = results.get(b"t1:blech1")
t2_blech = results.get(b"t2:blech2")
self.assertIsNotNone(t1_blech, "Missing t1:blech1 snapshot")
self.assertIsNotNone(t2_blech, "Missing t2:blech2 snapshot")