gh-148225: Validate profiling.sampling replay input (#148243)

This commit is contained in:
Pablo Galindo Salgado 2026-04-09 00:34:46 +01:00 committed by GitHub
parent 09968dd2a9
commit efde4333bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 130 additions and 40 deletions

View file

@ -86,6 +86,8 @@ def __call__(self, parser, namespace, values, option_string=None):
_PROCESS_KILL_TIMEOUT_SEC = 2.0
_READY_MESSAGE = b"ready"
_RECV_BUFFER_SIZE = 1024
_BINARY_PROFILE_HEADER_SIZE = 64
_BINARY_PROFILE_MAGICS = (b"HCAT", b"TACH")
# Format configuration
FORMAT_EXTENSIONS = {
@ -650,6 +652,88 @@ def _open_in_browser(path):
print(f"Warning: Could not open browser: {e}", file=sys.stderr)
def _validate_replay_input_file(filename):
"""Validate that the replay input looks like a sampling binary profile."""
try:
with open(filename, "rb") as file:
header = file.read(_BINARY_PROFILE_HEADER_SIZE)
except OSError as exc:
sys.exit(f"Error: Could not read input file {filename}: {exc}")
if (
len(header) < _BINARY_PROFILE_HEADER_SIZE
or header[:4] not in _BINARY_PROFILE_MAGICS
):
sys.exit(
"Error: Input file is not a binary sampling profile. "
"The replay command only accepts files created with --binary"
)
def _replay_with_reader(args, reader):
"""Replay samples from an open binary reader."""
info = reader.get_info()
interval = info['sample_interval_us']
print(f"Replaying {info['sample_count']} samples from {args.input_file}")
print(f" Sample interval: {interval} us")
print(
" Compression: "
f"{'zstd' if info.get('compression_type', 0) == 1 else 'none'}"
)
collector = _create_collector(
args.format, interval, skip_idle=False,
diff_baseline=args.diff_baseline
)
def progress_callback(current, total):
if total > 0:
pct = current / total
bar_width = 40
filled = int(bar_width * pct)
bar = '' * filled + '' * (bar_width - filled)
print(
f"\r [{bar}] {pct*100:5.1f}% ({current:,}/{total:,})",
end="",
flush=True,
)
count = reader.replay_samples(collector, progress_callback)
print()
if args.format == "pstats":
if args.outfile:
collector.export(args.outfile)
else:
sort_choice = (
args.sort if args.sort is not None else "nsamples"
)
limit = args.limit if args.limit is not None else 15
sort_mode = _sort_to_mode(sort_choice)
collector.print_stats(
sort_mode, limit, not args.no_summary,
PROFILING_MODE_WALL
)
else:
filename = (
args.outfile
or _generate_output_filename(args.format, os.getpid())
)
collector.export(filename)
# Auto-open browser for HTML output if --browser flag is set
if (
args.format in (
'flamegraph', 'diff_flamegraph', 'heatmap'
)
and getattr(args, 'browser', False)
):
_open_in_browser(filename)
print(f"Replayed {count} samples")
def _handle_output(collector, args, pid, mode):
"""Handle output for the collector based on format and arguments.
@ -1201,47 +1285,13 @@ def _handle_replay(args):
if not os.path.exists(args.input_file):
sys.exit(f"Error: Input file not found: {args.input_file}")
with BinaryReader(args.input_file) as reader:
info = reader.get_info()
interval = info['sample_interval_us']
_validate_replay_input_file(args.input_file)
print(f"Replaying {info['sample_count']} samples from {args.input_file}")
print(f" Sample interval: {interval} us")
print(f" Compression: {'zstd' if info.get('compression_type', 0) == 1 else 'none'}")
collector = _create_collector(
args.format, interval, skip_idle=False,
diff_baseline=args.diff_baseline
)
def progress_callback(current, total):
if total > 0:
pct = current / total
bar_width = 40
filled = int(bar_width * pct)
bar = '' * filled + '' * (bar_width - filled)
print(f"\r [{bar}] {pct*100:5.1f}% ({current:,}/{total:,})", end="", flush=True)
count = reader.replay_samples(collector, progress_callback)
print()
if args.format == "pstats":
if args.outfile:
collector.export(args.outfile)
else:
sort_choice = args.sort if args.sort is not None else "nsamples"
limit = args.limit if args.limit is not None else 15
sort_mode = _sort_to_mode(sort_choice)
collector.print_stats(sort_mode, limit, not args.no_summary, PROFILING_MODE_WALL)
else:
filename = args.outfile or _generate_output_filename(args.format, os.getpid())
collector.export(filename)
# Auto-open browser for HTML output if --browser flag is set
if args.format in ('flamegraph', 'diff_flamegraph', 'heatmap') and getattr(args, 'browser', False):
_open_in_browser(filename)
print(f"Replayed {count} samples")
try:
with BinaryReader(args.input_file) as reader:
_replay_with_reader(args, reader)
except (OSError, ValueError) as exc:
sys.exit(f"Error: {exc}")
if __name__ == "__main__":

View file

@ -1,8 +1,10 @@
"""Tests for sampling profiler CLI argument parsing and functionality."""
import io
import os
import subprocess
import sys
import tempfile
import unittest
from unittest import mock
@ -722,3 +724,38 @@ def test_cli_attach_nonexistent_pid(self):
main()
self.assertIn(fake_pid, str(cm.exception))
def test_cli_replay_rejects_non_binary_profile(self):
with tempfile.TemporaryDirectory() as tempdir:
profile = os.path.join(tempdir, "output.prof")
with open(profile, "wb") as file:
file.write(b"not a binary sampling profile")
with mock.patch("sys.argv", ["profiling.sampling.cli", "replay", profile]):
with self.assertRaises(SystemExit) as cm:
main()
error = str(cm.exception)
self.assertIn("not a binary sampling profile", error)
self.assertIn("--binary", error)
def test_cli_replay_reader_errors_exit_cleanly(self):
with tempfile.TemporaryDirectory() as tempdir:
profile = os.path.join(tempdir, "output.bin")
with open(profile, "wb") as file:
file.write(b"HCAT" + (b"\0" * 60))
with (
mock.patch("sys.argv", ["profiling.sampling.cli", "replay", profile]),
mock.patch(
"profiling.sampling.cli.BinaryReader",
side_effect=ValueError("Unsupported format version 2"),
),
):
with self.assertRaises(SystemExit) as cm:
main()
self.assertEqual(
str(cm.exception),
"Error: Unsupported format version 2",
)

View file

@ -0,0 +1,3 @@
The :mod:`profiling.sampling` ``replay`` command now rejects non-binary
profile files with a clear error explaining that replay only accepts files
created with ``--binary``.