reverseproxy: Fix H2C dialer using new stdlib DialTLSContext (#4951)

This commit is contained in:
Francis Lavoie 2022-08-12 15:11:13 -04:00 committed by GitHub
parent 91ab0e6066
commit 922d9f5c25
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 68 deletions

View file

@ -128,28 +128,6 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error {
}
h.Transport = rt
// if h2c is enabled, configure its transport (std lib http.Transport
// does not "HTTP/2 over cleartext TCP")
if sliceContains(h.Versions, "h2c") {
// crafting our own http2.Transport doesn't allow us to utilize
// most of the customizations/preferences on the http.Transport,
// because, for some reason, only http2.ConfigureTransport()
// is allowed to set the unexported field that refers to a base
// http.Transport config; oh well
h2t := &http2.Transport{
// kind of a hack, but for plaintext/H2C requests, pretend to dial TLS
DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
// TODO: no context, thus potentially wrong dial info
return net.Dial(network, addr)
},
AllowHTTP: true,
}
if h.Compression != nil {
h2t.DisableCompression = !*h.Compression
}
h.h2cTransport = h2t
}
return nil
}
@ -194,35 +172,38 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
}
}
// Set up the dialer to pull the correct information from the context
dialContext := func(ctx context.Context, network, address string) (net.Conn, error) {
// the proper dialing information should be embedded into the request's context
if dialInfo, ok := GetDialInfo(ctx); ok {
network = dialInfo.Network
address = dialInfo.Address
}
conn, err := dialer.DialContext(ctx, network, address)
if err != nil {
// identify this error as one that occurred during
// dialing, which can be important when trying to
// decide whether to retry a request
return nil, DialError{err}
}
// if read/write timeouts are configured and this is a TCP connection, enforce the timeouts
// by wrapping the connection with our own type
if tcpConn, ok := conn.(*net.TCPConn); ok && (h.ReadTimeout > 0 || h.WriteTimeout > 0) {
conn = &tcpRWTimeoutConn{
TCPConn: tcpConn,
readTimeout: time.Duration(h.ReadTimeout),
writeTimeout: time.Duration(h.WriteTimeout),
logger: caddyCtx.Logger(h),
}
}
return conn, nil
}
rt := &http.Transport{
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
// the proper dialing information should be embedded into the request's context
if dialInfo, ok := GetDialInfo(ctx); ok {
network = dialInfo.Network
address = dialInfo.Address
}
conn, err := dialer.DialContext(ctx, network, address)
if err != nil {
// identify this error as one that occurred during
// dialing, which can be important when trying to
// decide whether to retry a request
return nil, DialError{err}
}
// if read/write timeouts are configured and this is a TCP connection, enforce the timeouts
// by wrapping the connection with our own type
if tcpConn, ok := conn.(*net.TCPConn); ok && (h.ReadTimeout > 0 || h.WriteTimeout > 0) {
conn = &tcpRWTimeoutConn{
TCPConn: tcpConn,
readTimeout: time.Duration(h.ReadTimeout),
writeTimeout: time.Duration(h.WriteTimeout),
logger: caddyCtx.Logger(h),
}
}
return conn, nil
},
DialContext: dialContext,
MaxConnsPerHost: h.MaxConnsPerHost,
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout),
@ -260,6 +241,27 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
}
}
// if h2c is enabled, configure its transport (std lib http.Transport
// does not "HTTP/2 over cleartext TCP")
if sliceContains(h.Versions, "h2c") {
// crafting our own http2.Transport doesn't allow us to utilize
// most of the customizations/preferences on the http.Transport,
// because, for some reason, only http2.ConfigureTransport()
// is allowed to set the unexported field that refers to a base
// http.Transport config; oh well
h2t := &http2.Transport{
// kind of a hack, but for plaintext/H2C requests, pretend to dial TLS
DialTLSContext: func(ctx context.Context, network, address string, _ *tls.Config) (net.Conn, error) {
return dialContext(ctx, network, address)
},
AllowHTTP: true,
}
if h.Compression != nil {
h2t.DisableCompression = !*h.Compression
}
h.h2cTransport = h2t
}
return rt, nil
}