gh-143040: Exit taychon live mode gracefully and display profiled script errors (#143101)

This commit is contained in:
Marta Gómez Macías 2025-12-27 01:36:15 +01:00 committed by GitHub
parent a1c6308346
commit 9d92ac1225
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 133 additions and 38 deletions

View file

@ -135,7 +135,7 @@ def _execute_module(module_name: str, module_args: List[str]) -> None:
module_args: Arguments to pass to the module
Raises:
TargetError: If module execution fails
TargetError: If module cannot be found
"""
# Replace sys.argv to match how Python normally runs modules
# When running 'python -m module args', sys.argv is ["__main__.py", "args"]
@ -145,11 +145,8 @@ def _execute_module(module_name: str, module_args: List[str]) -> None:
runpy.run_module(module_name, run_name="__main__", alter_sys=True)
except ImportError as e:
raise TargetError(f"Module '{module_name}' not found: {e}") from e
except SystemExit:
# SystemExit is normal for modules
pass
except Exception as e:
raise TargetError(f"Error executing module '{module_name}': {e}") from e
# Let other exceptions (including SystemExit) propagate naturally
# so Python prints the full traceback to stderr
def _execute_script(script_path: str, script_args: List[str], cwd: str) -> None:
@ -183,22 +180,20 @@ def _execute_script(script_path: str, script_args: List[str], cwd: str) -> None:
except PermissionError as e:
raise TargetError(f"Permission denied reading script: {script_path}") from e
try:
main_module = types.ModuleType("__main__")
main_module.__file__ = script_path
main_module.__builtins__ = __builtins__
# gh-140729: Create a __mp_main__ module to allow pickling
sys.modules['__main__'] = sys.modules['__mp_main__'] = main_module
main_module = types.ModuleType("__main__")
main_module.__file__ = script_path
main_module.__builtins__ = __builtins__
# gh-140729: Create a __mp_main__ module to allow pickling
sys.modules['__main__'] = sys.modules['__mp_main__'] = main_module
try:
code = compile(source_code, script_path, 'exec', module='__main__')
exec(code, main_module.__dict__)
except SyntaxError as e:
raise TargetError(f"Syntax error in script {script_path}: {e}") from e
except SystemExit:
# SystemExit is normal for scripts
pass
except Exception as e:
raise TargetError(f"Error executing script '{script_path}': {e}") from e
# Execute the script - let exceptions propagate naturally so Python
# prints the full traceback to stderr
exec(code, main_module.__dict__)
def main() -> NoReturn:
@ -209,6 +204,8 @@ def main() -> NoReturn:
with the sample profiler by signaling when the process is ready
to be profiled.
"""
# Phase 1: Parse arguments and set up environment
# Errors here are coordinator errors, not script errors
try:
# Parse and validate arguments
sync_port, cwd, target_args = _validate_arguments(sys.argv)
@ -237,21 +234,19 @@ def main() -> NoReturn:
# Signal readiness to profiler
_signal_readiness(sync_port)
# Execute the target
if is_module:
_execute_module(module_name, module_args)
else:
_execute_script(script_path, script_args, cwd)
except CoordinatorError as e:
print(f"Profiler coordinator error: {e}", file=sys.stderr)
sys.exit(1)
except KeyboardInterrupt:
print("Interrupted", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"Unexpected error in profiler coordinator: {e}", file=sys.stderr)
sys.exit(1)
# Phase 2: Execute the target script/module
# Let exceptions propagate naturally so Python prints full tracebacks
if is_module:
_execute_module(module_name, module_args)
else:
_execute_script(script_path, script_args, cwd)
# Normal exit
sys.exit(0)

View file

@ -272,11 +272,6 @@ def _run_with_sync(original_cmd, suppress_output=False):
try:
_wait_for_ready_signal(sync_sock, process, _SYNC_TIMEOUT_SEC)
# Close stderr pipe if we were capturing it
if process.stderr:
process.stderr.close()
except socket.timeout:
# If we timeout, kill the process and raise an error
if process.poll() is None:
@ -1103,14 +1098,27 @@ def _handle_live_run(args):
blocking=args.blocking,
)
finally:
# Clean up the subprocess
if process.poll() is None:
# Clean up the subprocess and get any error output
returncode = process.poll()
if returncode is None:
# Process still running - terminate it
process.terminate()
try:
process.wait(timeout=_PROCESS_KILL_TIMEOUT_SEC)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
# Ensure process is fully terminated
process.wait()
# Read any stderr output (tracebacks, errors, etc.)
if process.stderr:
with process.stderr:
try:
stderr = process.stderr.read()
if stderr:
print(stderr.decode(), file=sys.stderr)
except (OSError, ValueError):
# Ignore errors if pipe is already closed
pass
def _handle_replay(args):

View file

@ -216,6 +216,9 @@ def __init__(
def elapsed_time(self):
"""Get the elapsed time, frozen when finished."""
if self.finished and self.finish_timestamp is not None:
# Handle case where process exited before any samples were collected
if self.start_time is None:
return 0
return self.finish_timestamp - self.start_time
return time.perf_counter() - self.start_time if self.start_time else 0

View file

@ -42,7 +42,9 @@ def _pause_threads(unwinder, blocking):
LiveStatsCollector = None
_FREE_THREADED_BUILD = sysconfig.get_config_var("Py_GIL_DISABLED") is not None
# Minimum number of samples required before showing the TUI
# If fewer samples are collected, we skip the TUI and just print a message
MIN_SAMPLES_FOR_TUI = 200
class SampleProfiler:
def __init__(self, pid, sample_interval_usec, all_threads, *, mode=PROFILING_MODE_WALL, native=False, gc=True, opcodes=False, skip_non_matching_threads=True, collect_stats=False, blocking=False):
@ -459,6 +461,11 @@ def sample_live(
"""
import curses
# Check if process is alive before doing any heavy initialization
if not _is_process_running(pid):
print(f"No samples collected - process {pid} exited before profiling could begin.", file=sys.stderr)
return collector
# Get sample interval from collector
sample_interval_usec = collector.sample_interval_usec
@ -486,6 +493,12 @@ def curses_wrapper_func(stdscr):
collector.init_curses(stdscr)
try:
profiler.sample(collector, duration_sec, async_aware=async_aware)
# If too few samples were collected, exit cleanly without showing TUI
if collector.successful_samples < MIN_SAMPLES_FOR_TUI:
# Clear screen before exiting to avoid visual artifacts
stdscr.clear()
stdscr.refresh()
return
# Mark as finished and keep the TUI running until user presses 'q'
collector.mark_finished()
# Keep processing input until user quits
@ -500,4 +513,11 @@ def curses_wrapper_func(stdscr):
except KeyboardInterrupt:
pass
# If too few samples were collected, print a message
if collector.successful_samples < MIN_SAMPLES_FOR_TUI:
if collector.successful_samples == 0:
print(f"No samples collected - process {pid} exited before profiling could begin.", file=sys.stderr)
else:
print(f"Only {collector.successful_samples} sample(s) collected (minimum {MIN_SAMPLES_FOR_TUI} required for TUI) - process {pid} exited too quickly.", file=sys.stderr)
return collector

View file

@ -18,7 +18,6 @@
from profiling.sampling.cli import main
from profiling.sampling.errors import SamplingScriptNotFoundError, SamplingModuleNotFoundError, SamplingUnknownProcessError
class TestSampleProfilerCLI(unittest.TestCase):
def _setup_sync_mocks(self, mock_socket, mock_popen):
"""Helper to set up socket and process mocks for coordinator tests."""

View file

@ -4,11 +4,14 @@
edge cases, update display, and display helpers.
"""
import functools
import io
import sys
import tempfile
import time
import unittest
from unittest import mock
from test.support import requires
from test.support import requires, requires_remote_subprocess_debugging
from test.support.import_helper import import_module
# Only run these tests if curses is available
@ -16,10 +19,12 @@
curses = import_module("curses")
from profiling.sampling.live_collector import LiveStatsCollector, MockDisplay
from profiling.sampling.cli import main
from ._live_collector_helpers import (
MockThreadInfo,
MockInterpreterInfo,
)
from .helpers import close_and_unlink
class TestLiveStatsCollectorWithMockDisplay(unittest.TestCase):
@ -816,5 +821,70 @@ def test_get_all_lines_full_display(self):
self.assertTrue(any("PID" in line for line in lines))
@requires_remote_subprocess_debugging()
class TestLiveModeErrors(unittest.TestCase):
"""Tests running error commands in the live mode fails gracefully."""
def mock_curses_wrapper(self, func):
func(mock.MagicMock())
def mock_init_curses_side_effect(self, n_times, mock_self, stdscr):
mock_self.display = MockDisplay()
# Allow the loop to run for a bit (approx 0.5s) before quitting
# This ensures we don't exit too early while the subprocess is
# still failing
for _ in range(n_times):
mock_self.display.simulate_input(-1)
if n_times >= 500:
mock_self.display.simulate_input(ord('q'))
def test_run_failed_module_live(self):
"""Test that running a existing module that fails exits with clean error."""
args = [
"profiling.sampling.cli", "run", "--live", "-m", "test",
"test_asdasd"
]
with (
mock.patch(
'profiling.sampling.live_collector.collector.LiveStatsCollector.init_curses',
autospec=True,
side_effect=functools.partial(self.mock_init_curses_side_effect, 1000)
),
mock.patch('curses.wrapper', side_effect=self.mock_curses_wrapper),
mock.patch("sys.argv", args),
mock.patch('sys.stderr', new=io.StringIO()) as fake_stderr
):
main()
self.assertIn(
'test test_asdasd crashed -- Traceback (most recent call last):',
fake_stderr.getvalue()
)
def test_run_failed_script_live(self):
"""Test that running a failing script exits with clean error."""
script = tempfile.NamedTemporaryFile(suffix=".py")
self.addCleanup(close_and_unlink, script)
script.write(b'1/0\n')
script.seek(0)
args = ["profiling.sampling.cli", "run", "--live", script.name]
with (
mock.patch(
'profiling.sampling.live_collector.collector.LiveStatsCollector.init_curses',
autospec=True,
side_effect=functools.partial(self.mock_init_curses_side_effect, 200)
),
mock.patch('curses.wrapper', side_effect=self.mock_curses_wrapper),
mock.patch("sys.argv", args),
mock.patch('sys.stderr', new=io.StringIO()) as fake_stderr
):
main()
stderr = fake_stderr.getvalue()
self.assertIn('ZeroDivisionError', stderr)
if __name__ == "__main__":
unittest.main()