"""Tests for sampling profiler core functionality.""" import io from unittest import mock import unittest try: import _remote_debugging # noqa: F401 from profiling.sampling.sample import SampleProfiler from profiling.sampling.pstats_collector import PstatsCollector except ImportError: raise unittest.SkipTest( "Test only runs when _remote_debugging is available" ) from test.support import force_not_colorized_test_class def print_sampled_stats(stats, sort=-1, limit=None, show_summary=True, sample_interval_usec=100): """Helper function to maintain compatibility with old test API. This wraps the new PstatsCollector.print_stats() API to work with the existing test infrastructure. """ # Create a mock collector that populates stats correctly collector = PstatsCollector(sample_interval_usec=sample_interval_usec) # Override create_stats to populate self.stats with the provided stats def mock_create_stats(): collector.stats = stats.stats collector.create_stats = mock_create_stats # Call the new print_stats method collector.print_stats(sort=sort, limit=limit, show_summary=show_summary) class TestSampleProfiler(unittest.TestCase): """Test the SampleProfiler class.""" def test_sample_profiler_initialization(self): """Test SampleProfiler initialization with various parameters.""" # Mock RemoteUnwinder to avoid permission issues with mock.patch( "_remote_debugging.RemoteUnwinder" ) as mock_unwinder_class: mock_unwinder_class.return_value = mock.MagicMock() # Test basic initialization profiler = SampleProfiler( pid=12345, sample_interval_usec=1000, all_threads=False ) self.assertEqual(profiler.pid, 12345) self.assertEqual(profiler.sample_interval_usec, 1000) self.assertEqual(profiler.all_threads, False) # Test with all_threads=True profiler = SampleProfiler( pid=54321, sample_interval_usec=5000, all_threads=True ) self.assertEqual(profiler.pid, 54321) self.assertEqual(profiler.sample_interval_usec, 5000) self.assertEqual(profiler.all_threads, True) def test_sample_profiler_sample_method_timing(self): """Test that the sample method respects duration and handles timing correctly.""" # Mock the unwinder to avoid needing a real process mock_unwinder = mock.MagicMock() mock_unwinder.get_stack_trace.return_value = [ ( 1, [ mock.MagicMock( filename="test.py", lineno=10, funcname="test_func" ) ], ) ] with mock.patch( "_remote_debugging.RemoteUnwinder" ) as mock_unwinder_class: mock_unwinder_class.return_value = mock_unwinder profiler = SampleProfiler( pid=12345, sample_interval_usec=100000, all_threads=False ) # 100ms interval # Mock collector mock_collector = mock.MagicMock() # Mock time to control the sampling loop start_time = 1000.0 times = [ start_time + i * 0.1 for i in range(12) ] # 0, 0.1, 0.2, ..., 1.1 seconds with mock.patch("time.perf_counter", side_effect=times): with io.StringIO() as output: with mock.patch("sys.stdout", output): profiler.sample(mock_collector, duration_sec=1) result = output.getvalue() # Should have captured approximately 10 samples (1 second / 0.1 second interval) self.assertIn("Captured", result) self.assertIn("samples", result) # Verify collector was called multiple times self.assertGreaterEqual(mock_collector.collect.call_count, 5) self.assertLessEqual(mock_collector.collect.call_count, 11) def test_sample_profiler_error_handling(self): """Test that the sample method handles errors gracefully.""" # Mock unwinder that raises errors mock_unwinder = mock.MagicMock() error_sequence = [ RuntimeError("Process died"), [ ( 1, [ mock.MagicMock( filename="test.py", lineno=10, funcname="test_func" ) ], ) ], UnicodeDecodeError("utf-8", b"", 0, 1, "invalid"), [ ( 1, [ mock.MagicMock( filename="test.py", lineno=20, funcname="test_func2", ) ], ) ], OSError("Permission denied"), ] mock_unwinder.get_stack_trace.side_effect = error_sequence with mock.patch( "_remote_debugging.RemoteUnwinder" ) as mock_unwinder_class: mock_unwinder_class.return_value = mock_unwinder profiler = SampleProfiler( pid=12345, sample_interval_usec=10000, all_threads=False ) mock_collector = mock.MagicMock() # Control timing to run exactly 5 samples times = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06] with mock.patch("time.perf_counter", side_effect=times): with io.StringIO() as output: with mock.patch("sys.stdout", output): profiler.sample(mock_collector, duration_sec=0.05) result = output.getvalue() # Should report error rate self.assertIn("Error rate:", result) self.assertIn("%", result) # Collector should have been called only for successful samples (should be > 0) self.assertGreater(mock_collector.collect.call_count, 0) self.assertLessEqual(mock_collector.collect.call_count, 3) def test_sample_profiler_missed_samples_warning(self): """Test that the profiler warns about missed samples when sampling is too slow.""" mock_unwinder = mock.MagicMock() mock_unwinder.get_stack_trace.return_value = [ ( 1, [ mock.MagicMock( filename="test.py", lineno=10, funcname="test_func" ) ], ) ] with mock.patch( "_remote_debugging.RemoteUnwinder" ) as mock_unwinder_class: mock_unwinder_class.return_value = mock_unwinder # Use very short interval that we'll miss profiler = SampleProfiler( pid=12345, sample_interval_usec=1000, all_threads=False ) # 1ms interval mock_collector = mock.MagicMock() # Simulate slow sampling where we miss many samples times = [ 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, ] # Extra time points to avoid StopIteration with mock.patch("time.perf_counter", side_effect=times): with io.StringIO() as output: with mock.patch("sys.stdout", output): profiler.sample(mock_collector, duration_sec=0.5) result = output.getvalue() # Should warn about missed samples self.assertIn("Warning: missed", result) self.assertIn("samples from the expected total", result) def test_sample_profiler_keyboard_interrupt(self): mock_unwinder = mock.MagicMock() mock_unwinder.get_stack_trace.side_effect = [ [ ( 1, [ mock.MagicMock( filename="test.py", lineno=10, funcname="test_func" ) ], ) ], KeyboardInterrupt(), ] with mock.patch( "_remote_debugging.RemoteUnwinder" ) as mock_unwinder_class: mock_unwinder_class.return_value = mock_unwinder profiler = SampleProfiler( pid=12345, sample_interval_usec=10000, all_threads=False ) mock_collector = mock.MagicMock() times = [0.0, 0.01, 0.02, 0.03, 0.04] with mock.patch("time.perf_counter", side_effect=times): with io.StringIO() as output: with mock.patch("sys.stdout", output): try: profiler.sample(mock_collector, duration_sec=1.0) except KeyboardInterrupt: self.fail( "KeyboardInterrupt was not handled by the profiler" ) result = output.getvalue() self.assertIn("Interrupted by user.", result) self.assertIn("Captured", result) self.assertIn("samples", result) self.assertNotIn("Warning: missed", result) @force_not_colorized_test_class class TestPrintSampledStats(unittest.TestCase): """Test the print_sampled_stats function.""" def setUp(self): """Set up test data.""" # Mock stats data self.mock_stats = mock.MagicMock() self.mock_stats.stats = { ("file1.py", 10, "func1"): ( 100, 100, 0.5, 0.5, {}, ), # cc, nc, tt, ct, callers ("file2.py", 20, "func2"): (50, 50, 0.25, 0.3, {}), ("file3.py", 30, "func3"): (200, 200, 1.5, 2.0, {}), ("file4.py", 40, "func4"): ( 10, 10, 0.001, 0.001, {}, ), # millisecond range ("file5.py", 50, "func5"): ( 5, 5, 0.000001, 0.000002, {}, ), # microsecond range } def test_print_sampled_stats_basic(self): """Test basic print_sampled_stats functionality.""" # Capture output with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats(self.mock_stats, sample_interval_usec=100) result = output.getvalue() # Check header is present self.assertIn("Profile Stats:", result) self.assertIn("nsamples", result) self.assertIn("tottime", result) self.assertIn("cumtime", result) # Check functions are present self.assertIn("func1", result) self.assertIn("func2", result) self.assertIn("func3", result) def test_print_sampled_stats_sorting(self): """Test different sorting options.""" # Test sort by calls with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, sort=0, sample_interval_usec=100 ) result = output.getvalue() lines = result.strip().split("\n") # Find the data lines (skip header) data_lines = [l for l in lines if "file" in l and ".py" in l] # func3 should be first (200 calls) self.assertIn("func3", data_lines[0]) # Test sort by time with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, sort=1, sample_interval_usec=100 ) result = output.getvalue() lines = result.strip().split("\n") data_lines = [l for l in lines if "file" in l and ".py" in l] # func3 should be first (1.5s time) self.assertIn("func3", data_lines[0]) def test_print_sampled_stats_limit(self): """Test limiting output rows.""" with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, limit=2, sample_interval_usec=100 ) result = output.getvalue() # Count function entries in the main stats section (not in summary) lines = result.split("\n") # Find where the main stats section ends (before summary) main_section_lines = [] for line in lines: if "Summary of Interesting Functions:" in line: break main_section_lines.append(line) # Count function entries only in main section func_count = sum( 1 for line in main_section_lines if "func" in line and ".py" in line ) self.assertEqual(func_count, 2) def test_print_sampled_stats_time_units(self): """Test proper time unit selection.""" with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats(self.mock_stats, sample_interval_usec=100) result = output.getvalue() # Should use seconds for the header since max time is > 1s self.assertIn("tottime (s)", result) self.assertIn("cumtime (s)", result) # Test with only microsecond-range times micro_stats = mock.MagicMock() micro_stats.stats = { ("file1.py", 10, "func1"): (100, 100, 0.000005, 0.000010, {}), } with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats(micro_stats, sample_interval_usec=100) result = output.getvalue() # Should use microseconds self.assertIn("tottime (μs)", result) self.assertIn("cumtime (μs)", result) def test_print_sampled_stats_summary(self): """Test summary section generation.""" with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, show_summary=True, sample_interval_usec=100, ) result = output.getvalue() # Check summary sections are present self.assertIn("Summary of Interesting Functions:", result) self.assertIn( "Functions with Highest Direct/Cumulative Ratio (Hot Spots):", result, ) self.assertIn( "Functions with Highest Call Frequency (Indirect Calls):", result ) self.assertIn( "Functions with Highest Call Magnification (Cumulative/Direct):", result, ) def test_print_sampled_stats_no_summary(self): """Test disabling summary output.""" with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, show_summary=False, sample_interval_usec=100, ) result = output.getvalue() # Summary should not be present self.assertNotIn("Summary of Interesting Functions:", result) def test_print_sampled_stats_empty_stats(self): """Test with empty stats.""" empty_stats = mock.MagicMock() empty_stats.stats = {} with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats(empty_stats, sample_interval_usec=100) result = output.getvalue() # Should print message about no samples self.assertIn("No samples were collected.", result) def test_print_sampled_stats_sample_percentage_sorting(self): """Test sample percentage sorting options.""" # Add a function with high sample percentage (more direct calls than func3's 200) self.mock_stats.stats[("expensive.py", 60, "expensive_func")] = ( 300, # direct calls (higher than func3's 200) 300, # cumulative calls 1.0, # total time 1.0, # cumulative time {}, ) # Test sort by sample percentage with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, sort=3, sample_interval_usec=100 ) # sample percentage result = output.getvalue() lines = result.strip().split("\n") data_lines = [l for l in lines if ".py" in l and "func" in l] # expensive_func should be first (highest sample percentage) self.assertIn("expensive_func", data_lines[0]) def test_print_sampled_stats_with_recursive_calls(self): """Test print_sampled_stats with recursive calls where nc != cc.""" # Create stats with recursive calls (nc != cc) recursive_stats = mock.MagicMock() recursive_stats.stats = { # (direct_calls, cumulative_calls, tt, ct, callers) - recursive function ("recursive.py", 10, "factorial"): ( 5, # direct_calls 10, # cumulative_calls (appears more times in stack due to recursion) 0.5, 0.6, {}, ), ("normal.py", 20, "normal_func"): ( 3, # direct_calls 3, # cumulative_calls (same as direct for non-recursive) 0.2, 0.2, {}, ), } with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats(recursive_stats, sample_interval_usec=100) result = output.getvalue() # Should display recursive calls as "5/10" format self.assertIn("5/10", result) # nc/cc format for recursive calls self.assertIn("3", result) # just nc for non-recursive calls self.assertIn("factorial", result) self.assertIn("normal_func", result) def test_print_sampled_stats_with_zero_call_counts(self): """Test print_sampled_stats with zero call counts to trigger division protection.""" # Create stats with zero call counts zero_stats = mock.MagicMock() zero_stats.stats = { ("file.py", 10, "zero_calls"): (0, 0, 0.0, 0.0, {}), # Zero calls ("file.py", 20, "normal_func"): ( 5, 5, 0.1, 0.1, {}, ), # Normal function } with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats(zero_stats, sample_interval_usec=100) result = output.getvalue() # Should handle zero call counts gracefully self.assertIn("zero_calls", result) self.assertIn("zero_calls", result) self.assertIn("normal_func", result) def test_print_sampled_stats_sort_by_name(self): """Test sort by function name option.""" with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, sort=-1, sample_interval_usec=100 ) # sort by name result = output.getvalue() lines = result.strip().split("\n") # Find the data lines (skip header and summary) # Data lines start with whitespace and numbers, and contain filename:lineno(function) data_lines = [] for line in lines: # Skip header lines and summary sections if ( line.startswith(" ") and "(" in line and ")" in line and not line.startswith( " 1." ) # Skip summary lines that start with times and not line.startswith( " 0." ) # Skip summary lines that start with times and not "per call" in line # Skip summary lines and not "calls" in line # Skip summary lines and not "total time" in line # Skip summary lines and not "cumulative time" in line ): # Skip summary lines data_lines.append(line) # Extract just the function names for comparison func_names = [] import re for line in data_lines: # Function name is between the last ( and ), accounting for ANSI color codes match = re.search(r"\(([^)]+)\)$", line) if match: func_name = match.group(1) # Remove ANSI color codes func_name = re.sub(r"\x1b\[[0-9;]*m", "", func_name) func_names.append(func_name) # Verify we extracted function names and they are sorted self.assertGreater( len(func_names), 0, "Should have extracted some function names" ) self.assertEqual( func_names, sorted(func_names), f"Function names {func_names} should be sorted alphabetically", ) def test_print_sampled_stats_with_zero_time_functions(self): """Test summary sections with functions that have zero time.""" # Create stats with zero-time functions zero_time_stats = mock.MagicMock() zero_time_stats.stats = { ("file1.py", 10, "zero_time_func"): ( 5, 5, 0.0, 0.0, {}, ), # Zero time ("file2.py", 20, "normal_func"): ( 3, 3, 0.1, 0.1, {}, ), # Normal time } with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( zero_time_stats, show_summary=True, sample_interval_usec=100, ) result = output.getvalue() # Should handle zero-time functions gracefully in summary self.assertIn("Summary of Interesting Functions:", result) self.assertIn("zero_time_func", result) self.assertIn("normal_func", result) def test_print_sampled_stats_with_malformed_qualified_names(self): """Test summary generation with function names that don't contain colons.""" # Create stats with function names that would create malformed qualified names malformed_stats = mock.MagicMock() malformed_stats.stats = { # Function name without clear module separation ("no_colon_func", 10, "func"): (3, 3, 0.1, 0.1, {}), ("", 20, "empty_filename_func"): (2, 2, 0.05, 0.05, {}), ("normal.py", 30, "normal_func"): (5, 5, 0.2, 0.2, {}), } with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( malformed_stats, show_summary=True, sample_interval_usec=100, ) result = output.getvalue() # Should handle malformed names gracefully in summary aggregation self.assertIn("Summary of Interesting Functions:", result) # All function names should appear somewhere in the output self.assertIn("func", result) self.assertIn("empty_filename_func", result) self.assertIn("normal_func", result) def test_print_sampled_stats_with_recursive_call_stats_creation(self): """Test create_stats with recursive call data to trigger total_rec_calls branch.""" collector = PstatsCollector(sample_interval_usec=1000000) # 1 second # Simulate recursive function data where total_rec_calls would be set # We need to manually manipulate the collector result to test this branch collector.result = { ("recursive.py", 10, "factorial"): { "total_rec_calls": 3, # Non-zero recursive calls "direct_calls": 5, "cumulative_calls": 10, }, ("normal.py", 20, "normal_func"): { "total_rec_calls": 0, # Zero recursive calls "direct_calls": 2, "cumulative_calls": 5, }, } collector.create_stats() # Check that recursive calls are handled differently from non-recursive factorial_stats = collector.stats[("recursive.py", 10, "factorial")] normal_stats = collector.stats[("normal.py", 20, "normal_func")] # factorial should use cumulative_calls (10) as nc self.assertEqual( factorial_stats[1], 10 ) # nc should be cumulative_calls self.assertEqual(factorial_stats[0], 5) # cc should be direct_calls # normal_func should use cumulative_calls as nc self.assertEqual(normal_stats[1], 5) # nc should be cumulative_calls self.assertEqual(normal_stats[0], 2) # cc should be direct_calls