GH-141565: Add async code awareness to Tachyon (#141533)

Co-authored-by: Pablo Galindo Salgado <pablogsal@gmail.com>
This commit is contained in:
Savannah Ostrowski 2025-12-06 11:31:40 -08:00 committed by GitHub
parent 35142b18ae
commit 56a442d0d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1355 additions and 83 deletions

View file

@ -195,6 +195,11 @@ def _add_sampling_options(parser):
dest="gc",
help='Don\'t include artificial "<GC>" frames to denote active garbage collection',
)
sampling_group.add_argument(
"--async-aware",
action="store_true",
help="Enable async-aware profiling (uses task-based stack reconstruction)",
)
def _add_mode_options(parser):
@ -205,7 +210,14 @@ def _add_mode_options(parser):
choices=["wall", "cpu", "gil"],
default="wall",
help="Sampling mode: wall (all samples), cpu (only samples when thread is on CPU), "
"gil (only samples when thread holds the GIL)",
"gil (only samples when thread holds the GIL). Incompatible with --async-aware",
)
mode_group.add_argument(
"--async-mode",
choices=["running", "all"],
default="running",
help='Async profiling mode: "running" (only running task) '
'or "all" (all tasks including waiting). Requires --async-aware',
)
@ -382,6 +394,27 @@ def _validate_args(args, parser):
"Live mode requires the curses module, which is not available."
)
# Async-aware mode is incompatible with --native, --no-gc, --mode, and --all-threads
if args.async_aware:
issues = []
if args.native:
issues.append("--native")
if not args.gc:
issues.append("--no-gc")
if hasattr(args, 'mode') and args.mode != "wall":
issues.append(f"--mode={args.mode}")
if hasattr(args, 'all_threads') and args.all_threads:
issues.append("--all-threads")
if issues:
parser.error(
f"Options {', '.join(issues)} are incompatible with --async-aware. "
"Async-aware profiling uses task-based stack reconstruction."
)
# --async-mode requires --async-aware
if hasattr(args, 'async_mode') and args.async_mode != "running" and not args.async_aware:
parser.error("--async-mode requires --async-aware to be enabled.")
# Live mode is incompatible with format options
if hasattr(args, 'live') and args.live:
if args.format != "pstats":
@ -570,6 +603,7 @@ def _handle_attach(args):
all_threads=args.all_threads,
realtime_stats=args.realtime_stats,
mode=mode,
async_aware=args.async_mode if args.async_aware else None,
native=args.native,
gc=args.gc,
)
@ -618,6 +652,7 @@ def _handle_run(args):
all_threads=args.all_threads,
realtime_stats=args.realtime_stats,
mode=mode,
async_aware=args.async_mode if args.async_aware else None,
native=args.native,
gc=args.gc,
)
@ -650,6 +685,7 @@ def _handle_live_attach(args, pid):
limit=20, # Default limit
pid=pid,
mode=mode,
async_aware=args.async_mode if args.async_aware else None,
)
# Sample in live mode
@ -660,6 +696,7 @@ def _handle_live_attach(args, pid):
all_threads=args.all_threads,
realtime_stats=args.realtime_stats,
mode=mode,
async_aware=args.async_mode if args.async_aware else None,
native=args.native,
gc=args.gc,
)
@ -689,6 +726,7 @@ def _handle_live_run(args):
limit=20, # Default limit
pid=process.pid,
mode=mode,
async_aware=args.async_mode if args.async_aware else None,
)
# Profile the subprocess in live mode
@ -700,6 +738,7 @@ def _handle_live_run(args):
all_threads=args.all_threads,
realtime_stats=args.realtime_stats,
mode=mode,
async_aware=args.async_mode if args.async_aware else None,
native=args.native,
gc=args.gc,
)

View file

@ -2,10 +2,16 @@
from .constants import (
THREAD_STATUS_HAS_GIL,
THREAD_STATUS_ON_CPU,
THREAD_STATUS_UNKNOWN,
THREAD_STATUS_GIL_REQUESTED,
THREAD_STATUS_UNKNOWN,
)
try:
from _remote_debugging import FrameInfo
except ImportError:
# Fallback definition if _remote_debugging is not available
FrameInfo = None
class Collector(ABC):
@abstractmethod
def collect(self, stack_frames):
@ -33,6 +39,95 @@ def _iter_all_frames(self, stack_frames, skip_idle=False):
if frames:
yield frames, thread_info.thread_id
def _iter_async_frames(self, awaited_info_list):
# Phase 1: Index tasks and build parent relationships with pre-computed selection
task_map, child_to_parent, all_task_ids, all_parent_ids = self._build_task_graph(awaited_info_list)
# Phase 2: Find leaf tasks (tasks not awaited by anyone)
leaf_task_ids = self._find_leaf_tasks(all_task_ids, all_parent_ids)
# Phase 3: Build linear stacks from each leaf to root (optimized - no sorting!)
yield from self._build_linear_stacks(leaf_task_ids, task_map, child_to_parent)
def _build_task_graph(self, awaited_info_list):
task_map = {}
child_to_parent = {} # Maps child_id -> (selected_parent_id, parent_count)
all_task_ids = set()
all_parent_ids = set() # Track ALL parent IDs for leaf detection
for awaited_info in awaited_info_list:
thread_id = awaited_info.thread_id
for task_info in awaited_info.awaited_by:
task_id = task_info.task_id
task_map[task_id] = (task_info, thread_id)
all_task_ids.add(task_id)
# Pre-compute selected parent and count for optimization
if task_info.awaited_by:
parent_ids = [p.task_name for p in task_info.awaited_by]
parent_count = len(parent_ids)
# Track ALL parents for leaf detection
all_parent_ids.update(parent_ids)
# Use min() for O(n) instead of sorted()[0] which is O(n log n)
selected_parent = min(parent_ids) if parent_count > 1 else parent_ids[0]
child_to_parent[task_id] = (selected_parent, parent_count)
return task_map, child_to_parent, all_task_ids, all_parent_ids
def _find_leaf_tasks(self, all_task_ids, all_parent_ids):
# Leaves are tasks that are not parents of any other task
return all_task_ids - all_parent_ids
def _build_linear_stacks(self, leaf_task_ids, task_map, child_to_parent):
for leaf_id in leaf_task_ids:
frames = []
visited = set()
current_id = leaf_id
thread_id = None
# Follow the single parent chain from leaf to root
while current_id is not None:
# Cycle detection
if current_id in visited:
break
visited.add(current_id)
# Check if task exists in task_map
if current_id not in task_map:
break
task_info, tid = task_map[current_id]
# Set thread_id from first task
if thread_id is None:
thread_id = tid
# Add all frames from all coroutines in this task
if task_info.coroutine_stack:
for coro_info in task_info.coroutine_stack:
for frame in coro_info.call_stack:
frames.append(frame)
# Get pre-computed parent info (no sorting needed!)
parent_info = child_to_parent.get(current_id)
# Add task boundary marker with parent count annotation if multiple parents
task_name = task_info.task_name or "Task-" + str(task_info.task_id)
if parent_info:
selected_parent, parent_count = parent_info
if parent_count > 1:
task_name = f"{task_name} ({parent_count} parents)"
frames.append(FrameInfo(("<task>", 0, task_name)))
current_id = selected_parent
else:
# Root task - no parent
frames.append(FrameInfo(("<task>", 0, task_name)))
current_id = None
# Yield the complete stack if we collected any frames
if frames and thread_id is not None:
yield frames, thread_id, leaf_id
def _is_gc_frame(self, frame):
if isinstance(frame, tuple):
funcname = frame[2] if len(frame) >= 3 else ""

