gh-138122: Make sampling profiler integration tests more resilient (#142382)

The tests were flaky on slow machines because subprocesses could finish
before enough samples were collected. This adds synchronization similar
to test_external_inspection: test scripts now signal when they start
working, and the profiler waits for this signal before sampling.

Test scripts now run in infinite loops until killed rather than for
fixed iterations, ensuring the profiler always has active work to
sample regardless of machine speed.
This commit is contained in:
Pablo Galindo Salgado 2025-12-07 22:41:15 +00:00 committed by GitHub
parent ff2577f56e
commit ef51a7c8f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 185 additions and 126 deletions

View file

@ -38,12 +38,88 @@
SubprocessInfo = namedtuple("SubprocessInfo", ["process", "socket"]) SubprocessInfo = namedtuple("SubprocessInfo", ["process", "socket"])
def _wait_for_signal(sock, expected_signals, timeout=SHORT_TIMEOUT):
"""
Wait for expected signal(s) from a socket with proper timeout and EOF handling.
Args:
sock: Connected socket to read from
expected_signals: Single bytes object or list of bytes objects to wait for
timeout: Socket timeout in seconds
Returns:
bytes: Complete accumulated response buffer
Raises:
RuntimeError: If connection closed before signal received or timeout
"""
if isinstance(expected_signals, bytes):
expected_signals = [expected_signals]
sock.settimeout(timeout)
buffer = b""
while True:
# Check if all expected signals are in buffer
if all(sig in buffer for sig in expected_signals):
return buffer
try:
chunk = sock.recv(4096)
if not chunk:
raise RuntimeError(
f"Connection closed before receiving expected signals. "
f"Expected: {expected_signals}, Got: {buffer[-200:]!r}"
)
buffer += chunk
except socket.timeout:
raise RuntimeError(
f"Timeout waiting for signals. "
f"Expected: {expected_signals}, Got: {buffer[-200:]!r}"
) from None
except OSError as e:
raise RuntimeError(
f"Socket error while waiting for signals: {e}. "
f"Expected: {expected_signals}, Got: {buffer[-200:]!r}"
) from None
def _cleanup_sockets(*sockets):
"""Safely close multiple sockets, ignoring errors."""
for sock in sockets:
if sock is not None:
try:
sock.close()
except OSError:
pass
def _cleanup_process(proc, timeout=SHORT_TIMEOUT):
"""Terminate a process gracefully, escalating to kill if needed."""
if proc.poll() is not None:
return
proc.terminate()
try:
proc.wait(timeout=timeout)
return
except subprocess.TimeoutExpired:
pass
proc.kill()
try:
proc.wait(timeout=timeout)
except subprocess.TimeoutExpired:
pass # Process refuses to die, nothing more we can do
@contextlib.contextmanager @contextlib.contextmanager
def test_subprocess(script): def test_subprocess(script, wait_for_working=False):
"""Context manager to create a test subprocess with socket synchronization. """Context manager to create a test subprocess with socket synchronization.
Args: Args:
script: Python code to execute in the subprocess script: Python code to execute in the subprocess. If wait_for_working
is True, script should send b"working" after starting work.
wait_for_working: If True, wait for both "ready" and "working" signals.
Default False for backward compatibility.
Yields: Yields:
SubprocessInfo: Named tuple with process and socket objects SubprocessInfo: Named tuple with process and socket objects
@ -80,19 +156,18 @@ def test_subprocess(script):
# Wait for process to connect and send ready signal # Wait for process to connect and send ready signal
client_socket, _ = server_socket.accept() client_socket, _ = server_socket.accept()
server_socket.close() server_socket.close()
response = client_socket.recv(1024) server_socket = None
if response != b"ready":
raise RuntimeError( # Wait for ready signal, and optionally working signal
f"Unexpected response from subprocess: {response!r}" if wait_for_working:
) _wait_for_signal(client_socket, [b"ready", b"working"])
else:
_wait_for_signal(client_socket, b"ready")
yield SubprocessInfo(proc, client_socket) yield SubprocessInfo(proc, client_socket)
finally: finally:
if client_socket is not None: _cleanup_sockets(client_socket, server_socket)
client_socket.close() _cleanup_process(proc)
if proc.poll() is None:
proc.kill()
proc.wait()
def close_and_unlink(file): def close_and_unlink(file):

View file

@ -39,32 +39,26 @@ def setUpClass(cls):
import gc import gc
class ExpensiveGarbage: class ExpensiveGarbage:
"""Class that triggers GC with expensive finalizer (callback)."""
def __init__(self): def __init__(self):
self.cycle = self self.cycle = self
def __del__(self): def __del__(self):
# CPU-intensive work in the finalizer callback
result = 0 result = 0
for i in range(100000): for i in range(100000):
result += i * i result += i * i
if i % 1000 == 0: if i % 1000 == 0:
result = result % 1000000 result = result % 1000000
def main_loop(): _test_sock.sendall(b"working")
"""Main loop that triggers GC with expensive callback.""" while True:
while True: ExpensiveGarbage()
ExpensiveGarbage() gc.collect()
gc.collect()
if __name__ == "__main__":
main_loop()
''' '''
def test_gc_frames_enabled(self): def test_gc_frames_enabled(self):
"""Test that GC frames appear when gc tracking is enabled.""" """Test that GC frames appear when gc tracking is enabled."""
with ( with (
test_subprocess(self.gc_test_script) as subproc, test_subprocess(self.gc_test_script, wait_for_working=True) as subproc,
io.StringIO() as captured_output, io.StringIO() as captured_output,
mock.patch("sys.stdout", captured_output), mock.patch("sys.stdout", captured_output),
): ):
@ -94,7 +88,7 @@ def test_gc_frames_enabled(self):
def test_gc_frames_disabled(self): def test_gc_frames_disabled(self):
"""Test that GC frames do not appear when gc tracking is disabled.""" """Test that GC frames do not appear when gc tracking is disabled."""
with ( with (
test_subprocess(self.gc_test_script) as subproc, test_subprocess(self.gc_test_script, wait_for_working=True) as subproc,
io.StringIO() as captured_output, io.StringIO() as captured_output,
mock.patch("sys.stdout", captured_output), mock.patch("sys.stdout", captured_output),
): ):
@ -133,18 +127,13 @@ def setUpClass(cls):
cls.native_test_script = """ cls.native_test_script = """
import operator import operator
def main_loop():
while True:
# Native code in the middle of the stack:
operator.call(inner)
def inner(): def inner():
# Python code at the top of the stack:
for _ in range(1_000_0000): for _ in range(1_000_0000):
pass pass
if __name__ == "__main__": _test_sock.sendall(b"working")
main_loop() while True:
operator.call(inner)
""" """
def test_native_frames_enabled(self): def test_native_frames_enabled(self):
@ -154,10 +143,7 @@ def test_native_frames_enabled(self):
) )
self.addCleanup(close_and_unlink, collapsed_file) self.addCleanup(close_and_unlink, collapsed_file)
with ( with test_subprocess(self.native_test_script, wait_for_working=True) as subproc:
test_subprocess(self.native_test_script) as subproc,
):
# Suppress profiler output when testing file export
with ( with (
io.StringIO() as captured_output, io.StringIO() as captured_output,
mock.patch("sys.stdout", captured_output), mock.patch("sys.stdout", captured_output),
@ -199,7 +185,7 @@ def test_native_frames_enabled(self):
def test_native_frames_disabled(self): def test_native_frames_disabled(self):
"""Test that native frames do not appear when native tracking is disabled.""" """Test that native frames do not appear when native tracking is disabled."""
with ( with (
test_subprocess(self.native_test_script) as subproc, test_subprocess(self.native_test_script, wait_for_working=True) as subproc,
io.StringIO() as captured_output, io.StringIO() as captured_output,
mock.patch("sys.stdout", captured_output), mock.patch("sys.stdout", captured_output),
): ):

View file

@ -39,6 +39,9 @@
# Duration for profiling tests - long enough for process to complete naturally # Duration for profiling tests - long enough for process to complete naturally
PROFILING_TIMEOUT = str(int(SHORT_TIMEOUT)) PROFILING_TIMEOUT = str(int(SHORT_TIMEOUT))
# Duration for profiling in tests - short enough to complete quickly
PROFILING_DURATION_SEC = 2
@skip_if_not_supported @skip_if_not_supported
@unittest.skipIf( @unittest.skipIf(
@ -359,23 +362,14 @@ def total_occurrences(func):
self.assertEqual(total_occurrences(main_key), 2) self.assertEqual(total_occurrences(main_key), 2)
@requires_subprocess() # Shared workload functions for test scripts
@skip_if_not_supported _WORKLOAD_FUNCTIONS = '''
class TestSampleProfilerIntegration(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.test_script = '''
import time
import os
def slow_fibonacci(n): def slow_fibonacci(n):
"""Recursive fibonacci - should show up prominently in profiler."""
if n <= 1: if n <= 1:
return n return n
return slow_fibonacci(n-1) + slow_fibonacci(n-2) return slow_fibonacci(n-1) + slow_fibonacci(n-2)
def cpu_intensive_work(): def cpu_intensive_work():
"""CPU intensive work that should show in profiler."""
result = 0 result = 0
for i in range(10000): for i in range(10000):
result += i * i result += i * i
@ -383,33 +377,48 @@ def cpu_intensive_work():
result = result % 1000000 result = result % 1000000
return result return result
def main_loop(): def do_work():
"""Main test loop.""" iteration = 0
max_iterations = 200 while True:
for iteration in range(max_iterations):
if iteration % 2 == 0: if iteration % 2 == 0:
result = slow_fibonacci(15) slow_fibonacci(15)
else: else:
result = cpu_intensive_work() cpu_intensive_work()
iteration += 1
'''
if __name__ == "__main__":
main_loop() @requires_subprocess()
@skip_if_not_supported
class TestSampleProfilerIntegration(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Test script for use with test_subprocess() - signals when work starts
cls.test_script = _WORKLOAD_FUNCTIONS + '''
_test_sock.sendall(b"working")
do_work()
'''
# CLI test script - runs for fixed duration (no socket sync)
cls.cli_test_script = '''
import time
''' + _WORKLOAD_FUNCTIONS.replace(
'while True:', 'end_time = time.time() + 30\n while time.time() < end_time:'
) + '''
do_work()
''' '''
def test_sampling_basic_functionality(self): def test_sampling_basic_functionality(self):
with ( with (
test_subprocess(self.test_script) as subproc, test_subprocess(self.test_script, wait_for_working=True) as subproc,
io.StringIO() as captured_output, io.StringIO() as captured_output,
mock.patch("sys.stdout", captured_output), mock.patch("sys.stdout", captured_output),
): ):
try: try:
# Sample for up to SHORT_TIMEOUT seconds, but process exits after fixed iterations
collector = PstatsCollector(sample_interval_usec=1000, skip_idle=False) collector = PstatsCollector(sample_interval_usec=1000, skip_idle=False)
profiling.sampling.sample.sample( profiling.sampling.sample.sample(
subproc.process.pid, subproc.process.pid,
collector, collector,
duration_sec=SHORT_TIMEOUT, duration_sec=PROFILING_DURATION_SEC,
) )
collector.print_stats(show_summary=False) collector.print_stats(show_summary=False)
except PermissionError: except PermissionError:
@ -431,7 +440,7 @@ def test_sampling_with_pstats_export(self):
) )
self.addCleanup(close_and_unlink, pstats_out) self.addCleanup(close_and_unlink, pstats_out)
with test_subprocess(self.test_script) as subproc: with test_subprocess(self.test_script, wait_for_working=True) as subproc:
# Suppress profiler output when testing file export # Suppress profiler output when testing file export
with ( with (
io.StringIO() as captured_output, io.StringIO() as captured_output,
@ -442,7 +451,7 @@ def test_sampling_with_pstats_export(self):
profiling.sampling.sample.sample( profiling.sampling.sample.sample(
subproc.process.pid, subproc.process.pid,
collector, collector,
duration_sec=1, duration_sec=PROFILING_DURATION_SEC,
) )
collector.export(pstats_out.name) collector.export(pstats_out.name)
except PermissionError: except PermissionError:
@ -476,7 +485,7 @@ def test_sampling_with_collapsed_export(self):
self.addCleanup(close_and_unlink, collapsed_file) self.addCleanup(close_and_unlink, collapsed_file)
with ( with (
test_subprocess(self.test_script) as subproc, test_subprocess(self.test_script, wait_for_working=True) as subproc,
): ):
# Suppress profiler output when testing file export # Suppress profiler output when testing file export
with ( with (
@ -488,7 +497,7 @@ def test_sampling_with_collapsed_export(self):
profiling.sampling.sample.sample( profiling.sampling.sample.sample(
subproc.process.pid, subproc.process.pid,
collector, collector,
duration_sec=1, duration_sec=PROFILING_DURATION_SEC,
) )
collector.export(collapsed_file.name) collector.export(collapsed_file.name)
except PermissionError: except PermissionError:
@ -526,7 +535,7 @@ def test_sampling_with_collapsed_export(self):
def test_sampling_all_threads(self): def test_sampling_all_threads(self):
with ( with (
test_subprocess(self.test_script) as subproc, test_subprocess(self.test_script, wait_for_working=True) as subproc,
# Suppress profiler output # Suppress profiler output
io.StringIO() as captured_output, io.StringIO() as captured_output,
mock.patch("sys.stdout", captured_output), mock.patch("sys.stdout", captured_output),
@ -536,7 +545,7 @@ def test_sampling_all_threads(self):
profiling.sampling.sample.sample( profiling.sampling.sample.sample(
subproc.process.pid, subproc.process.pid,
collector, collector,
duration_sec=1, duration_sec=PROFILING_DURATION_SEC,
all_threads=True, all_threads=True,
) )
collector.print_stats(show_summary=False) collector.print_stats(show_summary=False)
@ -548,12 +557,16 @@ def test_sampling_all_threads(self):
def test_sample_target_script(self): def test_sample_target_script(self):
script_file = tempfile.NamedTemporaryFile(delete=False) script_file = tempfile.NamedTemporaryFile(delete=False)
script_file.write(self.test_script.encode("utf-8")) script_file.write(self.cli_test_script.encode("utf-8"))
script_file.flush() script_file.flush()
self.addCleanup(close_and_unlink, script_file) self.addCleanup(close_and_unlink, script_file)
# Sample for up to SHORT_TIMEOUT seconds, but process exits after fixed iterations # Sample for PROFILING_DURATION_SEC seconds
test_args = ["profiling.sampling.sample", "run", "-d", PROFILING_TIMEOUT, script_file.name] test_args = [
"profiling.sampling.sample", "run",
"-d", str(PROFILING_DURATION_SEC),
script_file.name
]
with ( with (
mock.patch("sys.argv", test_args), mock.patch("sys.argv", test_args),
@ -583,13 +596,13 @@ def test_sample_target_module(self):
module_path = os.path.join(tempdir.name, "test_module.py") module_path = os.path.join(tempdir.name, "test_module.py")
with open(module_path, "w") as f: with open(module_path, "w") as f:
f.write(self.test_script) f.write(self.cli_test_script)
test_args = [ test_args = [
"profiling.sampling.cli", "profiling.sampling.cli",
"run", "run",
"-d", "-d",
PROFILING_TIMEOUT, str(PROFILING_DURATION_SEC),
"-m", "-m",
"test_module", "test_module",
] ]
@ -630,8 +643,10 @@ def test_invalid_pid(self):
profiling.sampling.sample.sample(-1, collector, duration_sec=1) profiling.sampling.sample.sample(-1, collector, duration_sec=1)
def test_process_dies_during_sampling(self): def test_process_dies_during_sampling(self):
# Use wait_for_working=False since this simple script doesn't send "working"
with test_subprocess( with test_subprocess(
"import time; time.sleep(0.5); exit()" "import time; time.sleep(0.5); exit()",
wait_for_working=False
) as subproc: ) as subproc:
with ( with (
io.StringIO() as captured_output, io.StringIO() as captured_output,
@ -654,7 +669,11 @@ def test_process_dies_during_sampling(self):
self.assertIn("Error rate", output) self.assertIn("Error rate", output)
def test_is_process_running(self): def test_is_process_running(self):
with test_subprocess("import time; time.sleep(1000)") as subproc: # Use wait_for_working=False since this simple script doesn't send "working"
with test_subprocess(
"import time; time.sleep(1000)",
wait_for_working=False
) as subproc:
try: try:
profiler = SampleProfiler( profiler = SampleProfiler(
pid=subproc.process.pid, pid=subproc.process.pid,
@ -681,7 +700,11 @@ def test_is_process_running(self):
@unittest.skipUnless(sys.platform == "linux", "Only valid on Linux") @unittest.skipUnless(sys.platform == "linux", "Only valid on Linux")
def test_esrch_signal_handling(self): def test_esrch_signal_handling(self):
with test_subprocess("import time; time.sleep(1000)") as subproc: # Use wait_for_working=False since this simple script doesn't send "working"
with test_subprocess(
"import time; time.sleep(1000)",
wait_for_working=False
) as subproc:
try: try:
unwinder = _remote_debugging.RemoteUnwinder( unwinder = _remote_debugging.RemoteUnwinder(
subproc.process.pid subproc.process.pid
@ -793,38 +816,34 @@ class TestAsyncAwareProfilingIntegration(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# Async test script that runs indefinitely until killed.
# Sends "working" signal AFTER tasks are created and scheduled.
cls.async_script = ''' cls.async_script = '''
import asyncio import asyncio
async def sleeping_leaf(): async def sleeping_leaf():
"""Leaf task that just sleeps - visible in 'all' mode.""" while True:
for _ in range(50):
await asyncio.sleep(0.02) await asyncio.sleep(0.02)
async def cpu_leaf(): async def cpu_leaf():
"""Leaf task that does CPU work - visible in both modes."""
total = 0 total = 0
for _ in range(200): while True:
for i in range(10000): for i in range(10000):
total += i * i total += i * i
await asyncio.sleep(0) await asyncio.sleep(0)
return total
async def supervisor(): async def supervisor():
"""Middle layer that spawns leaf tasks."""
tasks = [ tasks = [
asyncio.create_task(sleeping_leaf(), name="Sleeper-0"), asyncio.create_task(sleeping_leaf(), name="Sleeper-0"),
asyncio.create_task(sleeping_leaf(), name="Sleeper-1"), asyncio.create_task(sleeping_leaf(), name="Sleeper-1"),
asyncio.create_task(sleeping_leaf(), name="Sleeper-2"), asyncio.create_task(sleeping_leaf(), name="Sleeper-2"),
asyncio.create_task(cpu_leaf(), name="Worker"), asyncio.create_task(cpu_leaf(), name="Worker"),
] ]
await asyncio.sleep(0) # Let tasks get scheduled
_test_sock.sendall(b"working")
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
async def main(): asyncio.run(supervisor())
await supervisor()
if __name__ == "__main__":
asyncio.run(main())
''' '''
def _collect_async_samples(self, async_aware_mode): def _collect_async_samples(self, async_aware_mode):
@ -832,13 +851,13 @@ def _collect_async_samples(self, async_aware_mode):
Returns a dict mapping function names to their sample counts. Returns a dict mapping function names to their sample counts.
""" """
with test_subprocess(self.async_script) as subproc: with test_subprocess(self.async_script, wait_for_working=True) as subproc:
try: try:
collector = CollapsedStackCollector(1000, skip_idle=False) collector = CollapsedStackCollector(1000, skip_idle=False)
profiling.sampling.sample.sample( profiling.sampling.sample.sample(
subproc.process.pid, subproc.process.pid,
collector, collector,
duration_sec=SHORT_TIMEOUT, duration_sec=PROFILING_DURATION_SEC,
async_aware=async_aware_mode, async_aware=async_aware_mode,
) )
except PermissionError: except PermissionError:

View file

@ -143,27 +143,16 @@ def cpu_active_worker():
while True: while True:
x += 1 x += 1
def main(): idle_thread = threading.Thread(target=idle_worker)
# Start both threads cpu_thread = threading.Thread(target=cpu_active_worker)
idle_thread = threading.Thread(target=idle_worker) idle_thread.start()
cpu_thread = threading.Thread(target=cpu_active_worker) cpu_thread.start()
idle_thread.start() cpu_ready.wait()
cpu_thread.start() _test_sock.sendall(b"working")
idle_thread.join()
# Wait for CPU thread to be running, then signal test cpu_thread.join()
cpu_ready.wait()
_test_sock.sendall(b"threads_ready")
idle_thread.join()
cpu_thread.join()
main()
""" """
with test_subprocess(cpu_vs_idle_script) as subproc: with test_subprocess(cpu_vs_idle_script, wait_for_working=True) as subproc:
# Wait for signal that threads are running
response = subproc.socket.recv(1024)
self.assertEqual(response, b"threads_ready")
with ( with (
io.StringIO() as captured_output, io.StringIO() as captured_output,
@ -365,26 +354,16 @@ def gil_holding_work():
while True: while True:
x += 1 x += 1
def main(): idle_thread = threading.Thread(target=gil_releasing_work)
# Start both threads cpu_thread = threading.Thread(target=gil_holding_work)
idle_thread = threading.Thread(target=gil_releasing_work) idle_thread.start()
cpu_thread = threading.Thread(target=gil_holding_work) cpu_thread.start()
idle_thread.start() gil_ready.wait()
cpu_thread.start() _test_sock.sendall(b"working")
idle_thread.join()
# Wait for GIL-holding thread to be running, then signal test cpu_thread.join()
gil_ready.wait()
_test_sock.sendall(b"threads_ready")
idle_thread.join()
cpu_thread.join()
main()
""" """
with test_subprocess(gil_test_script) as subproc: with test_subprocess(gil_test_script, wait_for_working=True) as subproc:
# Wait for signal that threads are running
response = subproc.socket.recv(1024)
self.assertEqual(response, b"threads_ready")
with ( with (
io.StringIO() as captured_output, io.StringIO() as captured_output,