diff --git a/src/net/http/client.go b/src/net/http/client.go index dd099bb3165..faac5d4e2e7 100644 --- a/src/net/http/client.go +++ b/src/net/http/client.go @@ -20,7 +20,6 @@ import ( "net/url" "strings" "sync" - "sync/atomic" "time" ) @@ -66,10 +65,15 @@ type Client struct { // // A Timeout of zero means no timeout. // - // The Client's Transport must support the CancelRequest - // method or Client will return errors when attempting to make - // a request with Get, Head, Post, or Do. Client's default - // Transport (DefaultTransport) supports CancelRequest. + // The Client cancels requests to the underlying Transport + // using the Request.Cancel mechanism. Requests passed + // to Client.Do may still set Request.Cancel; both will + // cancel the request. + // + // For compatibility, the Client will also use the deprecated + // CancelRequest method on Transport if found. New + // RoundTripper implementations should use Request.Cancel + // instead of implementing CancelRequest. Timeout time.Duration } @@ -142,13 +146,13 @@ type readClose struct { io.Closer } -func (c *Client) send(req *Request) (*Response, error) { +func (c *Client) send(req *Request, deadline time.Time) (*Response, error) { if c.Jar != nil { for _, cookie := range c.Jar.Cookies(req.URL) { req.AddCookie(cookie) } } - resp, err := send(req, c.transport()) + resp, err := send(req, c.transport(), deadline) if err != nil { return nil, err } @@ -180,13 +184,20 @@ func (c *Client) send(req *Request) (*Response, error) { // Generally Get, Post, or PostForm will be used instead of Do. func (c *Client) Do(req *Request) (resp *Response, err error) { method := valueOrDefault(req.Method, "GET") - if method == "" || method == "GET" || method == "HEAD" { + if method == "GET" || method == "HEAD" { return c.doFollowingRedirects(req, shouldRedirectGet) } if method == "POST" || method == "PUT" { return c.doFollowingRedirects(req, shouldRedirectPost) } - return c.send(req) + return c.send(req, c.deadline()) +} + +func (c *Client) deadline() time.Time { + if c.Timeout > 0 { + return time.Now().Add(c.Timeout) + } + return time.Time{} } func (c *Client) transport() RoundTripper { @@ -198,8 +209,10 @@ func (c *Client) transport() RoundTripper { // send issues an HTTP request. // Caller should close resp.Body when done reading from it. -func send(req *Request, t RoundTripper) (resp *Response, err error) { - if t == nil { +func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error) { + req := ireq // req is either the original request, or a modified fork + + if rt == nil { req.closeBody() return nil, errors.New("http: no Client.Transport or DefaultTransport") } @@ -214,20 +227,39 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) { return nil, errors.New("http: Request.RequestURI can't be set in client requests.") } + // forkReq forks req into a shallow clone of ireq the first + // time it's called. + forkReq := func() { + if ireq == req { + req = new(Request) + *req = *ireq // shallow clone + } + } + // Most the callers of send (Get, Post, et al) don't need // Headers, leaving it uninitialized. We guarantee to the // Transport that this has been initialized, though. if req.Header == nil { + forkReq() req.Header = make(Header) } if u := req.URL.User; u != nil && req.Header.Get("Authorization") == "" { username := u.Username() password, _ := u.Password() + forkReq() + req.Header = cloneHeader(ireq.Header) req.Header.Set("Authorization", "Basic "+basicAuth(username, password)) } - resp, err = t.RoundTrip(req) + + if !deadline.IsZero() { + forkReq() + } + stopTimer, wasCanceled := setRequestCancel(req, rt, deadline) + + resp, err := rt.RoundTrip(req) if err != nil { + stopTimer() if resp != nil { log.Printf("RoundTripper returned a response & error; ignoring response") } @@ -241,9 +273,76 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) { } return nil, err } + if !deadline.IsZero() { + resp.Body = &cancelTimerBody{ + stop: stopTimer, + rc: resp.Body, + reqWasCanceled: wasCanceled, + } + } return resp, nil } +// setRequestCancel sets the Cancel field of req, if deadline is +// non-zero. The RoundTripper's type is used to determine whether the legacy +// CancelRequest behavior should be used. +func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), wasCanceled func() bool) { + if deadline.IsZero() { + return nop, alwaysFalse + } + + initialReqCancel := req.Cancel // the user's original Request.Cancel, if any + + cancel := make(chan struct{}) + req.Cancel = cancel + + wasCanceled = func() bool { + select { + case <-cancel: + return true + default: + return false + } + } + + doCancel := func() { + // The new way: + close(cancel) + + // The legacy compatibility way, used only + // for RoundTripper implementations written + // before Go 1.5 or Go 1.6. + type canceler interface { + CancelRequest(*Request) + } + switch v := rt.(type) { + case *Transport, *http2Transport: + // Do nothing. The net/http package's transports + // support the new Request.Cancel channel + case canceler: + v.CancelRequest(req) + } + } + + stopTimerCh := make(chan struct{}) + var once sync.Once + stopTimer = func() { once.Do(func() { close(stopTimerCh) }) } + + timer := time.NewTimer(deadline.Sub(time.Now())) + go func() { + select { + case <-initialReqCancel: + doCancel() + case <-timer.C: + doCancel() + case <-stopTimerCh: + timer.Stop() + } + }() + + return stopTimer, wasCanceled +} + // See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt // "To receive authorization, the client sends the userid and password, // separated by a single colon (":") character, within a base64 @@ -338,28 +437,8 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo return nil, errors.New("http: nil Request.URL") } - var reqmu sync.Mutex // guards req req := ireq - - var timer *time.Timer - var atomicWasCanceled int32 // atomic bool (1 or 0) - var wasCanceled = alwaysFalse - if c.Timeout > 0 { - wasCanceled = func() bool { return atomic.LoadInt32(&atomicWasCanceled) != 0 } - type canceler interface { - CancelRequest(*Request) - } - tr, ok := c.transport().(canceler) - if !ok { - return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport()) - } - timer = time.AfterFunc(c.Timeout, func() { - atomic.StoreInt32(&atomicWasCanceled, 1) - reqmu.Lock() - defer reqmu.Unlock() - tr.CancelRequest(req) - }) - } + deadline := c.deadline() urlStr := "" // next relative or absolute URL to fetch (after first request) redirectFailed := false @@ -388,14 +467,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo break } } - reqmu.Lock() req = nreq - reqmu.Unlock() } urlStr = req.URL.String() - if resp, err = c.send(req); err != nil { - if wasCanceled() { + if resp, err = c.send(req, deadline); err != nil { + if !deadline.IsZero() && !time.Now().Before(deadline) { err = &httpError{ err: err.Error() + " (Client.Timeout exceeded while awaiting headers)", timeout: true, @@ -420,22 +497,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo via = append(via, req) continue } - if timer != nil { - resp.Body = &cancelTimerBody{ - t: timer, - rc: resp.Body, - reqWasCanceled: wasCanceled, - } - } return resp, nil } - method := ireq.Method - if method == "" { - method = "GET" - } + method := valueOrDefault(ireq.Method, "GET") urlErr := &url.Error{ - Op: method[0:1] + strings.ToLower(method[1:]), + Op: method[:1] + strings.ToLower(method[1:]), URL: urlStr, Err: err, } @@ -548,30 +615,35 @@ func (c *Client) Head(url string) (resp *Response, err error) { } // cancelTimerBody is an io.ReadCloser that wraps rc with two features: -// 1) on Read EOF or Close, the timer t is Stopped, +// 1) on Read error or close, the stop func is called. // 2) On Read failure, if reqWasCanceled is true, the error is wrapped and // marked as net.Error that hit its timeout. type cancelTimerBody struct { - t *time.Timer + stop func() // stops the time.Timer waiting to cancel the request rc io.ReadCloser reqWasCanceled func() bool } func (b *cancelTimerBody) Read(p []byte) (n int, err error) { n, err = b.rc.Read(p) + if err == nil { + return n, nil + } + b.stop() if err == io.EOF { - b.t.Stop() - } else if err != nil && b.reqWasCanceled() { - return n, &httpError{ + return n, err + } + if b.reqWasCanceled() { + err = &httpError{ err: err.Error() + " (Client.Timeout exceeded while reading body)", timeout: true, } } - return + return n, err } func (b *cancelTimerBody) Close() error { err := b.rc.Close() - b.t.Stop() + b.stop() return err } diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go index 9d3444c89af..cfad71e0295 100644 --- a/src/net/http/client_test.go +++ b/src/net/http/client_test.go @@ -922,10 +922,7 @@ func TestBasicAuthHeadersPreserved(t *testing.T) { } func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) } -func TestClientTimeout_h2(t *testing.T) { - t.Skip("skipping in http2 mode; golang.org/issue/13540") - testClientTimeout(t, h2Mode) -} +func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) } func testClientTimeout(t *testing.T, h2 bool) { if testing.Short() { @@ -999,10 +996,7 @@ func testClientTimeout(t *testing.T, h2 bool) { } func TestClientTimeout_Headers_h1(t *testing.T) { testClientTimeout_Headers(t, h1Mode) } -func TestClientTimeout_Headers_h2(t *testing.T) { - t.Skip("skipping in http2 mode; golang.org/issue/13540") - testClientTimeout_Headers(t, h2Mode) -} +func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h2Mode) } // Client.Timeout firing before getting to the body func testClientTimeout_Headers(t *testing.T, h2 bool) { diff --git a/src/net/http/header.go b/src/net/http/header.go index d847b131184..049f32f27dc 100644 --- a/src/net/http/header.go +++ b/src/net/http/header.go @@ -211,3 +211,13 @@ func hasToken(v, token string) bool { func isTokenBoundary(b byte) bool { return b == ' ' || b == ',' || b == '\t' } + +func cloneHeader(h Header) Header { + h2 := make(Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 67b29150411..8d9e58cc2ef 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -372,6 +372,8 @@ func (t *Transport) CloseIdleConnections() { // CancelRequest cancels an in-flight request by closing its connection. // CancelRequest should only be called after RoundTrip has returned. +// +// Deprecated: Use Request.Cancel instead. func (t *Transport) CancelRequest(req *Request) { t.reqMu.Lock() cancel := t.reqCanceler[req]