net/http: support non-tls.Conn TLS connections

Historically, x/net/http2 has supported working with user-provided
net.Conns that implement the ConnectionState method from *tls.Conn.

Howver, net/http has required that user-provided net.Conns be a
concrete *tls.Conn for it to observe the TLS state of the connection.

CL 440795 made the net/http Server set Request.TLS when the net.Conn
has a ConnectionState method, but did not affect protocol negotiation.
(So a non-*tls.Conn can never negotiate HTTP/2.)

CL 765360 made the net/http Server also able to negotiate HTTP/2
when using a non-*tls.Conn net.Conn with a ConnectionState method.
This was part of a change to support net/http/internal/http2 tests,
not an intentional change to functionality in net/http.

Finish off the job by making the net/http Transport negotiate HTTP/2
when using a non-*tls.Conn net.Conn with a ConnectionState method.

Document the new support.

Fixes #21753

Change-Id: Ia9854f90df3044fcc855fce105006a186a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/772961
Reviewed-by: Nicholas Husin <nsh@golang.org>
Reviewed-by: Nicholas Husin <husin@google.com>
Auto-Submit: Damien Neil <dneil@google.com>
Reviewed-by: 王伟强 <wwqgtxx@gmail.com>
TryBot-Bypass: Damien Neil <dneil@google.com>
This commit is contained in:
Damien Neil 2026-03-24 09:41:49 -07:00 committed by Gopher Robot
parent 887f38afa9
commit 15b9fc2659
4 changed files with 120 additions and 20 deletions

View file

@ -0,0 +1,3 @@
[Transport] and [Server] support TLS ALPN protocol negotiation on
user-provided [net.Conn] connections which implement a
`ConnectionState() tls.ConnectionState` method.

View file

@ -1923,3 +1923,88 @@ func testEarlyHintsRequest(t *testing.T, mode testMode) {
t.Errorf("Read body %q; want Hello", body)
}
}
// TestClientServerTLSConnWrapper verifies that the Transport and Server can
// negotiate an HTTP/2 connection using a net.Conn that has a
// "ConnectionState() tls.ConnectionState" method but is not a *tls.Conn.
func TestClientServerTLSConnWrapper(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
protocols := &Protocols{}
protocols.SetHTTP1(true)
protocols.SetHTTP2(true)
li := fakeNetListen()
server := &Server{
Handler: HandlerFunc(func(w ResponseWriter, r *Request) {
if r.TLS == nil {
t.Fatal("server request has no TLS ConnectionState")
}
}),
Protocols: protocols,
}
defer server.Close()
go server.Serve(&testListener{
accept: func() (net.Conn, error) {
conn, err := li.Accept()
if err != nil {
return nil, err
}
return &testTLSConn{
Conn: conn,
state: tls.ConnectionState{
Version: tls.VersionTLS13,
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
NegotiatedProtocol: "h2",
},
}, nil
},
close: li.Close,
addr: li.Addr(),
})
tr := &Transport{
DialTLS: func(network, address string) (net.Conn, error) {
return &testTLSConn{
Conn: li.connect(),
state: tls.ConnectionState{
Version: tls.VersionTLS13,
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
NegotiatedProtocol: "h2",
},
}, nil
},
Protocols: protocols,
}
req, _ := NewRequest("GET", "https://example.tld", nil)
resp, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("response status %v, want 200", resp.StatusCode)
}
if resp.TLS == nil {
t.Fatal("server request has no TLS ConnectionState")
}
})
}
type testListener struct {
accept func() (net.Conn, error)
close func() error
addr net.Addr
}
func (li *testListener) Accept() (net.Conn, error) { return li.accept() }
func (li *testListener) Close() error { return li.close() }
func (li *testListener) Addr() net.Addr { return li.addr }
type testTLSConn struct {
net.Conn
state tls.ConnectionState
}
func (c *testTLSConn) Handshake() error { return nil }
func (c *testTLSConn) ConnectionState() tls.ConnectionState { return c.state }

View file

@ -1915,10 +1915,6 @@ func isCommonNetReadError(err error) bool {
return false
}
type connectionStater interface {
ConnectionState() tls.ConnectionState
}
// Serve a new connection.
func (c *conn) serve(ctx context.Context) {
if ra := c.rwc.RemoteAddr(); ra != nil {
@ -1947,6 +1943,12 @@ func (c *conn) serve(ctx context.Context) {
}
}()
type connectionStater interface {
ConnectionState() tls.ConnectionState
}
type handshakeContexter interface {
HandshakeContext(ctx context.Context) error
}
if connStater, ok := c.rwc.(connectionStater); ok {
tlsTO := c.server.tlsHandshakeTimeout()
if tlsTO > 0 {
@ -1954,9 +1956,6 @@ func (c *conn) serve(ctx context.Context) {
c.rwc.SetReadDeadline(dl)
c.rwc.SetWriteDeadline(dl)
}
type handshakeContexter interface {
HandshakeContext(ctx context.Context) error
}
var err error
if handshaker, ok := c.rwc.(handshakeContexter); ok {
err = handshaker.HandshakeContext(ctx)
@ -2989,8 +2988,9 @@ func (mux *ServeMux) registerErr(patstr string, handler Handler) error {
// The handler is typically nil, in which case [DefaultServeMux] is used.
//
// HTTP/2 support is only enabled if the Listener returns [*tls.Conn]
// connections and they were configured with "h2" in the TLS
// Config.NextProtos.
// connections or connections which implement the same ConnectionState
// method as *tls.Conn, and the connection state indicates that the "h2"
// protocol was negotiated by ALPN.
//
// Serve always returns a non-nil error.
func Serve(l net.Listener, handler Handler) error {

View file

@ -166,6 +166,9 @@ type Transport struct {
// requests and the TLSClientConfig and TLSHandshakeTimeout
// are ignored. The returned net.Conn is assumed to already be
// past the TLS handshake.
//
// To support ALPN protocol negotiation, the returned net.Conn should be
// a *tls.Conn or implement the same ConnectionState method as *tls.Conn.
DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
// DialTLS specifies an optional dial function for creating
@ -1886,20 +1889,28 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod, isClientConn
if err != nil {
return nil, wrapErr(err)
}
if tc, ok := pconn.conn.(*tls.Conn); ok {
// Handshake here, in case DialTLS didn't. TLSNextProto below
// depends on it for knowing the connection state.
type connectionStater interface {
ConnectionState() tls.ConnectionState
}
type handshaker interface {
HandshakeContext(context.Context) error
}
if cstater, ok := pconn.conn.(connectionStater); ok {
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
if err := tc.HandshakeContext(ctx); err != nil {
go pconn.conn.Close()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tls.ConnectionState{}, err)
if handshaker, ok := cstater.(handshaker); ok {
// Handshake here, in case DialTLS didn't. TLSNextProto below
// depends on it for knowing the connection state.
if err := handshaker.HandshakeContext(ctx); err != nil {
go pconn.conn.Close()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tls.ConnectionState{}, err)
}
return nil, err
}
return nil, err
}
cs := tc.ConnectionState()
cs := cstater.ConnectionState()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(cs, nil)
}
@ -2095,8 +2106,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod, isClientConn
}
if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" {
if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok {
alt := next(cm.targetAddr, pconn.conn.(*tls.Conn))
tlsConn, tlsConnOK := pconn.conn.(*tls.Conn)
if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; tlsConnOK && ok {
alt := next(cm.targetAddr, tlsConn)
if e, ok := alt.(erringRoundTripper); ok {
// pconn.conn was closed by next (http2configureTransports.upgradeFn).
return nil, e.RoundTripErr()