Refactor run_pipeline() to use multiplexed I/O

Add _communicate_streams() helper function that properly multiplexes
read/write operations to prevent pipe buffer deadlocks. The helper
uses selectors on POSIX and threads on Windows, similar to
Popen.communicate().

This fixes potential deadlocks when large amounts of data flow through
the pipeline and significantly improves performance.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Gregory P. Smith using claude.ai/code 2025-11-28 01:12:12 +00:00
parent 4feb2a80e5
commit 2a11d4bf53
No known key found for this signature in database

View file

@ -320,6 +320,220 @@ def _cleanup():
DEVNULL = -3
# Helper function for multiplexed I/O, used by run_pipeline()
def _remaining_time_helper(endtime):
"""Calculate remaining time until deadline."""
if endtime is None:
return None
return endtime - _time()
def _communicate_streams(stdin=None, input_data=None, read_streams=None,
timeout=None, cmd_for_timeout=None):
"""
Multiplex I/O: write input_data to stdin, read from read_streams.
Works with both file objects and raw file descriptors.
All I/O is done in binary mode; caller handles text encoding.
Args:
stdin: Writable file object for input, or None
input_data: Bytes to write to stdin, or None
read_streams: List of readable file objects or raw fds to read from
timeout: Timeout in seconds, or None for no timeout
cmd_for_timeout: Value to use for TimeoutExpired.cmd
Returns:
Dict mapping each item in read_streams to its bytes data
Raises:
TimeoutExpired: If timeout expires (with partial data)
"""
if timeout is not None:
endtime = _time() + timeout
else:
endtime = None
read_streams = read_streams or []
if _mswindows:
return _communicate_streams_windows(
stdin, input_data, read_streams, endtime, timeout, cmd_for_timeout)
else:
return _communicate_streams_posix(
stdin, input_data, read_streams, endtime, timeout, cmd_for_timeout)
if _mswindows:
def _reader_thread_func(fh, buffer):
"""Thread function to read from a file handle into a buffer list."""
try:
buffer.append(fh.read())
except OSError:
buffer.append(b'')
def _communicate_streams_windows(stdin, input_data, read_streams,
endtime, orig_timeout, cmd_for_timeout):
"""Windows implementation using threads."""
threads = []
buffers = {}
fds_to_close = []
# Start reader threads
for stream in read_streams:
buf = []
buffers[stream] = buf
# Wrap raw fds in file objects
if isinstance(stream, int):
fobj = os.fdopen(os.dup(stream), 'rb')
fds_to_close.append(stream)
else:
fobj = stream
t = threading.Thread(target=_reader_thread_func, args=(fobj, buf))
t.daemon = True
t.start()
threads.append((stream, t, fobj))
# Write stdin
if stdin and input_data:
try:
stdin.write(input_data)
except BrokenPipeError:
pass
except OSError as exc:
if exc.errno != errno.EINVAL:
raise
if stdin:
try:
stdin.close()
except BrokenPipeError:
pass
except OSError as exc:
if exc.errno != errno.EINVAL:
raise
# Join threads with timeout
for stream, t, fobj in threads:
remaining = _remaining_time_helper(endtime)
if remaining is not None and remaining < 0:
remaining = 0
t.join(remaining)
if t.is_alive():
# Collect partial results
results = {s: (b[0] if b else b'') for s, b in buffers.items()}
raise TimeoutExpired(
cmd_for_timeout, orig_timeout,
output=results.get(read_streams[0]) if read_streams else None)
# Close any raw fds we duped
for fd in fds_to_close:
try:
os.close(fd)
except OSError:
pass
# Collect results
return {stream: (buf[0] if buf else b'') for stream, buf in buffers.items()}
else:
def _communicate_streams_posix(stdin, input_data, read_streams,
endtime, orig_timeout, cmd_for_timeout):
"""POSIX implementation using selectors."""
# Normalize read_streams: build mapping of fd -> (original_key, chunks)
fd_info = {} # fd -> (original_stream, chunks_list)
for stream in read_streams:
if isinstance(stream, int):
fd = stream
else:
fd = stream.fileno()
fd_info[fd] = (stream, [])
# Prepare stdin
stdin_fd = None
if stdin:
try:
stdin.flush()
except BrokenPipeError:
pass
if input_data:
stdin_fd = stdin.fileno()
else:
try:
stdin.close()
except BrokenPipeError:
pass
# Prepare input data
input_offset = 0
input_view = memoryview(input_data) if input_data else None
with _PopenSelector() as selector:
if stdin_fd is not None and input_data:
selector.register(stdin_fd, selectors.EVENT_WRITE)
for fd in fd_info:
selector.register(fd, selectors.EVENT_READ)
while selector.get_map():
remaining = _remaining_time_helper(endtime)
if remaining is not None and remaining < 0:
# Timed out - collect partial results
results = {orig: b''.join(chunks)
for fd, (orig, chunks) in fd_info.items()}
raise TimeoutExpired(
cmd_for_timeout, orig_timeout,
output=results.get(read_streams[0]) if read_streams else None)
ready = selector.select(remaining)
# Check timeout after select
if endtime is not None and _time() > endtime:
results = {orig: b''.join(chunks)
for fd, (orig, chunks) in fd_info.items()}
raise TimeoutExpired(
cmd_for_timeout, orig_timeout,
output=results.get(read_streams[0]) if read_streams else None)
for key, events in ready:
if key.fd == stdin_fd:
# Write chunk to stdin
chunk = input_view[input_offset:input_offset + _PIPE_BUF]
try:
input_offset += os.write(key.fd, chunk)
except BrokenPipeError:
selector.unregister(key.fd)
try:
stdin.close()
except BrokenPipeError:
pass
else:
if input_offset >= len(input_data):
selector.unregister(key.fd)
try:
stdin.close()
except BrokenPipeError:
pass
elif key.fd in fd_info:
# Read chunk from output stream
data = os.read(key.fd, 32768)
if not data:
selector.unregister(key.fd)
else:
fd_info[key.fd][1].append(data)
# Build results: map original stream keys to joined data
results = {}
for fd, (orig_stream, chunks) in fd_info.items():
results[orig_stream] = b''.join(chunks)
# Close file objects (but not raw fds - caller manages those)
if not isinstance(orig_stream, int):
try:
orig_stream.close()
except OSError:
pass
return results
# XXX This function is only used by multiprocessing and the test suite,
# but it's here so that it can be imported when Python is compiled without
# threads.
@ -781,54 +995,70 @@ def run_pipeline(*commands, input=None, capture_output=False, timeout=None,
first_proc = processes[0]
last_proc = processes[-1]
# Handle communication with timeout
start_time = _time() if timeout is not None else None
# Write input to first process if provided
if input is not None and first_proc.stdin is not None:
try:
first_proc.stdin.write(input)
except BrokenPipeError:
pass # First process may have exited early
finally:
first_proc.stdin.close()
# Calculate deadline for timeout (used throughout)
if timeout is not None:
endtime = _time() + timeout
else:
endtime = None
# Determine if we're in text mode
text_mode = kwargs.get('text') or kwargs.get('encoding') or kwargs.get('errors')
encoding = kwargs.get('encoding')
errors_param = kwargs.get('errors', 'strict')
if text_mode and encoding is None:
encoding = locale.getencoding()
# Read output from the last process
stdout = None
stderr = None
# Encode input if in text mode
input_data = input
if input_data is not None and text_mode:
input_data = input_data.encode(encoding, errors_param)
# Read stdout if we created a pipe for it (capture_output or stdout=PIPE)
# Build list of streams to read from
read_streams = []
if last_proc.stdout is not None:
stdout = last_proc.stdout.read()
# Read stderr from the shared pipe
read_streams.append(last_proc.stdout)
if stderr_read_fd is not None:
stderr = os.read(stderr_read_fd, 1024 * 1024 * 10) # Up to 10MB
# Keep reading until EOF
while True:
chunk = os.read(stderr_read_fd, 65536)
if not chunk:
break
stderr += chunk
read_streams.append(stderr_read_fd)
# Calculate remaining timeout
def remaining_timeout():
if timeout is None:
return None
elapsed = _time() - start_time
remaining = timeout - elapsed
if remaining <= 0:
raise TimeoutExpired(commands, timeout, stdout, stderr)
return remaining
# Use multiplexed I/O to handle stdin/stdout/stderr concurrently
# This avoids deadlocks from pipe buffer limits
stdin_stream = first_proc.stdin if input is not None else None
# Wait for all processes to complete
try:
results = _communicate_streams(
stdin=stdin_stream,
input_data=input_data,
read_streams=read_streams,
timeout=_remaining_time_helper(endtime),
cmd_for_timeout=commands,
)
except TimeoutExpired:
# Kill all processes on timeout
for p in processes:
if p.poll() is None:
p.kill()
for p in processes:
p.wait()
raise
# Extract results
stdout = results.get(last_proc.stdout)
stderr = results.get(stderr_read_fd)
# Decode stdout if in text mode (Popen text mode only applies to
# streams it creates, but we read via _communicate_streams which
# always returns bytes)
if text_mode and stdout is not None:
stdout = stdout.decode(encoding, errors_param)
if text_mode and stderr is not None:
stderr = stderr.decode(encoding, errors_param)
# Wait for all processes to complete (use remaining time from deadline)
returncodes = []
for proc in processes:
try:
proc.wait(timeout=remaining_timeout())
remaining = _remaining_time_helper(endtime)
proc.wait(timeout=remaining)
except TimeoutExpired:
# Kill all processes on timeout
for p in processes:
@ -839,16 +1069,6 @@ def remaining_timeout():
raise TimeoutExpired(commands, timeout, stdout, stderr)
returncodes.append(proc.returncode)
# Handle text mode conversion for stderr (stdout is already handled
# by Popen when text=True). stderr is always read as bytes since
# we use os.pipe() directly.
if text_mode and stderr is not None:
encoding = kwargs.get('encoding')
errors = kwargs.get('errors', 'strict')
if encoding is None:
encoding = locale.getencoding()
stderr = stderr.decode(encoding, errors)
result = PipelineResult(commands, returncodes, stdout, stderr)
if check and any(rc != 0 for rc in returncodes):
@ -867,7 +1087,7 @@ def remaining_timeout():
proc.stdin.close()
if proc.stdout and not proc.stdout.closed:
proc.stdout.close()
# Close stderr pipe file descriptors
# Close stderr pipe file descriptor
if stderr_read_fd is not None:
try:
os.close(stderr_read_fd)