mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
net/http/httputil: make ReverseProxy automatically proxy WebSocket requests
Fixes #26937 Change-Id: I6cdc1bad4cf476cd2ea1462b53444eccd8841e14 Reviewed-on: https://go-review.googlesource.com/c/146437 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
This commit is contained in:
parent
de578dcdd6
commit
ee55f0856a
4 changed files with 161 additions and 12 deletions
|
|
@ -436,7 +436,7 @@ var pkgDeps = map[string][]string{
|
||||||
"L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal", "crypto/x509",
|
"L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal", "crypto/x509",
|
||||||
"golang_org/x/net/http/httpguts",
|
"golang_org/x/net/http/httpguts",
|
||||||
},
|
},
|
||||||
"net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal"},
|
"net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal", "golang_org/x/net/http/httpguts"},
|
||||||
"net/http/pprof": {"L4", "OS", "html/template", "net/http", "runtime/pprof", "runtime/trace"},
|
"net/http/pprof": {"L4", "OS", "html/template", "net/http", "runtime/pprof", "runtime/trace"},
|
||||||
"net/rpc": {"L4", "NET", "encoding/gob", "html/template", "net/http"},
|
"net/rpc": {"L4", "NET", "encoding/gob", "html/template", "net/http"},
|
||||||
"net/rpc/jsonrpc": {"L4", "NET", "encoding/json", "net/rpc"},
|
"net/rpc/jsonrpc": {"L4", "NET", "encoding/json", "net/rpc"},
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ package httputil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
|
@ -16,6 +17,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang_org/x/net/http/httpguts"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
||||||
|
|
@ -199,6 +202,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
p.Director(outreq)
|
p.Director(outreq)
|
||||||
outreq.Close = false
|
outreq.Close = false
|
||||||
|
|
||||||
|
reqUpType := upgradeType(outreq.Header)
|
||||||
removeConnectionHeaders(outreq.Header)
|
removeConnectionHeaders(outreq.Header)
|
||||||
|
|
||||||
// Remove hop-by-hop headers to the backend. Especially
|
// Remove hop-by-hop headers to the backend. Especially
|
||||||
|
|
@ -221,6 +225,13 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
outreq.Header.Del(h)
|
outreq.Header.Del(h)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// After stripping all the hop-by-hop connection headers above, add back any
|
||||||
|
// necessary for protocol upgrades, such as for websockets.
|
||||||
|
if reqUpType != "" {
|
||||||
|
outreq.Header.Set("Connection", "Upgrade")
|
||||||
|
outreq.Header.Set("Upgrade", reqUpType)
|
||||||
|
}
|
||||||
|
|
||||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||||
// If we aren't the first proxy retain prior
|
// If we aren't the first proxy retain prior
|
||||||
// X-Forwarded-For information as a comma+space
|
// X-Forwarded-For information as a comma+space
|
||||||
|
|
@ -237,6 +248,12 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||||
|
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||||
|
p.handleUpgradeResponse(rw, outreq, res)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
removeConnectionHeaders(res.Header)
|
removeConnectionHeaders(res.Header)
|
||||||
|
|
||||||
for _, h := range hopHeaders {
|
for _, h := range hopHeaders {
|
||||||
|
|
@ -463,3 +480,67 @@ func (m *maxLatencyWriter) stop() {
|
||||||
m.t.Stop()
|
m.t.Stop()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func upgradeType(h http.Header) string {
|
||||||
|
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.ToLower(h.Get("Upgrade"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
||||||
|
reqUpType := upgradeType(req.Header)
|
||||||
|
resUpType := upgradeType(res.Header)
|
||||||
|
if reqUpType != resUpType {
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hj, ok := rw.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
backConn, ok := res.Body.(io.ReadWriteCloser)
|
||||||
|
if !ok {
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer backConn.Close()
|
||||||
|
conn, brw, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
|
||||||
|
if err := res.Write(brw); err != nil {
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := brw.Flush(); err != nil {
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errc := make(chan error, 1)
|
||||||
|
spc := switchProtocolCopier{user: conn, backend: backConn}
|
||||||
|
go spc.copyToBackend(errc)
|
||||||
|
go spc.copyFromBackend(errc)
|
||||||
|
<-errc
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// switchProtocolCopier exists so goroutines proxying data back and
|
||||||
|
// forth have nice names in stacks.
|
||||||
|
type switchProtocolCopier struct {
|
||||||
|
user, backend io.ReadWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
|
||||||
|
_, err := io.Copy(c.user, c.backend)
|
||||||
|
errc <- err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
|
||||||
|
_, err := io.Copy(c.backend, c.user)
|
||||||
|
errc <- err
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -153,15 +153,20 @@ func TestReverseProxy(t *testing.T) {
|
||||||
func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
|
func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
|
||||||
const fakeConnectionToken = "X-Fake-Connection-Token"
|
const fakeConnectionToken = "X-Fake-Connection-Token"
|
||||||
const backendResponse = "I am the backend"
|
const backendResponse = "I am the backend"
|
||||||
|
|
||||||
|
// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
|
||||||
|
// in the Request's Connection header.
|
||||||
|
const someConnHeader = "X-Some-Conn-Header"
|
||||||
|
|
||||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if c := r.Header.Get(fakeConnectionToken); c != "" {
|
if c := r.Header.Get(fakeConnectionToken); c != "" {
|
||||||
t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
|
t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
|
||||||
}
|
}
|
||||||
if c := r.Header.Get("Upgrade"); c != "" {
|
if c := r.Header.Get(someConnHeader); c != "" {
|
||||||
t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
|
t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
|
||||||
}
|
}
|
||||||
w.Header().Set("Connection", "Upgrade, "+fakeConnectionToken)
|
w.Header().Set("Connection", someConnHeader+", "+fakeConnectionToken)
|
||||||
w.Header().Set("Upgrade", "should be deleted")
|
w.Header().Set(someConnHeader, "should be deleted")
|
||||||
w.Header().Set(fakeConnectionToken, "should be deleted")
|
w.Header().Set(fakeConnectionToken, "should be deleted")
|
||||||
io.WriteString(w, backendResponse)
|
io.WriteString(w, backendResponse)
|
||||||
}))
|
}))
|
||||||
|
|
@ -173,15 +178,15 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
|
||||||
proxyHandler := NewSingleHostReverseProxy(backendURL)
|
proxyHandler := NewSingleHostReverseProxy(backendURL)
|
||||||
frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
proxyHandler.ServeHTTP(w, r)
|
proxyHandler.ServeHTTP(w, r)
|
||||||
if c := r.Header.Get("Upgrade"); c != "original value" {
|
if c := r.Header.Get(someConnHeader); c != "original value" {
|
||||||
t.Errorf("handler modified header %q = %q; want %q", "Upgrade", c, "original value")
|
t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "original value")
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
defer frontend.Close()
|
defer frontend.Close()
|
||||||
|
|
||||||
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
||||||
getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken)
|
getReq.Header.Set("Connection", someConnHeader+", "+fakeConnectionToken)
|
||||||
getReq.Header.Set("Upgrade", "original value")
|
getReq.Header.Set(someConnHeader, "original value")
|
||||||
getReq.Header.Set(fakeConnectionToken, "should be deleted")
|
getReq.Header.Set(fakeConnectionToken, "should be deleted")
|
||||||
res, err := frontend.Client().Do(getReq)
|
res, err := frontend.Client().Do(getReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -195,8 +200,8 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
|
||||||
if got, want := string(bodyBytes), backendResponse; got != want {
|
if got, want := string(bodyBytes), backendResponse; got != want {
|
||||||
t.Errorf("got body %q; want %q", got, want)
|
t.Errorf("got body %q; want %q", got, want)
|
||||||
}
|
}
|
||||||
if c := res.Header.Get("Upgrade"); c != "" {
|
if c := res.Header.Get(someConnHeader); c != "" {
|
||||||
t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
|
t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
|
||||||
}
|
}
|
||||||
if c := res.Header.Get(fakeConnectionToken); c != "" {
|
if c := res.Header.Get(fakeConnectionToken); c != "" {
|
||||||
t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
|
t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
|
||||||
|
|
@ -980,3 +985,66 @@ func TestSelectFlushInterval(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyWebSocket(t *testing.T) {
|
||||||
|
backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if upgradeType(r.Header) != "websocket" {
|
||||||
|
t.Error("unexpected backend request")
|
||||||
|
http.Error(w, "unexpected request", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c, _, err := w.(http.Hijacker).Hijack()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
|
||||||
|
bs := bufio.NewScanner(c)
|
||||||
|
if !bs.Scan() {
|
||||||
|
t.Errorf("backend failed to read line from client: %v", bs.Err())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Fprintf(c, "backend got %q\n", bs.Text())
|
||||||
|
}))
|
||||||
|
defer backendServer.Close()
|
||||||
|
|
||||||
|
backURL, _ := url.Parse(backendServer.URL)
|
||||||
|
rproxy := NewSingleHostReverseProxy(backURL)
|
||||||
|
rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
|
||||||
|
|
||||||
|
frontendProxy := httptest.NewServer(rproxy)
|
||||||
|
defer frontendProxy.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
|
||||||
|
req.Header.Set("Connection", "Upgrade")
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
|
||||||
|
c := frontendProxy.Client()
|
||||||
|
res, err := c.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if res.StatusCode != 101 {
|
||||||
|
t.Fatalf("status = %v; want 101", res.Status)
|
||||||
|
}
|
||||||
|
if upgradeType(res.Header) != "websocket" {
|
||||||
|
t.Fatalf("not websocket upgrade; got %#v", res.Header)
|
||||||
|
}
|
||||||
|
rwc, ok := res.Body.(io.ReadWriteCloser)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
|
||||||
|
}
|
||||||
|
defer rwc.Close()
|
||||||
|
|
||||||
|
io.WriteString(rwc, "Hello\n")
|
||||||
|
bs := bufio.NewScanner(rwc)
|
||||||
|
if !bs.Scan() {
|
||||||
|
t.Fatalf("Scan: %v", bs.Err())
|
||||||
|
}
|
||||||
|
got := bs.Text()
|
||||||
|
want := `backend got "Hello"`
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("got %#q, want %#q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1714,7 +1714,7 @@ func (pc *persistConn) readLoop() {
|
||||||
alive = false
|
alive = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasBody {
|
if !hasBody || bodyWritable {
|
||||||
pc.t.setReqCanceler(rc.req, nil)
|
pc.t.setReqCanceler(rc.req, nil)
|
||||||
|
|
||||||
// Put the idle conn back into the pool before we send the response
|
// Put the idle conn back into the pool before we send the response
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue