gh-142654: show the clear error message when sampling on an unknown PID in tachyon (#142655)

Co-authored-by: Pablo Galindo Salgado <pablogsal@gmail.com>
This commit is contained in:
Keming 2025-12-17 22:15:22 +08:00 committed by GitHub
parent 1fc3039d71
commit d4095f25e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 104 additions and 43 deletions

View file

@ -46,6 +46,7 @@
"""
from .cli import main
from .errors import SamplingUnknownProcessError, SamplingModuleNotFoundError, SamplingScriptNotFoundError
def handle_permission_error():
"""Handle PermissionError by displaying appropriate error message."""
@ -64,3 +65,9 @@ def handle_permission_error():
main()
except PermissionError:
handle_permission_error()
except SamplingUnknownProcessError as err:
print(f"Tachyon cannot find the process: {err}", file=sys.stderr)
sys.exit(1)
except (SamplingModuleNotFoundError, SamplingScriptNotFoundError) as err:
print(f"Tachyon cannot find the target: {err}", file=sys.stderr)
sys.exit(1)

View file

@ -10,7 +10,8 @@
import time
from contextlib import nullcontext
from .sample import sample, sample_live
from .errors import SamplingUnknownProcessError, SamplingModuleNotFoundError, SamplingScriptNotFoundError
from .sample import sample, sample_live, _is_process_running
from .pstats_collector import PstatsCollector
from .stack_collector import CollapsedStackCollector, FlamegraphCollector
from .heatmap_collector import HeatmapCollector
@ -743,6 +744,8 @@ def main():
def _handle_attach(args):
"""Handle the 'attach' command."""
if not _is_process_running(args.pid):
raise SamplingUnknownProcessError(args.pid)
# Check if live mode is requested
if args.live:
_handle_live_attach(args, args.pid)
@ -792,13 +795,13 @@ def _handle_run(args):
added_cwd = True
try:
if importlib.util.find_spec(args.target) is None:
sys.exit(f"Error: Module not found: {args.target}")
raise SamplingModuleNotFoundError(args.target)
finally:
if added_cwd:
sys.path.remove(cwd)
else:
if not os.path.exists(args.target):
sys.exit(f"Error: Script not found: {args.target}")
raise SamplingScriptNotFoundError(args.target)
# Check if live mode is requested
if args.live:

View file

@ -0,0 +1,19 @@
"""Custom exceptions for the sampling profiler."""
class SamplingProfilerError(Exception):
"""Base exception for sampling profiler errors."""
class SamplingUnknownProcessError(SamplingProfilerError):
def __init__(self, pid):
self.pid = pid
super().__init__(f"Process with PID '{pid}' does not exist.")
class SamplingScriptNotFoundError(SamplingProfilerError):
def __init__(self, script_path):
self.script_path = script_path
super().__init__(f"Script '{script_path}' not found.")
class SamplingModuleNotFoundError(SamplingProfilerError):
def __init__(self, module_name):
self.module_name = module_name
super().__init__(f"Module '{module_name}' not found.")

View file

@ -34,24 +34,30 @@ def __init__(self, pid, sample_interval_usec, all_threads, *, mode=PROFILING_MOD
self.all_threads = all_threads
self.mode = mode # Store mode for later use
self.collect_stats = collect_stats
if _FREE_THREADED_BUILD:
self.unwinder = _remote_debugging.RemoteUnwinder(
self.pid, all_threads=self.all_threads, mode=mode, native=native, gc=gc,
opcodes=opcodes, skip_non_matching_threads=skip_non_matching_threads,
cache_frames=True, stats=collect_stats
)
else:
only_active_threads = bool(self.all_threads)
self.unwinder = _remote_debugging.RemoteUnwinder(
self.pid, only_active_thread=only_active_threads, mode=mode, native=native, gc=gc,
opcodes=opcodes, skip_non_matching_threads=skip_non_matching_threads,
cache_frames=True, stats=collect_stats
)
try:
self.unwinder = self._new_unwinder(native, gc, opcodes, skip_non_matching_threads)
except RuntimeError as err:
raise SystemExit(err) from err
# Track sample intervals and total sample count
self.sample_intervals = deque(maxlen=100)
self.total_samples = 0
self.realtime_stats = False
def _new_unwinder(self, native, gc, opcodes, skip_non_matching_threads):
if _FREE_THREADED_BUILD:
unwinder = _remote_debugging.RemoteUnwinder(
self.pid, all_threads=self.all_threads, mode=self.mode, native=native, gc=gc,
opcodes=opcodes, skip_non_matching_threads=skip_non_matching_threads,
cache_frames=True, stats=self.collect_stats
)
else:
unwinder = _remote_debugging.RemoteUnwinder(
self.pid, only_active_thread=bool(self.all_threads), mode=self.mode, native=native, gc=gc,
opcodes=opcodes, skip_non_matching_threads=skip_non_matching_threads,
cache_frames=True, stats=self.collect_stats
)
return unwinder
def sample(self, collector, duration_sec=10, *, async_aware=False):
sample_interval_sec = self.sample_interval_usec / 1_000_000
running_time = 0
@ -86,7 +92,7 @@ def sample(self, collector, duration_sec=10, *, async_aware=False):
collector.collect_failed_sample()
errors += 1
except Exception as e:
if not self._is_process_running():
if not _is_process_running(self.pid):
break
raise e from None
@ -148,22 +154,6 @@ def sample(self, collector, duration_sec=10, *, async_aware=False):
f"({(expected_samples - num_samples) / expected_samples * 100:.2f}%)"
)
def _is_process_running(self):
if sys.platform == "linux" or sys.platform == "darwin":
try:
os.kill(self.pid, 0)
return True
except ProcessLookupError:
return False
elif sys.platform == "win32":
try:
_remote_debugging.RemoteUnwinder(self.pid)
except Exception:
return False
return True
else:
raise ValueError(f"Unsupported platform: {sys.platform}")
def _print_realtime_stats(self):
"""Print real-time sampling statistics."""
if len(self.sample_intervals) < 2:
@ -279,6 +269,28 @@ def _print_unwinder_stats(self):
print(f" {ANSIColors.YELLOW}Stale cache invalidations: {stale_invalidations}{ANSIColors.RESET}")
def _is_process_running(pid):
if pid <= 0:
return False
if os.name == "posix":
try:
os.kill(pid, 0)
return True
except ProcessLookupError:
return False
except PermissionError:
# EPERM means process exists but we can't signal it
return True
elif sys.platform == "win32":
try:
_remote_debugging.RemoteUnwinder(pid)
except Exception:
return False
return True
else:
raise ValueError(f"Unsupported platform: {sys.platform}")
def sample(
pid,
collector,