mirror of
https://github.com/golang/go.git
synced 2025-10-19 19:13:18 +00:00
net/http: reduce memory usage when hijacking
Previously, Hijack allocated a new write buffer and the existing
connection write buffer used an extra 4KiB of memory until the handler
finished and the "conn" was garbage collected. Now, hijack re-uses the
existing write buffer and re-attaches it to the raw connection to avoid
referencing the net/http "conn" after returning.
After a handler that hijacked exited, the "conn" reference in
"connReader" will now be unset. This allows all of the "conn",
"response" and "Request" to get garbage collected.
Overall, this is reducing the memory usage by 43% or 6.7KiB per hijacked
connection (see BenchmarkServerHijackMemoryUsage in an earlier revision
of the CL).
CloseNotify will continue to work _before_ the handler has exited
(i.e. while the "conn" is still referenced in "connReader"). This aligns
with the documentation of CloseNotifier:
> After the Handler has returned, there is no guarantee that the channel
> receives a value.
goos: linux
goarch: amd64
pkg: net/http
cpu: Intel(R) Core(TM) i7-8550U CPU @ 1.80GHz
│ before │ after │
│ sec/op │ sec/op vs base │
ServerHijack-8 42.59µ ± 8% 39.47µ ± 16% ~ (p=0.481 n=10)
│ before │ after │
│ B/op │ B/op vs base │
ServerHijack-8 16.12Ki ± 0% 12.06Ki ± 0% -25.16% (p=0.000 n=10)
│ before │ after │
│ allocs/op │ allocs/op vs base │
ServerHijack-8 51.00 ± 0% 49.00 ± 0% -3.92% (p=0.000 n=10)
Change-Id: I20a37ee314ed0d47463a4657d712154e78e48138
GitHub-Last-Rev: 80f09dfa27
GitHub-Pull-Request: golang/go#70756
Reviewed-on: https://go-review.googlesource.com/c/go/+/634855
Reviewed-by: Sean Liao <sean@liao.dev>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
Auto-Submit: Sean Liao <sean@liao.dev>
This commit is contained in:
parent
cac276f81a
commit
cfb78d63bd
2 changed files with 79 additions and 20 deletions
|
@ -6139,6 +6139,50 @@ func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
|
|||
<-done
|
||||
}
|
||||
|
||||
// Test that the bufio.Reader returned by Hijack yields the entire body.
|
||||
func TestServerHijackGetsFullBody(t *testing.T) {
|
||||
run(t, testServerHijackGetsFullBody, []testMode{http1Mode})
|
||||
}
|
||||
func testServerHijackGetsFullBody(t *testing.T, mode testMode) {
|
||||
if runtime.GOOS == "plan9" {
|
||||
t.Skip("skipping test; see https://golang.org/issue/18657")
|
||||
}
|
||||
done := make(chan struct{})
|
||||
needle := strings.Repeat("x", 100*1024) // assume: larger than net/http bufio size
|
||||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
defer close(done)
|
||||
|
||||
conn, buf, err := w.(Hijacker).Hijack()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
got := make([]byte, len(needle))
|
||||
n, err := io.ReadFull(buf.Reader, got)
|
||||
if n != len(needle) || string(got) != needle || err != nil {
|
||||
t.Errorf("Peek = %q, %v; want 'x'*4096, nil", got, err)
|
||||
}
|
||||
})).ts
|
||||
|
||||
cn, err := net.Dial("tcp", ts.Listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cn.Close()
|
||||
buf := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
|
||||
buf = append(buf, []byte(needle)...)
|
||||
if _, err := cn.Write(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
<-done
|
||||
}
|
||||
|
||||
// Like TestServerHijackGetsBackgroundByte above but sending a
|
||||
// immediate 1MB of data to the server to fill up the server's 4KB
|
||||
// buffer.
|
||||
|
|
|
@ -324,12 +324,14 @@ func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
|
|||
rwc = c.rwc
|
||||
rwc.SetDeadline(time.Time{})
|
||||
|
||||
buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc))
|
||||
if c.r.hasByte {
|
||||
if _, err := c.bufr.Peek(c.bufr.Buffered() + 1); err != nil {
|
||||
return nil, nil, fmt.Errorf("unexpected Peek failure reading buffered byte: %v", err)
|
||||
}
|
||||
}
|
||||
c.bufw.Reset(rwc)
|
||||
buf = bufio.NewReadWriter(c.bufr, c.bufw)
|
||||
|
||||
c.setState(rwc, StateHijacked, runHooks)
|
||||
return
|
||||
}
|
||||
|
@ -652,10 +654,13 @@ type readResult struct {
|
|||
// read sizes) with support for selectively keeping an io.Reader.Read
|
||||
// call blocked in a background goroutine to wait for activity and
|
||||
// trigger a CloseNotifier channel.
|
||||
// After a Handler has hijacked the conn and exited, connReader behaves like a
|
||||
// proxy for the net.Conn and the aforementioned behavior is bypassed.
|
||||
type connReader struct {
|
||||
conn *conn
|
||||
rwc net.Conn // rwc is the underlying network connection.
|
||||
|
||||
mu sync.Mutex // guards following
|
||||
conn *conn // conn is nil after handler exit.
|
||||
hasByte bool
|
||||
byteBuf [1]byte
|
||||
cond *sync.Cond
|
||||
|
@ -673,6 +678,12 @@ func (cr *connReader) lock() {
|
|||
|
||||
func (cr *connReader) unlock() { cr.mu.Unlock() }
|
||||
|
||||
func (cr *connReader) releaseConn() {
|
||||
cr.lock()
|
||||
defer cr.unlock()
|
||||
cr.conn = nil
|
||||
}
|
||||
|
||||
func (cr *connReader) startBackgroundRead() {
|
||||
cr.lock()
|
||||
defer cr.unlock()
|
||||
|
@ -683,12 +694,12 @@ func (cr *connReader) startBackgroundRead() {
|
|||
return
|
||||
}
|
||||
cr.inRead = true
|
||||
cr.conn.rwc.SetReadDeadline(time.Time{})
|
||||
cr.rwc.SetReadDeadline(time.Time{})
|
||||
go cr.backgroundRead()
|
||||
}
|
||||
|
||||
func (cr *connReader) backgroundRead() {
|
||||
n, err := cr.conn.rwc.Read(cr.byteBuf[:])
|
||||
n, err := cr.rwc.Read(cr.byteBuf[:])
|
||||
cr.lock()
|
||||
if n == 1 {
|
||||
cr.hasByte = true
|
||||
|
@ -719,7 +730,7 @@ func (cr *connReader) backgroundRead() {
|
|||
// Ignore this error. It's the expected error from
|
||||
// another goroutine calling abortPendingRead.
|
||||
} else if err != nil {
|
||||
cr.handleReadError(err)
|
||||
cr.handleReadErrorLocked(err)
|
||||
}
|
||||
cr.aborted = false
|
||||
cr.inRead = false
|
||||
|
@ -734,18 +745,18 @@ func (cr *connReader) abortPendingRead() {
|
|||
return
|
||||
}
|
||||
cr.aborted = true
|
||||
cr.conn.rwc.SetReadDeadline(aLongTimeAgo)
|
||||
cr.rwc.SetReadDeadline(aLongTimeAgo)
|
||||
for cr.inRead {
|
||||
cr.cond.Wait()
|
||||
}
|
||||
cr.conn.rwc.SetReadDeadline(time.Time{})
|
||||
cr.rwc.SetReadDeadline(time.Time{})
|
||||
}
|
||||
|
||||
func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain }
|
||||
func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 }
|
||||
func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 }
|
||||
|
||||
// handleReadError is called whenever a Read from the client returns a
|
||||
// handleReadErrorLocked is called whenever a Read from the client returns a
|
||||
// non-nil error.
|
||||
//
|
||||
// The provided non-nil err is almost always io.EOF or a "use of
|
||||
|
@ -754,14 +765,12 @@ func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 }
|
|||
// development. Any error means the connection is dead and we should
|
||||
// down its context.
|
||||
//
|
||||
// It may be called from multiple goroutines.
|
||||
func (cr *connReader) handleReadError(_ error) {
|
||||
// The caller must hold connReader.mu.
|
||||
func (cr *connReader) handleReadErrorLocked(_ error) {
|
||||
if cr.conn == nil {
|
||||
return
|
||||
}
|
||||
cr.conn.cancelCtx()
|
||||
cr.closeNotify()
|
||||
}
|
||||
|
||||
// may be called from multiple goroutines.
|
||||
func (cr *connReader) closeNotify() {
|
||||
if res := cr.conn.curReq.Load(); res != nil {
|
||||
res.closeNotify()
|
||||
}
|
||||
|
@ -769,9 +778,14 @@ func (cr *connReader) closeNotify() {
|
|||
|
||||
func (cr *connReader) Read(p []byte) (n int, err error) {
|
||||
cr.lock()
|
||||
if cr.inRead {
|
||||
if cr.conn == nil {
|
||||
cr.unlock()
|
||||
if cr.conn.hijacked() {
|
||||
return cr.rwc.Read(p)
|
||||
}
|
||||
if cr.inRead {
|
||||
hijacked := cr.conn.hijacked()
|
||||
cr.unlock()
|
||||
if hijacked {
|
||||
panic("invalid Body.Read call. After hijacked, the original Request must not be used")
|
||||
}
|
||||
panic("invalid concurrent Body.Read call")
|
||||
|
@ -795,12 +809,12 @@ func (cr *connReader) Read(p []byte) (n int, err error) {
|
|||
}
|
||||
cr.inRead = true
|
||||
cr.unlock()
|
||||
n, err = cr.conn.rwc.Read(p)
|
||||
n, err = cr.rwc.Read(p)
|
||||
|
||||
cr.lock()
|
||||
cr.inRead = false
|
||||
if err != nil {
|
||||
cr.handleReadError(err)
|
||||
cr.handleReadErrorLocked(err)
|
||||
}
|
||||
cr.remain -= int64(n)
|
||||
cr.unlock()
|
||||
|
@ -1986,7 +2000,7 @@ func (c *conn) serve(ctx context.Context) {
|
|||
c.cancelCtx = cancelCtx
|
||||
defer cancelCtx()
|
||||
|
||||
c.r = &connReader{conn: c}
|
||||
c.r = &connReader{conn: c, rwc: c.rwc}
|
||||
c.bufr = newBufioReader(c.r)
|
||||
c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
|
||||
|
||||
|
@ -2083,6 +2097,7 @@ func (c *conn) serve(ctx context.Context) {
|
|||
inFlightResponse = nil
|
||||
w.cancelCtx()
|
||||
if c.hijacked() {
|
||||
c.r.releaseConn()
|
||||
return
|
||||
}
|
||||
w.finishRequest()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue