net/http: add connections back that haven't been canceled

Issue #41600 fixed the issue when a second request canceled a connection
while the first request was still in roundTrip.
This uncovered a second issue where a request was being canceled (in
roundtrip) but the connection was put back into the idle pool for a
subsequent request.
The fix is the similar except its now in readLoop instead of roundTrip.
A persistent connection is only added back if it successfully removed
the cancel function; otherwise we know the roundTrip has started
cancelRequest.

Fixes #42942

Change-Id: Ia56add20880ccd0c1ab812d380d8628e45f6f44c
Reviewed-on: https://go-review.googlesource.com/c/go/+/274973
Trust: Dmitri Shuralyov <dmitshur@golang.org>
Trust: Damien Neil <dneil@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Michael Fraenkel 2020-12-02 17:07:27 -07:00 committed by Dmitri Shuralyov
parent 6fa06d960b
commit 854a2f8e01

View file

@ -784,7 +784,8 @@ func (t *Transport) CancelRequest(req *Request) {
} }
// Cancel an in-flight request, recording the error value. // Cancel an in-flight request, recording the error value.
func (t *Transport) cancelRequest(key cancelKey, err error) { // Returns whether the request was canceled.
func (t *Transport) cancelRequest(key cancelKey, err error) bool {
t.reqMu.Lock() t.reqMu.Lock()
cancel := t.reqCanceler[key] cancel := t.reqCanceler[key]
delete(t.reqCanceler, key) delete(t.reqCanceler, key)
@ -792,6 +793,8 @@ func (t *Transport) cancelRequest(key cancelKey, err error) {
if cancel != nil { if cancel != nil {
cancel(err) cancel(err)
} }
return cancel != nil
} }
// //
@ -2127,18 +2130,17 @@ func (pc *persistConn) readLoop() {
} }
if !hasBody || bodyWritable { if !hasBody || bodyWritable {
pc.t.setReqCanceler(rc.cancelKey, nil) replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil)
// Put the idle conn back into the pool before we send the response // Put the idle conn back into the pool before we send the response
// so if they process it quickly and make another request, they'll // so if they process it quickly and make another request, they'll
// get this same conn. But we use the unbuffered channel 'rc' // get this same conn. But we use the unbuffered channel 'rc'
// to guarantee that persistConn.roundTrip got out of its select // to guarantee that persistConn.roundTrip got out of its select
// potentially waiting for this persistConn to close. // potentially waiting for this persistConn to close.
// but after
alive = alive && alive = alive &&
!pc.sawEOF && !pc.sawEOF &&
pc.wroteRequest() && pc.wroteRequest() &&
tryPutIdleConn(trace) replaced && tryPutIdleConn(trace)
if bodyWritable { if bodyWritable {
closeErr = errCallerOwnsConn closeErr = errCallerOwnsConn
@ -2200,12 +2202,12 @@ func (pc *persistConn) readLoop() {
// reading the response body. (or for cancellation or death) // reading the response body. (or for cancellation or death)
select { select {
case bodyEOF := <-waitForBodyRead: case bodyEOF := <-waitForBodyRead:
pc.t.setReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
alive = alive && alive = alive &&
bodyEOF && bodyEOF &&
!pc.sawEOF && !pc.sawEOF &&
pc.wroteRequest() && pc.wroteRequest() &&
tryPutIdleConn(trace) replaced && tryPutIdleConn(trace)
if bodyEOF { if bodyEOF {
eofc <- struct{}{} eofc <- struct{}{}
} }
@ -2600,6 +2602,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
cancelChan := req.Request.Cancel cancelChan := req.Request.Cancel
ctxDoneChan := req.Context().Done() ctxDoneChan := req.Context().Done()
pcClosed := pc.closech pcClosed := pc.closech
canceled := false
for { for {
testHookWaitResLoop() testHookWaitResLoop()
select { select {
@ -2621,8 +2624,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
} }
case <-pcClosed: case <-pcClosed:
pcClosed = nil pcClosed = nil
// check if we are still using the connection if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) {
if pc.t.replaceReqCanceler(req.cancelKey, nil) {
if debugRoundTrip { if debugRoundTrip {
req.logf("closech recv: %T %#v", pc.closed, pc.closed) req.logf("closech recv: %T %#v", pc.closed, pc.closed)
} }
@ -2646,10 +2648,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
} }
return re.res, nil return re.res, nil
case <-cancelChan: case <-cancelChan:
pc.t.cancelRequest(req.cancelKey, errRequestCanceled) canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
cancelChan = nil cancelChan = nil
case <-ctxDoneChan: case <-ctxDoneChan:
pc.t.cancelRequest(req.cancelKey, req.Context().Err()) canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err())
cancelChan = nil cancelChan = nil
ctxDoneChan = nil ctxDoneChan = nil
} }