View file

@ -103,6 +103,7 @@ def __init__(
pid=None,
display=None,
mode=None,
async_aware=None,
):
"""
Initialize the live stats collector.
@ -115,6 +116,7 @@ def __init__(
pid: Process ID being profiled
display: DisplayInterface implementation (None means curses will be used)
mode: Profiling mode ('cpu', 'gil', etc.) - affects what stats are shown
async_aware: Async tracing mode - None (sync only), "all" or "running"
"""
self.result = collections.defaultdict(
lambda: dict(total_rec_calls=0, direct_calls=0, cumulative_calls=0)
@ -133,6 +135,9 @@ def __init__(
self.running = True
self.pid = pid
self.mode = mode # Profiling mode
self.async_aware = async_aware # Async tracing mode
# Pre-select frame iterator method to avoid per-call dispatch overhead
self._get_frame_iterator = self._get_async_frame_iterator if async_aware else self._get_sync_frame_iterator
self._saved_stdout = None
self._saved_stderr = None
self._devnull = None
@ -294,6 +299,15 @@ def process_frames(self, frames, thread_id=None):
if thread_data:
thread_data.result[top_location]["direct_calls"] += 1
def _get_sync_frame_iterator(self, stack_frames):
"""Iterator for sync frames."""
return self._iter_all_frames(stack_frames, skip_idle=self.skip_idle)
def _get_async_frame_iterator(self, stack_frames):
"""Iterator for async frames, yielding (frames, thread_id) tuples."""
for frames, thread_id, task_id in self._iter_async_frames(stack_frames):
yield frames, thread_id
def collect_failed_sample(self):
self.failed_samples += 1
self.total_samples += 1
@ -304,78 +318,40 @@ def collect(self, stack_frames):
self.start_time = time.perf_counter()
self._last_display_update = self.start_time
# Thread status counts for this sample
temp_status_counts = {
"has_gil": 0,
"on_cpu": 0,
"gil_requested": 0,
"unknown": 0,
"total": 0,
}
has_gc_frame = False
# Always collect data, even when paused
# Track thread status flags and GC frames
for interpreter_info in stack_frames:
threads = getattr(interpreter_info, "threads", [])
for thread_info in threads:
temp_status_counts["total"] += 1
# Collect thread status stats (only available in sync mode)
if not self.async_aware:
status_counts, sample_has_gc, per_thread_stats = self._collect_thread_status_stats(stack_frames)
for key, count in status_counts.items():
self.thread_status_counts[key] += count
if sample_has_gc:
has_gc_frame = True
# Track thread status using bit flags
status_flags = getattr(thread_info, "status", 0)
thread_id = getattr(thread_info, "thread_id", None)
for thread_id, stats in per_thread_stats.items():
thread_data = self._get_or_create_thread_data(thread_id)
thread_data.has_gil += stats.get("has_gil", 0)
thread_data.on_cpu += stats.get("on_cpu", 0)
thread_data.gil_requested += stats.get("gil_requested", 0)
thread_data.unknown += stats.get("unknown", 0)
thread_data.total += stats.get("total", 0)
if stats.get("gc_samples", 0):
thread_data.gc_frame_samples += stats["gc_samples"]
# Update aggregated counts
if status_flags & THREAD_STATUS_HAS_GIL:
temp_status_counts["has_gil"] += 1
if status_flags & THREAD_STATUS_ON_CPU:
temp_status_counts["on_cpu"] += 1
if status_flags & THREAD_STATUS_GIL_REQUESTED:
temp_status_counts["gil_requested"] += 1
if status_flags & THREAD_STATUS_UNKNOWN:
temp_status_counts["unknown"] += 1
# Process frames using pre-selected iterator
for frames, thread_id in self._get_frame_iterator(stack_frames):
if not frames:
continue
# Update per-thread status counts
if thread_id is not None:
thread_data = self._get_or_create_thread_data(thread_id)
thread_data.increment_status_flag(status_flags)
self.process_frames(frames, thread_id=thread_id)
# Process frames (respecting skip_idle)
if self.skip_idle:
has_gil = bool(status_flags & THREAD_STATUS_HAS_GIL)
on_cpu = bool(status_flags & THREAD_STATUS_ON_CPU)
if not (has_gil or on_cpu):
continue
# Track thread IDs
if thread_id is not None and thread_id not in self.thread_ids:
self.thread_ids.append(thread_id)
frames = getattr(thread_info, "frame_info", None)
if frames:
self.process_frames(frames, thread_id=thread_id)
# Track thread IDs only for threads that actually have samples
if (
thread_id is not None
and thread_id not in self.thread_ids
):
self.thread_ids.append(thread_id)
# Increment per-thread sample count and check for GC frames
thread_has_gc_frame = False
for frame in frames:
funcname = getattr(frame, "funcname", "")
if "<GC>" in funcname or "gc_collect" in funcname:
has_gc_frame = True
thread_has_gc_frame = True
break
if thread_id is not None:
thread_data = self._get_or_create_thread_data(thread_id)
thread_data.sample_count += 1
if thread_has_gc_frame:
thread_data.gc_frame_samples += 1
# Update cumulative thread status counts
for key, count in temp_status_counts.items():
self.thread_status_counts[key] += count
if thread_id is not None:
thread_data = self._get_or_create_thread_data(thread_id)
thread_data.sample_count += 1
if has_gc_frame:
self.gc_frame_samples += 1

View file

@ -42,8 +42,14 @@ def _process_frames(self, frames):
self.callers[callee][caller] += 1
def collect(self, stack_frames):
for frames, thread_id in self._iter_all_frames(stack_frames, skip_idle=self.skip_idle):
self._process_frames(frames)
if stack_frames and hasattr(stack_frames[0], "awaited_by"):
# Async frame processing
for frames, thread_id, task_id in self._iter_async_frames(stack_frames):
self._process_frames(frames)
else:
# Regular frame processing
for frames, thread_id in self._iter_all_frames(stack_frames, skip_idle=self.skip_idle):
self._process_frames(frames)
def export(self, filename):
self.create_stats()

View file

@ -48,7 +48,7 @@ def __init__(self, pid, sample_interval_usec, all_threads, *, mode=PROFILING_MOD
self.total_samples = 0
self.realtime_stats = False
def sample(self, collector, duration_sec=10):
def sample(self, collector, duration_sec=10, *, async_aware=False):
sample_interval_sec = self.sample_interval_usec / 1_000_000
running_time = 0
num_samples = 0
@ -68,7 +68,12 @@ def sample(self, collector, duration_sec=10):
current_time = time.perf_counter()
if next_time < current_time:
try:
stack_frames = self.unwinder.get_stack_trace()
if async_aware == "all":
stack_frames = self.unwinder.get_all_awaited_by()
elif async_aware == "running":
stack_frames = self.unwinder.get_async_stack_trace()
else:
stack_frames = self.unwinder.get_stack_trace()
collector.collect(stack_frames)
except ProcessLookupError:
duration_sec = current_time - start_time
@ -191,6 +196,7 @@ def sample(
all_threads=False,
realtime_stats=False,
mode=PROFILING_MODE_WALL,
async_aware=None,
native=False,
gc=True,
):
@ -233,7 +239,7 @@ def sample(
profiler.realtime_stats = realtime_stats
# Run the sampling
profiler.sample(collector, duration_sec)
profiler.sample(collector, duration_sec, async_aware=async_aware)
return collector
@ -246,6 +252,7 @@ def sample_live(
all_threads=False,
realtime_stats=False,
mode=PROFILING_MODE_WALL,
async_aware=None,
native=False,
gc=True,
):
@ -290,7 +297,7 @@ def sample_live(
def curses_wrapper_func(stdscr):
collector.init_curses(stdscr)
try:
profiler.sample(collector, duration_sec)
profiler.sample(collector, duration_sec, async_aware=async_aware)
# Mark as finished and keep the TUI running until user presses 'q'
collector.mark_finished()
# Keep processing input until user quits

View file

@ -17,10 +17,18 @@ def __init__(self, sample_interval_usec, *, skip_idle=False):
self.skip_idle = skip_idle
def collect(self, stack_frames, skip_idle=False):
for frames, thread_id in self._iter_all_frames(stack_frames, skip_idle=skip_idle):
if not frames:
continue
self.process_frames(frames, thread_id)
if stack_frames and hasattr(stack_frames[0], "awaited_by"):
# Async-aware mode: process async task frames
for frames, thread_id, task_id in self._iter_async_frames(stack_frames):
if not frames:
continue
self.process_frames(frames, thread_id)
else:
# Sync-only mode
for frames, thread_id in self._iter_all_frames(stack_frames, skip_idle=skip_idle):
if not frames:
continue
self.process_frames(frames, thread_id)
def process_frames(self, frames, thread_id):
pass

View file

@ -36,3 +36,38 @@ def __init__(self, interpreter_id, threads):
def __repr__(self):
return f"MockInterpreterInfo(interpreter_id={self.interpreter_id}, threads={self.threads})"
class MockCoroInfo:
"""Mock CoroInfo for testing async tasks."""
def __init__(self, task_name, call_stack):
self.task_name = task_name # In reality, this is the parent task ID
self.call_stack = call_stack
def __repr__(self):
return f"MockCoroInfo(task_name={self.task_name}, call_stack={self.call_stack})"
class MockTaskInfo:
"""Mock TaskInfo for testing async tasks."""
def __init__(self, task_id, task_name, coroutine_stack, awaited_by=None):
self.task_id = task_id
self.task_name = task_name
self.coroutine_stack = coroutine_stack # List of CoroInfo objects
self.awaited_by = awaited_by or [] # List of CoroInfo objects (parents)
def __repr__(self):
return f"MockTaskInfo(task_id={self.task_id}, task_name={self.task_name})"
class MockAwaitedInfo:
"""Mock AwaitedInfo for testing async tasks."""
def __init__(self, thread_id, awaited_by):
self.thread_id = thread_id
self.awaited_by = awaited_by # List of TaskInfo objects
def __repr__(self):
return f"MockAwaitedInfo(thread_id={self.thread_id}, awaited_by={len(self.awaited_by)} tasks)"

View file

@ -0,0 +1,799 @@
"""Tests for async stack reconstruction in the sampling profiler.
Each test covers a distinct algorithm path or edge case:
1. Graph building: _build_task_graph()
2. Leaf identification: _find_leaf_tasks()
3. Stack traversal: _build_linear_stacks() with BFS
"""
import unittest
try:
import _remote_debugging # noqa: F401
from profiling.sampling.pstats_collector import PstatsCollector
except ImportError:
raise unittest.SkipTest(
"Test only runs when _remote_debugging is available"
)
from .mocks import MockFrameInfo, MockCoroInfo, MockTaskInfo, MockAwaitedInfo
class TestAsyncStackReconstruction(unittest.TestCase):
"""Test async task tree linear stack reconstruction algorithm."""
def test_empty_input(self):
"""Test _build_task_graph with empty awaited_info_list."""
collector = PstatsCollector(sample_interval_usec=1000)
stacks = list(collector._iter_async_frames([]))
self.assertEqual(len(stacks), 0)
def test_single_root_task(self):
"""Test _find_leaf_tasks: root task with no parents is its own leaf."""
collector = PstatsCollector(sample_interval_usec=1000)
root = MockTaskInfo(
task_id=123,
task_name="Task-1",
coroutine_stack=[
MockCoroInfo(
task_name="Task-1",
call_stack=[MockFrameInfo("main.py", 10, "main")]
)
],
awaited_by=[]
)
awaited_info_list = [MockAwaitedInfo(thread_id=100, awaited_by=[root])]
stacks = list(collector._iter_async_frames(awaited_info_list))
# Single root is both leaf and root
self.assertEqual(len(stacks), 1)
frames, thread_id, leaf_id = stacks[0]
self.assertEqual(leaf_id, 123)
self.assertEqual(thread_id, 100)
def test_parent_child_chain(self):
"""Test _build_linear_stacks: BFS follows parent links from leaf to root.
Task graph:
Parent (id=1)
|
Child (id=2)
"""
collector = PstatsCollector(sample_interval_usec=1000)
child = MockTaskInfo(
task_id=2,
task_name="Child",
coroutine_stack=[
MockCoroInfo(task_name="Child", call_stack=[MockFrameInfo("c.py", 5, "child_fn")])
],
awaited_by=[
MockCoroInfo(task_name=1, call_stack=[MockFrameInfo("p.py", 10, "parent_await")])
]
)
parent = MockTaskInfo(
task_id=1,
task_name="Parent",
coroutine_stack=[
MockCoroInfo(task_name="Parent", call_stack=[MockFrameInfo("p.py", 15, "parent_fn")])
],
awaited_by=[]
)
awaited_info_list = [MockAwaitedInfo(thread_id=200, awaited_by=[child, parent])]
stacks = list(collector._iter_async_frames(awaited_info_list))
# Leaf is child, traverses to parent
self.assertEqual(len(stacks), 1)
frames, thread_id, leaf_id = stacks[0]
self.assertEqual(leaf_id, 2)
# Verify both child and parent frames present
func_names = [f.funcname for f in frames]
self.assertIn("child_fn", func_names)
self.assertIn("parent_fn", func_names)
def test_multiple_leaf_tasks(self):
"""Test _find_leaf_tasks: identifies multiple leaves correctly.
Task graph (fan-out from root):
Root (id=1)
/ \
Leaf1 (id=10) Leaf2 (id=20)
Expected: 2 stacks (one for each leaf).
"""
collector = PstatsCollector(sample_interval_usec=1000)
leaf1 = MockTaskInfo(
task_id=10,
task_name="Leaf1",
coroutine_stack=[MockCoroInfo(task_name="Leaf1", call_stack=[MockFrameInfo("l1.py", 1, "f1")])],
awaited_by=[MockCoroInfo(task_name=1, call_stack=[MockFrameInfo("r.py", 5, "root")])]
)
leaf2 = MockTaskInfo(
task_id=20,
task_name="Leaf2",
coroutine_stack=[MockCoroInfo(task_name="Leaf2", call_stack=[MockFrameInfo("l2.py", 2, "f2")])],
awaited_by=[MockCoroInfo(task_name=1, call_stack=[MockFrameInfo("r.py", 5, "root")])]
)
root = MockTaskInfo(
task_id=1,
task_name="Root",
coroutine_stack=[MockCoroInfo(task_name="Root", call_stack=[MockFrameInfo("r.py", 10, "main")])],
awaited_by=[]
)
awaited_info_list = [MockAwaitedInfo(thread_id=300, awaited_by=[leaf1, leaf2, root])]
stacks = list(collector._iter_async_frames(awaited_info_list))
# Two leaves = two stacks
self.assertEqual(len(stacks), 2)
leaf_ids = {leaf_id for _, _, leaf_id in stacks}
self.assertEqual(leaf_ids, {10, 20})
def test_cycle_detection(self):
"""Test _build_linear_stacks: cycle detection prevents infinite loops.
Task graph (cyclic dependency):
A (id=1) <---> B (id=2)
Neither task is a leaf (both have parents), so no stacks are produced.
"""
collector = PstatsCollector(sample_interval_usec=1000)
task_a = MockTaskInfo(
task_id=1,
task_name="A",
coroutine_stack=[MockCoroInfo(task_name="A", call_stack=[MockFrameInfo("a.py", 1, "a")])],
awaited_by=[MockCoroInfo(task_name=2, call_stack=[MockFrameInfo("b.py", 5, "b")])]
)
task_b = MockTaskInfo(
task_id=2,
task_name="B",
coroutine_stack=[MockCoroInfo(task_name="B", call_stack=[MockFrameInfo("b.py", 10, "b")])],
awaited_by=[MockCoroInfo(task_name=1, call_stack=[MockFrameInfo("a.py", 15, "a")])]
)
awaited_info_list = [MockAwaitedInfo(thread_id=400, awaited_by=[task_a, task_b])]
stacks = list(collector._iter_async_frames(awaited_info_list))
# No leaves (both have parents), should return empty
self.assertEqual(len(stacks), 0)
def test_orphaned_parent_reference(self):
"""Test _build_linear_stacks: handles parent ID not in task_map."""
collector = PstatsCollector(sample_interval_usec=1000)
# Task references non-existent parent
orphan = MockTaskInfo(
task_id=5,
task_name="Orphan",
coroutine_stack=[MockCoroInfo(task_name="Orphan", call_stack=[MockFrameInfo("o.py", 1, "orphan")])],
awaited_by=[MockCoroInfo(task_name=999, call_stack=[])] # 999 doesn't exist
)
awaited_info_list = [MockAwaitedInfo(thread_id=500, awaited_by=[orphan])]
stacks = list(collector._iter_async_frames(awaited_info_list))
# Stops at missing parent, yields what it has
self.assertEqual(len(stacks), 1)
frames, _, leaf_id = stacks[0]
self.assertEqual(leaf_id, 5)
def test_multiple_coroutines_per_task(self):
"""Test _build_linear_stacks: collects frames from all coroutines in task."""
collector = PstatsCollector(sample_interval_usec=1000)
# Task with multiple coroutines (e.g., nested async generators)
task = MockTaskInfo(
task_id=7,
task_name="Multi",
coroutine_stack=[
MockCoroInfo(task_name="Multi", call_stack=[MockFrameInfo("g.py", 5, "gen1")]),
MockCoroInfo(task_name="Multi", call_stack=[MockFrameInfo("g.py", 10, "gen2")]),
],
awaited_by=[]
)
awaited_info_list = [MockAwaitedInfo(thread_id=600, awaited_by=[task])]
stacks = list(collector._iter_async_frames(awaited_info_list))
self.assertEqual(len(stacks), 1)
frames, _, _ = stacks[0]
# Both coroutine frames should be present
func_names = [f.funcname for f in frames]
self.assertIn("gen1", func_names)
self.assertIn("gen2", func_names)
def test_multiple_threads(self):
"""Test _build_task_graph: handles multiple AwaitedInfo (different threads)."""
collector = PstatsCollector(sample_interval_usec=1000)
# Two threads with separate task trees
thread1_task = MockTaskInfo(
task_id=100,
task_name="T1",
coroutine_stack=[MockCoroInfo(task_name="T1", call_stack=[MockFrameInfo("t1.py", 1, "t1")])],
awaited_by=[]
)
thread2_task = MockTaskInfo(
task_id=200,
task_name="T2",
coroutine_stack=[MockCoroInfo(task_name="T2", call_stack=[MockFrameInfo("t2.py", 1, "t2")])],
awaited_by=[]
)
awaited_info_list = [
MockAwaitedInfo(thread_id=1, awaited_by=[thread1_task]),
MockAwaitedInfo(thread_id=2, awaited_by=[thread2_task]),
]
stacks = list(collector._iter_async_frames(awaited_info_list))
# Two threads = two stacks
self.assertEqual(len(stacks), 2)
# Verify thread IDs preserved
thread_ids = {thread_id for _, thread_id, _ in stacks}
self.assertEqual(thread_ids, {1, 2})
def test_collect_public_interface(self):
"""Test collect() method correctly routes to async frame processing."""
collector = PstatsCollector(sample_interval_usec=1000)
child = MockTaskInfo(
task_id=50,
task_name="Child",
coroutine_stack=[MockCoroInfo(task_name="Child", call_stack=[MockFrameInfo("c.py", 1, "child")])],
awaited_by=[MockCoroInfo(task_name=51, call_stack=[])]
)
parent = MockTaskInfo(
task_id=51,
task_name="Parent",
coroutine_stack=[MockCoroInfo(task_name="Parent", call_stack=[MockFrameInfo("p.py", 1, "parent")])],
awaited_by=[]
)
awaited_info_list = [MockAwaitedInfo(thread_id=999, awaited_by=[child, parent])]
# Public interface: collect()
collector.collect(awaited_info_list)
# Verify stats collected
self.assertGreater(len(collector.result), 0)
func_names = [loc[2] for loc in collector.result.keys()]
self.assertIn("child", func_names)
self.assertIn("parent", func_names)
def test_diamond_pattern_multiple_parents(self):
"""Test _build_linear_stacks: task with 2+ parents picks one deterministically.
CRITICAL: Tests that when a task has multiple parents, we pick one parent
deterministically (sorted, first one) and annotate the task name with parent count.
"""
collector = PstatsCollector(sample_interval_usec=1000)
# Diamond pattern: Root spawns A and B, both await Child
#
# Root (id=1)
# / \
# A (id=2) B (id=3)
# \ /
# Child (id=4)
#
child = MockTaskInfo(
task_id=4,
task_name="Child",
coroutine_stack=[MockCoroInfo(task_name="Child", call_stack=[MockFrameInfo("c.py", 1, "child_work")])],
awaited_by=[
MockCoroInfo(task_name=2, call_stack=[MockFrameInfo("a.py", 5, "a_await")]), # Parent A
MockCoroInfo(task_name=3, call_stack=[MockFrameInfo("b.py", 5, "b_await")]), # Parent B
]
)
parent_a = MockTaskInfo(
task_id=2,
task_name="A",
coroutine_stack=[MockCoroInfo(task_name="A", call_stack=[MockFrameInfo("a.py", 10, "a_work")])],
awaited_by=[MockCoroInfo(task_name=1, call_stack=[MockFrameInfo("root.py", 5, "root_spawn")])]
)
parent_b = MockTaskInfo(
task_id=3,
task_name="B",
coroutine_stack=[MockCoroInfo(task_name="B", call_stack=[MockFrameInfo("b.py", 10, "b_work")])],
awaited_by=[MockCoroInfo(task_name=1, call_stack=[MockFrameInfo("root.py", 5, "root_spawn")])]
)
root = MockTaskInfo(
task_id=1,
task_name="Root",
coroutine_stack=[MockCoroInfo(task_name="Root", call_stack=[MockFrameInfo("root.py", 20, "main")])],
awaited_by=[]
)
awaited_info_list = [MockAwaitedInfo(thread_id=777, awaited_by=[child, parent_a, parent_b, root])]
stacks = list(collector._iter_async_frames(awaited_info_list))
# Should get 1 stack: Child->A->Root (picks parent with lowest ID: 2)
self.assertEqual(len(stacks), 1, "Diamond should create only 1 path, picking first sorted parent")
# Verify the single stack
frames, thread_id, leaf_id = stacks[0]
self.assertEqual(leaf_id, 4)
self.assertEqual(thread_id, 777)
func_names = [f.funcname for f in frames]
# Stack should contain child, parent A (id=2, first when sorted), and root
self.assertIn("child_work", func_names)
self.assertIn("a_work", func_names, "Should use parent A (id=2, first when sorted)")
self.assertNotIn("b_work", func_names, "Should not include parent B")
self.assertIn("main", func_names)
# Verify Child task is annotated with parent count
self.assertIn("Child (2 parents)", func_names, "Child task should be annotated with parent count")
def test_empty_coroutine_stack(self):
"""Test _build_linear_stacks: handles empty coroutine_stack (line 109 condition false)."""
collector = PstatsCollector(sample_interval_usec=1000)
# Task with no coroutine_stack
task = MockTaskInfo(
task_id=99,
task_name="EmptyStack",
coroutine_stack=[], # Empty!
awaited_by=[]
)
awaited_info_list = [MockAwaitedInfo(thread_id=111, awaited_by=[task])]
stacks = list(collector._iter_async_frames(awaited_info_list))
self.assertEqual(len(stacks), 1)
frames, _, _ = stacks[0]
# Should only have task marker, no function frames
func_names = [f.funcname for f in frames]
self.assertEqual(len(func_names), 1, "Should only have task marker")
self.assertIn("EmptyStack", func_names)
def test_orphaned_parent_with_no_frames_collected(self):
"""Test _build_linear_stacks: orphaned parent at start with empty frames (line 94-96)."""
collector = PstatsCollector(sample_interval_usec=1000)
# Leaf that doesn't exist in task_map (should not happen normally, but test robustness)
# We'll create a scenario where the leaf_id is present but empty
# Task references non-existent parent, and has no coroutine_stack
orphan = MockTaskInfo(
task_id=88,
task_name="Orphan",
coroutine_stack=[], # No frames
awaited_by=[MockCoroInfo(task_name=999, call_stack=[])] # Parent doesn't exist
)
awaited_info_list = [MockAwaitedInfo(thread_id=222, awaited_by=[orphan])]
stacks = list(collector._iter_async_frames(awaited_info_list))
# Should yield because we have the task marker even with no function frames
self.assertEqual(len(stacks), 1)
frames, _, leaf_id = stacks[0]
self.assertEqual(leaf_id, 88)
# Has task marker but no function frames
self.assertGreater(len(frames), 0, "Should have at least task marker")
def test_frame_ordering(self):
"""Test _build_linear_stacks: frames are collected in correct order (leaf->root).
Task graph (3-level chain):
Root (id=1) <- root_bottom, root_top
|
Middle (id=2) <- mid_bottom, mid_top
|
Leaf (id=3) <- leaf_bottom, leaf_top
Expected frame order: leaf_bottom, leaf_top, mid_bottom, mid_top, root_bottom, root_top
(stack is built bottom-up: leaf frames first, then parent frames).
"""
collector = PstatsCollector(sample_interval_usec=1000)
leaf = MockTaskInfo(
task_id=3,
task_name="Leaf",
coroutine_stack=[
MockCoroInfo(task_name="Leaf", call_stack=[
MockFrameInfo("leaf.py", 1, "leaf_bottom"),
MockFrameInfo("leaf.py", 2, "leaf_top"),
])
],
awaited_by=[MockCoroInfo(task_name=2, call_stack=[])]
)
middle = MockTaskInfo(
task_id=2,
task_name="Middle",
coroutine_stack=[
MockCoroInfo(task_name="Middle", call_stack=[
MockFrameInfo("mid.py", 1, "mid_bottom"),
MockFrameInfo("mid.py", 2, "mid_top"),
])
],
awaited_by=[MockCoroInfo(task_name=1, call_stack=[])]
)
root = MockTaskInfo(
task_id=1,
task_name="Root",
coroutine_stack=[
MockCoroInfo(task_name="Root", call_stack=[
MockFrameInfo("root.py", 1, "root_bottom"),
MockFrameInfo("root.py", 2, "root_top"),
])
],
awaited_by=[]
)
awaited_info_list = [MockAwaitedInfo(thread_id=333, awaited_by=[leaf, middle, root])]
stacks = list(collector._iter_async_frames(awaited_info_list))
self.assertEqual(len(stacks), 1)
frames, _, _ = stacks[0]
func_names = [f.funcname for f in frames]
# Order should be: leaf frames, leaf marker, middle frames, middle marker, root frames, root marker
leaf_bottom_idx = func_names.index("leaf_bottom")
leaf_top_idx = func_names.index("leaf_top")
mid_bottom_idx = func_names.index("mid_bottom")
root_bottom_idx = func_names.index("root_bottom")
# Verify leaf comes before middle comes before root
self.assertLess(leaf_bottom_idx, leaf_top_idx, "Leaf frames in order")
self.assertLess(leaf_top_idx, mid_bottom_idx, "Leaf before middle")
self.assertLess(mid_bottom_idx, root_bottom_idx, "Middle before root")
def test_complex_multi_parent_convergence(self):
"""Test _build_linear_stacks: multiple leaves with same parents pick deterministically.
Tests that when multiple leaves have multiple parents, each leaf picks the same
parent (sorted, first one) and all leaves are annotated with parent count.
Task graph structure (both leaves awaited by both A and B)::
Root (id=1)
/ \\
A (id=2) B (id=3)
| \\ / |
| \\ / |
| \\/ |
| /\\ |
| / \\ |
LeafX (id=4) LeafY (id=5)
Expected behavior: Both leaves pick parent A (lowest id=2) for their stack path.
Result: 2 stacks, both going through A -> Root (B is skipped).
"""
collector = PstatsCollector(sample_interval_usec=1000)
leaf_x = MockTaskInfo(
task_id=4,
task_name="LeafX",
coroutine_stack=[MockCoroInfo(task_name="LeafX", call_stack=[MockFrameInfo("x.py", 1, "x")])],
awaited_by=[
MockCoroInfo(task_name=2, call_stack=[]),
MockCoroInfo(task_name=3, call_stack=[]),
]
)
leaf_y = MockTaskInfo(
task_id=5,
task_name="LeafY",
coroutine_stack=[MockCoroInfo(task_name="LeafY", call_stack=[MockFrameInfo("y.py", 1, "y")])],
awaited_by=[
MockCoroInfo(task_name=2, call_stack=[]),
MockCoroInfo(task_name=3, call_stack=[]),
]
)
parent_a = MockTaskInfo(
task_id=2,
task_name="A",
coroutine_stack=[MockCoroInfo(task_name="A", call_stack=[MockFrameInfo("a.py", 1, "a")])],
awaited_by=[MockCoroInfo(task_name=1, call_stack=[])]
)
parent_b = MockTaskInfo(
task_id=3,
task_name="B",
coroutine_stack=[MockCoroInfo(task_name="B", call_stack=[MockFrameInfo("b.py", 1, "b")])],
awaited_by=[MockCoroInfo(task_name=1, call_stack=[])]
)
root = MockTaskInfo(
task_id=1,
task_name="Root",
coroutine_stack=[MockCoroInfo(task_name="Root", call_stack=[MockFrameInfo("r.py", 1, "root")])],
awaited_by=[]
)
awaited_info_list = [MockAwaitedInfo(thread_id=444, awaited_by=[leaf_x, leaf_y, parent_a, parent_b, root])]
stacks = list(collector._iter_async_frames(awaited_info_list))
# 2 leaves, each picks same parent (A, id=2) = 2 paths
self.assertEqual(len(stacks), 2, "Should create 2 paths: X->A->Root, Y->A->Root")
# Verify both leaves pick parent A (id=2, first when sorted)
leaf_ids_seen = set()
for frames, _, leaf_id in stacks:
leaf_ids_seen.add(leaf_id)
func_names = [f.funcname for f in frames]
# Both stacks should go through parent A only
self.assertIn("a", func_names, "Should use parent A (id=2, first when sorted)")
self.assertNotIn("b", func_names, "Should not include parent B")
self.assertIn("root", func_names, "Should reach root")
# Check for parent count annotation on the leaf
if leaf_id == 4:
self.assertIn("x", func_names)
self.assertIn("LeafX (2 parents)", func_names, "LeafX should be annotated with parent count")
elif leaf_id == 5:
self.assertIn("y", func_names)
self.assertIn("LeafY (2 parents)", func_names, "LeafY should be annotated with parent count")
# Both leaves should be represented
self.assertEqual(leaf_ids_seen, {4, 5}, "Both LeafX and LeafY should have paths")
class TestFlamegraphCollectorAsync(unittest.TestCase):
"""Test FlamegraphCollector with async frames."""
def test_flamegraph_with_async_frames(self):
"""Test FlamegraphCollector correctly processes async task frames."""
from profiling.sampling.stack_collector import FlamegraphCollector
collector = FlamegraphCollector(sample_interval_usec=1000)
# Build async task tree: Root -> Child
child = MockTaskInfo(
task_id=2,
task_name="ChildTask",
coroutine_stack=[
MockCoroInfo(
task_name="ChildTask",
call_stack=[MockFrameInfo("child.py", 10, "child_work")]
)
],
awaited_by=[MockCoroInfo(task_name=1, call_stack=[])]
)
root = MockTaskInfo(
task_id=1,
task_name="RootTask",
coroutine_stack=[
MockCoroInfo(
task_name="RootTask",
call_stack=[MockFrameInfo("root.py", 20, "root_work")]
)
],
awaited_by=[]
)
awaited_info_list = [MockAwaitedInfo(thread_id=100, awaited_by=[child, root])]
# Collect async frames
collector.collect(awaited_info_list)
# Verify samples were collected
self.assertGreater(collector._total_samples, 0)
# Verify the flamegraph tree structure contains our functions
root_node = collector._root
self.assertGreater(root_node["samples"], 0)
# Check that thread ID was tracked
self.assertIn(100, collector._all_threads)
def test_flamegraph_with_task_markers(self):
"""Test FlamegraphCollector includes <task> boundary markers."""
from profiling.sampling.stack_collector import FlamegraphCollector
collector = FlamegraphCollector(sample_interval_usec=1000)
task = MockTaskInfo(
task_id=42,
task_name="MyTask",
coroutine_stack=[
MockCoroInfo(
task_name="MyTask",
call_stack=[MockFrameInfo("work.py", 5, "do_work")]
)
],
awaited_by=[]
)
awaited_info_list = [MockAwaitedInfo(thread_id=200, awaited_by=[task])]
collector.collect(awaited_info_list)
# Find <task> marker in the tree
def find_task_marker(node, depth=0):
for func, child in node.get("children", {}).items():
if func[0] == "<task>":
return func
result = find_task_marker(child, depth + 1)
if result:
return result
return None
task_marker = find_task_marker(collector._root)
self.assertIsNotNone(task_marker, "Should have <task> marker in tree")
self.assertEqual(task_marker[0], "<task>")
self.assertIn("MyTask", task_marker[2])
def test_flamegraph_multiple_async_samples(self):
"""Test FlamegraphCollector aggregates multiple async samples correctly."""
from profiling.sampling.stack_collector import FlamegraphCollector
collector = FlamegraphCollector(sample_interval_usec=1000)
task = MockTaskInfo(
task_id=1,
task_name="Task",
coroutine_stack=[
MockCoroInfo(
task_name="Task",
call_stack=[MockFrameInfo("work.py", 10, "work")]
)
],
awaited_by=[]
)
awaited_info_list = [MockAwaitedInfo(thread_id=300, awaited_by=[task])]
# Collect multiple samples
for _ in range(5):
collector.collect(awaited_info_list)
# Verify sample count
self.assertEqual(collector._sample_count, 5)
self.assertEqual(collector._total_samples, 5)
class TestAsyncAwareParameterFlow(unittest.TestCase):
"""Integration tests for async_aware parameter flow from CLI to unwinder."""
def test_sample_function_accepts_async_aware(self):
"""Test that sample() function accepts async_aware parameter."""
from profiling.sampling.sample import sample
import inspect
sig = inspect.signature(sample)
self.assertIn("async_aware", sig.parameters)
def test_sample_live_function_accepts_async_aware(self):
"""Test that sample_live() function accepts async_aware parameter."""
from profiling.sampling.sample import sample_live
import inspect
sig = inspect.signature(sample_live)
self.assertIn("async_aware", sig.parameters)
def test_sample_profiler_sample_accepts_async_aware(self):
"""Test that SampleProfiler.sample() accepts async_aware parameter."""
from profiling.sampling.sample import SampleProfiler
import inspect
sig = inspect.signature(SampleProfiler.sample)
self.assertIn("async_aware", sig.parameters)
def test_async_aware_all_sees_sleeping_and_running_tasks(self):
"""Test async_aware='all' captures both sleeping and CPU-running tasks."""
# Sleeping task (awaiting)
sleeping_task = MockTaskInfo(
task_id=1,
task_name="SleepingTask",
coroutine_stack=[
MockCoroInfo(
task_name="SleepingTask",
call_stack=[MockFrameInfo("sleeper.py", 10, "sleep_work")]
)
],
awaited_by=[]
)
# CPU-running task (active)
running_task = MockTaskInfo(
task_id=2,
task_name="RunningTask",
coroutine_stack=[
MockCoroInfo(
task_name="RunningTask",
call_stack=[MockFrameInfo("runner.py", 20, "cpu_work")]
)
],
awaited_by=[]
)
# Both tasks returned by get_all_awaited_by
awaited_info_list = [MockAwaitedInfo(thread_id=100, awaited_by=[sleeping_task, running_task])]
collector = PstatsCollector(sample_interval_usec=1000)
collector.collect(awaited_info_list)
collector.create_stats()
# Both tasks should be visible
sleeping_key = ("sleeper.py", 10, "sleep_work")
running_key = ("runner.py", 20, "cpu_work")
self.assertIn(sleeping_key, collector.stats)
self.assertIn(running_key, collector.stats)
# Task markers should also be present
task_keys = [k for k in collector.stats if k[0] == "<task>"]
self.assertGreater(len(task_keys), 0, "Should have <task> markers in stats")
# Verify task names are in the markers
task_names = [k[2] for k in task_keys]
self.assertTrue(
any("SleepingTask" in name for name in task_names),
"SleepingTask should be in task markers"
)
self.assertTrue(
any("RunningTask" in name for name in task_names),
"RunningTask should be in task markers"
)
def test_async_aware_running_sees_only_running_task(self):
"""Test async_aware='running' only shows the currently running task stack."""
# Only the running task's stack is returned by get_async_stack_trace
running_task = MockTaskInfo(
task_id=2,
task_name="RunningTask",
coroutine_stack=[
MockCoroInfo(
task_name="RunningTask",
call_stack=[MockFrameInfo("runner.py", 20, "cpu_work")]
)
],
awaited_by=[]
)
# get_async_stack_trace only returns the running task
awaited_info_list = [MockAwaitedInfo(thread_id=100, awaited_by=[running_task])]
collector = PstatsCollector(sample_interval_usec=1000)
collector.collect(awaited_info_list)
collector.create_stats()
# Only running task should be visible
running_key = ("runner.py", 20, "cpu_work")
self.assertIn(running_key, collector.stats)
# Verify we don't see the sleeping task (it wasn't in the input)
sleeping_key = ("sleeper.py", 10, "sleep_work")
self.assertNotIn(sleeping_key, collector.stats)
# Task marker for running task should be present
task_keys = [k for k in collector.stats if k[0] == "<task>"]
self.assertGreater(len(task_keys), 0, "Should have <task> markers in stats")
task_names = [k[2] for k in task_keys]
self.assertTrue(
any("RunningTask" in name for name in task_names),
"RunningTask should be in task markers"
)
if __name__ == "__main__":
unittest.main()

View file

@ -547,3 +547,165 @@ def test_sort_options(self):
mock_sample.assert_called_once()
mock_sample.reset_mock()
def test_async_aware_flag_defaults_to_running(self):
"""Test --async-aware flag enables async profiling with default 'running' mode."""
test_args = ["profiling.sampling.cli", "attach", "12345", "--async-aware"]
with (
mock.patch("sys.argv", test_args),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
from profiling.sampling.cli import main
main()
mock_sample.assert_called_once()
# Verify async_aware was passed with default "running" mode
call_kwargs = mock_sample.call_args[1]
self.assertEqual(call_kwargs.get("async_aware"), "running")
def test_async_aware_with_async_mode_all(self):
"""Test --async-aware with --async-mode all."""
test_args = ["profiling.sampling.cli", "attach", "12345", "--async-aware", "--async-mode", "all"]
with (
mock.patch("sys.argv", test_args),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
from profiling.sampling.cli import main
main()
mock_sample.assert_called_once()
call_kwargs = mock_sample.call_args[1]
self.assertEqual(call_kwargs.get("async_aware"), "all")
def test_async_aware_default_is_none(self):
"""Test async_aware defaults to None when --async-aware not specified."""
test_args = ["profiling.sampling.cli", "attach", "12345"]
with (
mock.patch("sys.argv", test_args),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
from profiling.sampling.cli import main
main()
mock_sample.assert_called_once()
call_kwargs = mock_sample.call_args[1]
self.assertIsNone(call_kwargs.get("async_aware"))
def test_async_mode_invalid_choice(self):
"""Test --async-mode with invalid choice raises error."""
test_args = ["profiling.sampling.cli", "attach", "12345", "--async-aware", "--async-mode", "invalid"]
with (
mock.patch("sys.argv", test_args),
mock.patch("sys.stderr", io.StringIO()),
self.assertRaises(SystemExit) as cm,
):
from profiling.sampling.cli import main
main()
self.assertEqual(cm.exception.code, 2) # argparse error
def test_async_mode_requires_async_aware(self):
"""Test --async-mode without --async-aware raises error."""
test_args = ["profiling.sampling.cli", "attach", "12345", "--async-mode", "all"]
with (
mock.patch("sys.argv", test_args),
mock.patch("sys.stderr", io.StringIO()) as mock_stderr,
self.assertRaises(SystemExit) as cm,
):
from profiling.sampling.cli import main
main()
self.assertEqual(cm.exception.code, 2) # argparse error
error_msg = mock_stderr.getvalue()
self.assertIn("--async-mode requires --async-aware", error_msg)
def test_async_aware_incompatible_with_native(self):
"""Test --async-aware is incompatible with --native."""
test_args = ["profiling.sampling.cli", "attach", "12345", "--async-aware", "--native"]
with (
mock.patch("sys.argv", test_args),
mock.patch("sys.stderr", io.StringIO()) as mock_stderr,
self.assertRaises(SystemExit) as cm,
):
from profiling.sampling.cli import main
main()
self.assertEqual(cm.exception.code, 2) # argparse error
error_msg = mock_stderr.getvalue()
self.assertIn("--native", error_msg)
self.assertIn("incompatible with --async-aware", error_msg)
def test_async_aware_incompatible_with_no_gc(self):
"""Test --async-aware is incompatible with --no-gc."""
test_args = ["profiling.sampling.cli", "attach", "12345", "--async-aware", "--no-gc"]
with (
mock.patch("sys.argv", test_args),
mock.patch("sys.stderr", io.StringIO()) as mock_stderr,
self.assertRaises(SystemExit) as cm,
):
from profiling.sampling.cli import main
main()
self.assertEqual(cm.exception.code, 2) # argparse error
error_msg = mock_stderr.getvalue()
self.assertIn("--no-gc", error_msg)
self.assertIn("incompatible with --async-aware", error_msg)
def test_async_aware_incompatible_with_both_native_and_no_gc(self):
"""Test --async-aware is incompatible with both --native and --no-gc."""
test_args = ["profiling.sampling.cli", "attach", "12345", "--async-aware", "--native", "--no-gc"]
with (
mock.patch("sys.argv", test_args),
mock.patch("sys.stderr", io.StringIO()) as mock_stderr,
self.assertRaises(SystemExit) as cm,
):
from profiling.sampling.cli import main
main()
self.assertEqual(cm.exception.code, 2) # argparse error
error_msg = mock_stderr.getvalue()
self.assertIn("--native", error_msg)
self.assertIn("--no-gc", error_msg)
self.assertIn("incompatible with --async-aware", error_msg)
def test_async_aware_incompatible_with_mode(self):
"""Test --async-aware is incompatible with --mode (non-wall)."""
test_args = ["profiling.sampling.cli", "attach", "12345", "--async-aware", "--mode", "cpu"]
with (
mock.patch("sys.argv", test_args),
mock.patch("sys.stderr", io.StringIO()) as mock_stderr,
self.assertRaises(SystemExit) as cm,
):
from profiling.sampling.cli import main
main()
self.assertEqual(cm.exception.code, 2) # argparse error
error_msg = mock_stderr.getvalue()
self.assertIn("--mode=cpu", error_msg)
self.assertIn("incompatible with --async-aware", error_msg)
def test_async_aware_incompatible_with_all_threads(self):
"""Test --async-aware is incompatible with --all-threads."""
test_args = ["profiling.sampling.cli", "attach", "12345", "--async-aware", "--all-threads"]
with (
mock.patch("sys.argv", test_args),
mock.patch("sys.stderr", io.StringIO()) as mock_stderr,
self.assertRaises(SystemExit) as cm,
):
from profiling.sampling.cli import main
main()
self.assertEqual(cm.exception.code, 2) # argparse error
error_msg = mock_stderr.getvalue()
self.assertIn("--all-threads", error_msg)
self.assertIn("incompatible with --async-aware", error_msg)

View file

@ -780,3 +780,128 @@ def test_live_incompatible_with_pstats_default_values(self):
from profiling.sampling.cli import main
main()
self.assertNotEqual(cm.exception.code, 0)
@requires_subprocess()
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
class TestAsyncAwareProfilingIntegration(unittest.TestCase):
"""Integration tests for async-aware profiling mode."""
@classmethod
def setUpClass(cls):
cls.async_script = '''
import asyncio
async def sleeping_leaf():
"""Leaf task that just sleeps - visible in 'all' mode."""
for _ in range(50):
await asyncio.sleep(0.02)
async def cpu_leaf():
"""Leaf task that does CPU work - visible in both modes."""
total = 0
for _ in range(200):
for i in range(10000):
total += i * i
await asyncio.sleep(0)
return total
async def supervisor():
"""Middle layer that spawns leaf tasks."""
tasks = [
asyncio.create_task(sleeping_leaf(), name="Sleeper-0"),
asyncio.create_task(sleeping_leaf(), name="Sleeper-1"),
asyncio.create_task(sleeping_leaf(), name="Sleeper-2"),
asyncio.create_task(cpu_leaf(), name="Worker"),
]
await asyncio.gather(*tasks)
async def main():
await supervisor()
if __name__ == "__main__":
asyncio.run(main())
'''
def _collect_async_samples(self, async_aware_mode):
"""Helper to collect samples and count function occurrences.
Returns a dict mapping function names to their sample counts.
"""
with test_subprocess(self.async_script) as subproc:
try:
collector = CollapsedStackCollector(1000, skip_idle=False)
profiling.sampling.sample.sample(
subproc.process.pid,
collector,
duration_sec=SHORT_TIMEOUT,
async_aware=async_aware_mode,
)
except PermissionError:
self.skipTest("Insufficient permissions for remote profiling")
# Count samples per function from collapsed stacks
# stack_counter keys are (call_tree, thread_id) where call_tree
# is a tuple of (file, line, func) tuples
func_samples = {}
total = 0
for (call_tree, _thread_id), count in collector.stack_counter.items():
total += count
for _file, _line, func in call_tree:
func_samples[func] = func_samples.get(func, 0) + count
func_samples["_total"] = total
return func_samples
def test_async_aware_all_sees_sleeping_and_running_tasks(self):
"""Test that async_aware='all' captures both sleeping and CPU-running tasks.
Task tree structure:
main
supervisor
Sleeper-0 (sleeping_leaf)
Sleeper-1 (sleeping_leaf)
Sleeper-2 (sleeping_leaf)
Worker (cpu_leaf)
async_aware='all' should see ALL 4 leaf tasks in the output.
"""
samples = self._collect_async_samples("all")
self.assertGreater(samples["_total"], 0, "Should have collected samples")
self.assertIn("sleeping_leaf", samples)
self.assertIn("cpu_leaf", samples)
self.assertIn("supervisor", samples)
def test_async_aware_running_sees_only_cpu_task(self):
"""Test that async_aware='running' only captures the actively running task.
Task tree structure:
main
supervisor
Sleeper-0 (sleeping_leaf) - NOT visible in 'running'
Sleeper-1 (sleeping_leaf) - NOT visible in 'running'
Sleeper-2 (sleeping_leaf) - NOT visible in 'running'
Worker (cpu_leaf) - VISIBLE in 'running'
async_aware='running' should only see the Worker task doing CPU work.
"""
samples = self._collect_async_samples("running")
total = samples["_total"]
cpu_leaf_samples = samples.get("cpu_leaf", 0)
self.assertGreater(total, 0, "Should have collected some samples")
self.assertGreater(cpu_leaf_samples, 0, "cpu_leaf should appear in samples")
# cpu_leaf should have at least 90% of samples (typically 99%+)
# sleeping_leaf may occasionally appear with very few samples (< 1%)
# when tasks briefly wake up to check sleep timers
cpu_percentage = (cpu_leaf_samples / total) * 100
self.assertGreater(cpu_percentage, 90.0,
f"cpu_leaf should dominate samples in 'running' mode, "
f"got {cpu_percentage:.1f}% ({cpu_leaf_samples}/{total})")

View file

@ -0,0 +1 @@
Add async-aware profiling to the Tachyon sampling profiler. The profiler now reconstructs and displays async task hierarchies in flamegraphs, making the output more actionable for users. Patch by Savannah Ostrowski and Pablo Galindo Salgado.

View file

@ -405,6 +405,7 @@ extern PyObject* unwind_stack_for_thread(
extern uintptr_t _Py_RemoteDebug_GetAsyncioDebugAddress(proc_handle_t* handle);
extern int read_async_debug(RemoteUnwinderObject *unwinder);
extern int ensure_async_debug_offsets(RemoteUnwinderObject *unwinder);
/* Task parsing */
extern PyObject *parse_task_name(RemoteUnwinderObject *unwinder, uintptr_t task_address);

View file

@ -71,6 +71,28 @@ read_async_debug(RemoteUnwinderObject *unwinder)
return result;
}
int
ensure_async_debug_offsets(RemoteUnwinderObject *unwinder)
{
// If already available, nothing to do
if (unwinder->async_debug_offsets_available) {
return 0;
}
// Try to load async debug offsets (the target process may have
// loaded asyncio since we last checked)
if (read_async_debug(unwinder) < 0) {
PyErr_Clear();
PyErr_SetString(PyExc_RuntimeError, "AsyncioDebug section not available");
set_exception_cause(unwinder, PyExc_RuntimeError,
"AsyncioDebug section unavailable - asyncio module may not be loaded in target process");
return -1;
}
unwinder->async_debug_offsets_available = 1;
return 0;
}
/* ============================================================================
* SET ITERATION FUNCTIONS
* ============================================================================ */

View file

@ -645,9 +645,7 @@ static PyObject *
_remote_debugging_RemoteUnwinder_get_all_awaited_by_impl(RemoteUnwinderObject *self)
/*[clinic end generated code: output=6a49cd345e8aec53 input=307f754cbe38250c]*/
{
if (!self->async_debug_offsets_available) {
PyErr_SetString(PyExc_RuntimeError, "AsyncioDebug section not available");
set_exception_cause(self, PyExc_RuntimeError, "AsyncioDebug section unavailable in get_all_awaited_by");
if (ensure_async_debug_offsets(self) < 0) {
return NULL;
}
@ -736,9 +734,7 @@ static PyObject *
_remote_debugging_RemoteUnwinder_get_async_stack_trace_impl(RemoteUnwinderObject *self)
/*[clinic end generated code: output=6433d52b55e87bbe input=6129b7d509a887c9]*/
{
if (!self->async_debug_offsets_available) {
PyErr_SetString(PyExc_RuntimeError, "AsyncioDebug section not available");
set_exception_cause(self, PyExc_RuntimeError, "AsyncioDebug section unavailable in get_async_stack_trace");
if (ensure_async_debug_offsets(self) < 0) {
return NULL;
}