diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index e21af8b1595..f454dcdbed3 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -7166,3 +7166,79 @@ func TestError(t *testing.T) { t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff") } } + +func TestServerReadAfterWriteHeader100Continue(t *testing.T) { + run(t, testServerReadAfterWriteHeader100Continue) +} +func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) { + body := []byte("body") + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(200) + NewResponseController(w).Flush() + io.ReadAll(r.Body) + w.Write(body) + })) + + req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body")) + req.Header.Set("Expect", "100-continue") + res, err := cst.c.Do(req) + if err != nil { + t.Fatalf("Get(%q) = %v", cst.ts.URL, err) + } + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("io.ReadAll(res.Body) = %v", err) + } + if !bytes.Equal(got, body) { + t.Fatalf("response body = %q, want %q", got, body) + } +} + +func TestServerReadAfterHandlerDone100Continue(t *testing.T) { + run(t, testServerReadAfterHandlerDone100Continue) +} +func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) { + readyc := make(chan struct{}) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + go func() { + <-readyc + io.ReadAll(r.Body) + <-readyc + }() + })) + + req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body")) + req.Header.Set("Expect", "100-continue") + res, err := cst.c.Do(req) + if err != nil { + t.Fatalf("Get(%q) = %v", cst.ts.URL, err) + } + res.Body.Close() + readyc <- struct{}{} // server starts reading from the request body + readyc <- struct{}{} // server finishes reading from the request body +} + +func TestServerReadAfterHandlerAbort100Continue(t *testing.T) { + run(t, testServerReadAfterHandlerAbort100Continue) +} +func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) { + readyc := make(chan struct{}) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + go func() { + <-readyc + io.ReadAll(r.Body) + <-readyc + }() + panic(ErrAbortHandler) + })) + + req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body")) + req.Header.Set("Expect", "100-continue") + res, err := cst.c.Do(req) + if err == nil { + res.Body.Close() + } + readyc <- struct{}{} // server starts reading from the request body + readyc <- struct{}{} // server finishes reading from the request body +} diff --git a/src/net/http/server.go b/src/net/http/server.go index a50b20b7da9..b76c8695671 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -425,7 +425,6 @@ type response struct { reqBody io.ReadCloser cancelCtx context.CancelFunc // when ServeHTTP exits wroteHeader bool // a non-1xx header has been (logically) written - wroteContinue bool // 100 Continue response was written wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive" wantsClose bool // HTTP request has Connection "close" @@ -436,8 +435,8 @@ type response struct { // These two fields together synchronize the body reader (the // expectContinueReader, which wants to write 100 Continue) // against the main writer. - canWriteContinue atomic.Bool writeContinueMu sync.Mutex + canWriteContinue atomic.Bool w *bufio.Writer // buffers output in chunks to chunkWriter cw chunkWriter @@ -565,6 +564,14 @@ func (w *response) requestTooLarge() { } } +// disableWriteContinue stops Request.Body.Read from sending an automatic 100-Continue. +// If a 100-Continue is being written, it waits for it to complete before continuing. +func (w *response) disableWriteContinue() { + w.writeContinueMu.Lock() + w.canWriteContinue.Store(false) + w.writeContinueMu.Unlock() +} + // writerOnly hides an io.Writer value's optional ReadFrom method // from io.Copy. type writerOnly struct { @@ -917,8 +924,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { return 0, ErrBodyReadAfterClose } w := ecr.resp - if !w.wroteContinue && w.canWriteContinue.Load() && !w.conn.hijacked() { - w.wroteContinue = true + if w.canWriteContinue.Load() { w.writeContinueMu.Lock() if w.canWriteContinue.Load() { w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") @@ -1159,18 +1165,17 @@ func (w *response) WriteHeader(code int) { } checkWriteHeaderCode(code) + if code < 101 || code > 199 { + // Sending a 100 Continue or any non-1xx header disables the + // automatically-sent 100 Continue from Request.Body.Read. + w.disableWriteContinue() + } + // Handle informational headers. // // We shouldn't send any further headers after 101 Switching Protocols, // so it takes the non-informational path. if code >= 100 && code <= 199 && code != StatusSwitchingProtocols { - // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read() - if code == 100 && w.canWriteContinue.Load() { - w.writeContinueMu.Lock() - w.canWriteContinue.Store(false) - w.writeContinueMu.Unlock() - } - writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:]) // Per RFC 8297 we must not clear the current header map @@ -1378,14 +1383,20 @@ func (cw *chunkWriter) writeHeader(p []byte) { // // If full duplex mode has been enabled with ResponseController.EnableFullDuplex, // then leave the request body alone. + // + // We don't take this path when w.closeAfterReply is set. + // We may not need to consume the request to get ready for the next one + // (since we're closing the conn), but a client which sends a full request + // before reading a response may deadlock in this case. + // This behavior has been present since CL 5268043 (2011), however, + // so it doesn't seem to be causing problems. if w.req.ContentLength != 0 && !w.closeAfterReply && !w.fullDuplex { var discard, tooBig bool switch bdy := w.req.Body.(type) { case *expectContinueReader: - if bdy.resp.wroteContinue { - discard = true - } + // We only get here if we have already fully consumed the request body + // (see above). case *body: bdy.mu.Lock() switch { @@ -1626,13 +1637,8 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er } if w.canWriteContinue.Load() { - // Body reader wants to write 100 Continue but hasn't yet. - // Tell it not to. The store must be done while holding the lock - // because the lock makes sure that there is not an active write - // this very moment. - w.writeContinueMu.Lock() - w.canWriteContinue.Store(false) - w.writeContinueMu.Unlock() + // Body reader wants to write 100 Continue but hasn't yet. Tell it not to. + w.disableWriteContinue() } if !w.wroteHeader { @@ -1900,6 +1906,7 @@ func (c *conn) serve(ctx context.Context) { } if inFlightResponse != nil { inFlightResponse.cancelCtx() + inFlightResponse.disableWriteContinue() } if !c.hijacked() { if inFlightResponse != nil { @@ -2106,6 +2113,7 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if w.handlerDone.Load() { panic("net/http: Hijack called after ServeHTTP finished") } + w.disableWriteContinue() if w.wroteHeader { w.cw.flush() }