mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-03 23:01:06 +00:00 
			
		
		
		
	reverseproxy: Close hijacked conns on reload/quit (#4895)
* reverseproxy: Close hijacked conns on reload/quit We also send a Close control message to both ends of WebSocket connections. I have tested this many times in my dev environment with consistent success, although the variety of scenarios was limited. * Oops... actually call Close() this time * CloseMessage --> closeMessage Co-authored-by: Francis Lavoie <lavofr@gmail.com> * Use httpguts, duh * Use map instead of sync.Map Co-authored-by: Francis Lavoie <lavofr@gmail.com>
This commit is contained in:
		
							parent
							
								
									d3c3fa10bd
								
							
						
					
					
						commit
						66476d8c8f
					
				
					 2 changed files with 103 additions and 5 deletions
				
			
		| 
						 | 
					@ -192,6 +192,10 @@ type Handler struct {
 | 
				
			||||||
	// Holds the handle_response Caddyfile tokens while adapting
 | 
						// Holds the handle_response Caddyfile tokens while adapting
 | 
				
			||||||
	handleResponseSegments []*caddyfile.Dispenser
 | 
						handleResponseSegments []*caddyfile.Dispenser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Stores upgraded requests (hijacked connections) for proper cleanup
 | 
				
			||||||
 | 
						connections   map[io.ReadWriteCloser]openConnection
 | 
				
			||||||
 | 
						connectionsMu *sync.Mutex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ctx    caddy.Context
 | 
						ctx    caddy.Context
 | 
				
			||||||
	logger *zap.Logger
 | 
						logger *zap.Logger
 | 
				
			||||||
	events *caddyevents.App
 | 
						events *caddyevents.App
 | 
				
			||||||
| 
						 | 
					@ -214,6 +218,8 @@ func (h *Handler) Provision(ctx caddy.Context) error {
 | 
				
			||||||
	h.events = eventAppIface.(*caddyevents.App)
 | 
						h.events = eventAppIface.(*caddyevents.App)
 | 
				
			||||||
	h.ctx = ctx
 | 
						h.ctx = ctx
 | 
				
			||||||
	h.logger = ctx.Logger(h)
 | 
						h.logger = ctx.Logger(h)
 | 
				
			||||||
 | 
						h.connections = make(map[io.ReadWriteCloser]openConnection)
 | 
				
			||||||
 | 
						h.connectionsMu = new(sync.Mutex)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// verify SRV compatibility - TODO: LookupSRV deprecated; will be removed
 | 
						// verify SRV compatibility - TODO: LookupSRV deprecated; will be removed
 | 
				
			||||||
	for i, v := range h.Upstreams {
 | 
						for i, v := range h.Upstreams {
 | 
				
			||||||
| 
						 | 
					@ -407,16 +413,34 @@ func (h *Handler) Provision(ctx caddy.Context) error {
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Cleanup cleans up the resources made by h during provisioning.
 | 
					// Cleanup cleans up the resources made by h.
 | 
				
			||||||
func (h *Handler) Cleanup() error {
 | 
					func (h *Handler) Cleanup() error {
 | 
				
			||||||
	// TODO: Close keepalive connections on reload? https://github.com/caddyserver/caddy/pull/2507/files#diff-70219fd88fe3f36834f474ce6537ed26R762
 | 
						// close hijacked connections (both to client and backend)
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
						h.connectionsMu.Lock()
 | 
				
			||||||
 | 
						for _, oc := range h.connections {
 | 
				
			||||||
 | 
							if oc.gracefulClose != nil {
 | 
				
			||||||
 | 
								// this is potentially blocking while we have the lock on the connections
 | 
				
			||||||
 | 
								// map, but that should be OK since the server has in theory shut down
 | 
				
			||||||
 | 
								// and we are no longer using the connections map
 | 
				
			||||||
 | 
								gracefulErr := oc.gracefulClose()
 | 
				
			||||||
 | 
								if gracefulErr != nil && err == nil {
 | 
				
			||||||
 | 
									err = gracefulErr
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							closeErr := oc.conn.Close()
 | 
				
			||||||
 | 
							if closeErr != nil && err == nil {
 | 
				
			||||||
 | 
								err = closeErr
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.connectionsMu.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// remove hosts from our config from the pool
 | 
						// remove hosts from our config from the pool
 | 
				
			||||||
	for _, upstream := range h.Upstreams {
 | 
						for _, upstream := range h.Upstreams {
 | 
				
			||||||
		_, _ = hosts.Delete(upstream.String())
 | 
							_, _ = hosts.Delete(upstream.String())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
 | 
					func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,6 +20,7 @@ package reverseproxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
 | 
						"encoding/binary"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"mime"
 | 
						"mime"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
| 
						 | 
					@ -27,6 +28,7 @@ import (
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"go.uber.org/zap"
 | 
						"go.uber.org/zap"
 | 
				
			||||||
 | 
						"golang.org/x/net/http/httpguts"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) {
 | 
					func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) {
 | 
				
			||||||
| 
						 | 
					@ -97,8 +99,26 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	errc := make(chan error, 1)
 | 
						// Ensure the hijacked client connection, and the new connection established
 | 
				
			||||||
 | 
						// with the backend, are both closed in the event of a server shutdown. This
 | 
				
			||||||
 | 
						// is done by registering them. We also try to gracefully close connections
 | 
				
			||||||
 | 
						// we recognize as websockets.
 | 
				
			||||||
 | 
						gracefulClose := func(conn io.ReadWriteCloser) func() error {
 | 
				
			||||||
 | 
							if isWebsocket(req) {
 | 
				
			||||||
 | 
								return func() error {
 | 
				
			||||||
 | 
									return writeCloseControl(conn)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						deleteFrontConn := h.registerConnection(conn, gracefulClose(conn))
 | 
				
			||||||
 | 
						deleteBackConn := h.registerConnection(backConn, gracefulClose(backConn))
 | 
				
			||||||
 | 
						defer deleteFrontConn()
 | 
				
			||||||
 | 
						defer deleteBackConn()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	spc := switchProtocolCopier{user: conn, backend: backConn}
 | 
						spc := switchProtocolCopier{user: conn, backend: backConn}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						errc := make(chan error, 1)
 | 
				
			||||||
	go spc.copyToBackend(errc)
 | 
						go spc.copyToBackend(errc)
 | 
				
			||||||
	go spc.copyFromBackend(errc)
 | 
						go spc.copyFromBackend(errc)
 | 
				
			||||||
	<-errc
 | 
						<-errc
 | 
				
			||||||
| 
						 | 
					@ -209,6 +229,60 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, er
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// registerConnection holds onto conn so it can be closed in the event
 | 
				
			||||||
 | 
					// of a server shutdown. This is useful because hijacked connections or
 | 
				
			||||||
 | 
					// connections dialed to backends don't close when server is shut down.
 | 
				
			||||||
 | 
					// The caller should call the returned delete() function when the
 | 
				
			||||||
 | 
					// connection is done to remove it from memory.
 | 
				
			||||||
 | 
					func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error) (del func()) {
 | 
				
			||||||
 | 
						h.connectionsMu.Lock()
 | 
				
			||||||
 | 
						h.connections[conn] = openConnection{conn, gracefulClose}
 | 
				
			||||||
 | 
						h.connectionsMu.Unlock()
 | 
				
			||||||
 | 
						return func() {
 | 
				
			||||||
 | 
							h.connectionsMu.Lock()
 | 
				
			||||||
 | 
							delete(h.connections, conn)
 | 
				
			||||||
 | 
							h.connectionsMu.Unlock()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// writeCloseControl sends a best-effort Close control message to the given
 | 
				
			||||||
 | 
					// WebSocket connection. Thanks to @pascaldekloe who provided inspiration
 | 
				
			||||||
 | 
					// from his simple implementation of this I was able to learn from at:
 | 
				
			||||||
 | 
					// github.com/pascaldekloe/websocket.
 | 
				
			||||||
 | 
					func writeCloseControl(conn io.Writer) error {
 | 
				
			||||||
 | 
						// https://github.com/pascaldekloe/websocket/blob/32050af67a5d/websocket.go#L119
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var reason string // max 123 bytes (control frame payload limit is 125; status code takes 2)
 | 
				
			||||||
 | 
						const goingAway uint16 = 1001
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO: we might need to ensure we are the exclusive writer by this point (io.Copy is stopped)?
 | 
				
			||||||
 | 
						var writeBuf [127]byte
 | 
				
			||||||
 | 
						const closeMessage = 8
 | 
				
			||||||
 | 
						const finalBit = 1 << 7
 | 
				
			||||||
 | 
						writeBuf[0] = closeMessage | finalBit
 | 
				
			||||||
 | 
						writeBuf[1] = byte(len(reason) + 2)
 | 
				
			||||||
 | 
						binary.BigEndian.PutUint16(writeBuf[2:4], goingAway)
 | 
				
			||||||
 | 
						copy(writeBuf[4:], reason)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// simply best-effort, but return error for logging purposes
 | 
				
			||||||
 | 
						_, err := conn.Write(writeBuf[:4+len(reason)])
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// isWebsocket returns true if r looks to be an upgrade request for WebSockets.
 | 
				
			||||||
 | 
					// It is a fairly naive check.
 | 
				
			||||||
 | 
					func isWebsocket(r *http.Request) bool {
 | 
				
			||||||
 | 
						return httpguts.HeaderValuesContainsToken(r.Header["Connection"], "upgrade") &&
 | 
				
			||||||
 | 
							httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// openConnection maps an open connection to
 | 
				
			||||||
 | 
					// an optional function for graceful close.
 | 
				
			||||||
 | 
					type openConnection struct {
 | 
				
			||||||
 | 
						conn          io.ReadWriteCloser
 | 
				
			||||||
 | 
						gracefulClose func() error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type writeFlusher interface {
 | 
					type writeFlusher interface {
 | 
				
			||||||
	io.Writer
 | 
						io.Writer
 | 
				
			||||||
	http.Flusher
 | 
						http.Flusher
 | 
				
			||||||
| 
						 | 
					@ -265,7 +339,7 @@ func (m *maxLatencyWriter) stop() {
 | 
				
			||||||
// switchProtocolCopier exists so goroutines proxying data back and
 | 
					// switchProtocolCopier exists so goroutines proxying data back and
 | 
				
			||||||
// forth have nice names in stacks.
 | 
					// forth have nice names in stacks.
 | 
				
			||||||
type switchProtocolCopier struct {
 | 
					type switchProtocolCopier struct {
 | 
				
			||||||
	user, backend io.ReadWriter
 | 
						user, backend io.ReadWriteCloser
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
 | 
					func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue