net/http/httptest: add NewTestServer with in-memory network

NewTestServer takes a *testing.T. It uses it to register a
Cleanup function to shut down the server at the end of a test,
and to fail the test when a server handler panics.

NewTestServer uses an in-memory fake network by default,
suitable for use with the testing/synctest package.
The user can choose to use the loopback network instead
by calling Start or StartTLS.

For #76608

Change-Id: Ic99e304f1a809a4727aaa9f7316f60ec6a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/769521
Reviewed-by: Nicholas Husin <husin@google.com>
LUCI-TryBot-Result: golang-scoped@luci-project-accounts.iam.gserviceaccount.com <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Nicholas Husin <nsh@golang.org>
Auto-Submit: Damien Neil <dneil@google.com>
This commit is contained in:
Damien Neil 2026-04-18 17:11:02 -04:00 committed by Gopher Robot
parent a871fd3732
commit 813b317cc9
5 changed files with 404 additions and 59 deletions

1
api/next/76608.txt Normal file
View file

@ -0,0 +1 @@
pkg net/http/httptest, func NewTestServer(testing.TB, http.Handler) *Server #76608

View file

@ -0,0 +1,2 @@
[NewTestServer] creates a [Server] configured to use an in-memory
fake network suitable for use with the [testing/synctest] package.

View file

@ -685,12 +685,6 @@ var depsRules = `
net/http, net/http/internal/ascii
< net/http/cookiejar, net/http/httputil;
NET, internal/gate
< internal/nettest;
net/http, flag
< net/http/httptest;
net/http, regexp
< net/http/cgi
< net/http/fcgi;
@ -735,6 +729,12 @@ var depsRules = `
testing, crypto/rand
< testing/cryptotest;
NET, internal/gate
< internal/nettest;
net/http, flag, internal/nettest, testing
< net/http/httptest;
FMT, crypto/sha256, encoding/binary, encoding/json,
go/ast, go/parser, go/token,
internal/godebug, math/rand, encoding/hex

View file

@ -12,26 +12,123 @@ import (
"crypto/x509"
"flag"
"fmt"
"internal/nettest"
"log"
"net"
"net/http"
"net/http/internal/testcert"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"
_ "unsafe" // for linkname
)
// A Server is an HTTP server listening on a system-chosen port on the
// local loopback interface, for use in end-to-end HTTP tests.
// A Server is an HTTP server for use in end-to-end HTTP tests.
//
// Most tests should create a server with [NewTestServer].
// The [Server.Client] method returns a client which sends requests to the test server.
//
// // Create a test server and send a request to it.
// server := httptest.NewTestServer(t, handler)
// resp, err := server.Client().Get("http://www.example.com/")
//
// # Configuration
//
// Tests may change a Server's configuration prior to using it.
// The configuration must not be changed after the first call to
// [Server.Client], [Server.Start], or [Server.StartTLS].
//
// // Configure a test server before using.
// server := httptest.NewTestServer(t, handler)
// server.Config.MaxHeaderBytes = 1024
// resp, err := server.Client().Get("http://www.example.com/")
//
// # Tests
//
// Servers created with [NewTestServer] will:
//
// - Fail the test if the server handler panics with
// any value other than [http.ErrAbortHandler].
// - Register a Cleanup function to shut down the server at the end of the test.
//
// Servers created in any other way must be manually shut down with [Server.Close].
//
// # In-Memory Network
//
// A Server may use an in-memory network implementation or
// listen on a local network loopback interface.
// Most tests should use the in-memory network,
// which avoids port exhaustion and other transient networking issues
// and is suitable for use with the [testing/synctest] package.
//
// To use the in-memory network, create a server with [NewTestServer].
// Do not call [Server.Start] or [Server.StartTLS].
//
// When using the in-memory network, the [http.Client] returned by [Server.Client]
// is configured to send all requests to the server.
// The client will direct HTTP and HTTPS requests,
// regardless of destination address or hostname, to the server.
// Requests do not need to use [Server.URL] as the base URL.
//
// server := httptest.NewTestServer(t, handler)
// client := server.Client()
//
// // All of these requests are sent to the test server.
// // https:// requests use TLS over the in-memory network.
// _, _ = client.Get("http://www.example.com/")
// _, _ = client.Get("https://go.dev/")
// _, _ = client.Get("http://10.0.0.1/")
//
// The [Server.Listener] field is not set when using the in-memory network.
//
// # Loopback Network
//
// To listen on a loopback interface, call [Server.Start] or [Server.StartTLS].
// The server will listen on a system-chosen port.
//
// Loopback servers serve one of HTTP (when started with [Server.Start])
// or HTTPS (when started with [Server.StartTLS]).
//
// When using the loopback network, the [http.Client] returned by [Server.Client]
// is configured to send requests with a hostname of "example.com" or a subdomain
// of ".example.com" to the server.
//
// Requests may also be sent to the server's loopback address.
// The [Server.URL] field is set to a base URL containing the server's address.
//
// server := httptest.NewTestServer(t, handler)
// server.Start()
// client := server.Client()
//
// // This request is sent to the test server.
// _, _ = server.Client().Get(server.URL + "/")
//
// // This request (using http.DefaultClient) is also sent to the test server,
// // since server.URL contains the server's local IP address.
// _, _ = http.Get(server.URL + "/")
type Server struct {
URL string // base URL of form http://ipaddr:port with no trailing slash
// URL is the base URL of the server, of the form http://address:port
// with no trailing slash.
//
// It is set by the first call to Client, Start, or StartTLS.
//
// For servers listening on loopback, the address is the loopback IP address
// of the server.
//
// For servers using the in-memory network, this address is "example.com".
// Requests sent to servers using the in-memory network may use any address.
// It is not necessary to use this base URL.
URL string
// Listener is the network listener for servers listening on loopback.
// It is not set for servers using the in-memory network.
Listener net.Listener
// EnableHTTP2 controls whether HTTP/2 is enabled
// on the server. It must be set between calling
// NewUnstartedServer and calling Server.StartTLS.
// EnableHTTP2 controls whether HTTP/2 is enabled on the server.
// It must be set before calling Client, Start, or StartTLS.
EnableHTTP2 bool
// TLS is the optional TLS configuration, populated with a new config
@ -39,16 +136,24 @@ type Server struct {
// is called, existing fields are copied into the new config.
TLS *tls.Config
// Config may be changed after calling NewUnstartedServer and
// before Start or StartTLS.
// Config may be changed before calling Client, Start, or StartTLS.
Config *http.Server
t testing.TB
// certificate is a parsed version of the TLS config certificate, if present.
certificate *x509.Certificate
// startOnce is used to start fakenet servers once.
startOnce sync.Once
// started indicates whether the server has been started.
started bool
// Fake network listeners, one for HTTP and one for HTTPS.
fakeListener *nettest.Listener
fakeTLSListener *nettest.Listener
// wg counts the number of outstanding HTTP requests on this server.
// Close blocks until all requests are finished.
wg sync.WaitGroup
@ -62,6 +167,51 @@ type Server struct {
client *http.Client
}
// NewTestServer returns a new [Server] for a test.
// The server will use an in-memory network implementation by default.
//
// If the handler is nil, the server will serve 500 responses to all requests.
// It will not use [http.DefaultServeMux].
//
// See the [Server] documentation for more details.
func NewTestServer(t testing.TB, handler http.Handler) *Server {
s := &Server{
t: t,
Config: &http.Server{Handler: testServerHandler{t: t, h: handler}},
}
t.Cleanup(func() {
s.Close()
})
return s
}
type testServerHandler struct {
t testing.TB
h http.Handler
}
func (h testServerHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
defer func() {
if err := recover(); err != nil {
if err != http.ErrAbortHandler {
// This is the same logging http.Server would do,
// but we can put it into the test output rather than stderr.
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
h.t.Errorf("httptest: panic in server handler: %v\n%s", err, buf)
}
// Convert panic to ErrAbortHandler to suppress http.Server's logging.
panic(http.ErrAbortHandler)
}
}()
if h.h != nil {
h.h.ServeHTTP(w, req)
} else {
w.WriteHeader(500)
}
}
func newLocalListener() net.Listener {
if serveFlag != "" {
l, err := net.Listen("tcp", serveFlag)
@ -105,20 +255,30 @@ func strSliceContainsPrefix(v []string, pre string) bool {
return false
}
// NewServer starts and returns a new [Server].
// The caller should call Close when finished, to shut it down.
// NewServer starts and returns a new [Server] listening on a
// local network loopback interface.
// This is equivalent to calling [NewUnstartedServer] followed by [Server.Start].
//
// The caller should call [Server.Close] when finished, to shut it down.
//
// Most users should use [NewTestServer] instead.
// See the [Server] documentation for details.
func NewServer(handler http.Handler) *Server {
ts := NewUnstartedServer(handler)
ts.Start()
return ts
}
// NewUnstartedServer returns a new [Server] but doesn't start it.
// NewUnstartedServer returns a new [Server] listening on a
// local network loopback interface. It does not start the server.
//
// After changing its configuration, the caller should call Start or
// StartTLS.
// After changing the server's configuration, the caller should
// call [Server.Start] or [Server.StartTLS].
//
// The caller should call Close when finished, to shut it down.
// The caller should call [Server.Close] when finished, to shut it down.
//
// Most users should use [NewTestServer] instead.
// See the [Server] documentation for details.
func NewUnstartedServer(handler http.Handler) *Server {
return &Server{
Listener: newLocalListener(),
@ -126,13 +286,38 @@ func NewUnstartedServer(handler http.Handler) *Server {
}
}
// Start starts a server from NewUnstartedServer.
func (s *Server) Start() {
func (s *Server) startCommon(useLoopback bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.started {
panic("Server already started")
}
if s.closed {
panic("Start of closed Server")
}
s.started = true
if s.t != nil && useLoopback {
// We're being called from Start or StartTLS.
// Don't try to start the server again when Client is called.
s.startOnce.Do(func() {})
// NewTestServer servers create their listener at start time.
//
// We might want to permit the user to provide their own Listener
// in the future. For now, we panic.
if s.Listener != nil {
panic("Server.Listener is unexpectedly set")
}
s.Listener = newLocalListener()
}
s.wrap()
}
// Start starts a server on a local loopback network interface.
//
// The server should have been created by [NewTestServer] or [NewUnstartedServer].
func (s *Server) Start() {
s.startCommon(true)
tr := &http.Transport{}
s.client = &http.Client{Transport: tr}
@ -153,25 +338,17 @@ func (s *Server) Start() {
return dialer.DialContext(ctx, network, addr)
}
s.URL = "http://" + s.Listener.Addr().String()
s.goServe()
s.goServe(s.Listener)
if serveFlag != "" {
fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
select {}
}
}
// StartTLS starts TLS on a server from NewUnstartedServer.
func (s *Server) StartTLS() {
if s.started {
panic("Server already started")
}
s.started = true
s.wrap()
s.client = &http.Client{}
func (s *Server) initTLS() (tlsClientConfig *tls.Config, err error) {
cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
if err != nil {
panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
return nil, err
}
existingConfig := s.TLS
@ -192,15 +369,30 @@ func (s *Server) StartTLS() {
}
s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
if err != nil {
panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
return nil, err
}
certpool := x509.NewCertPool()
certpool.AddCert(s.certificate)
return &tls.Config{
RootCAs: certpool,
}, nil
}
// Start starts TLS on a server on a local loopback network interface.
//
// The server should have been created by [NewTestServer] or [NewUnstartedServer].
func (s *Server) StartTLS() {
s.startCommon(true)
s.client = &http.Client{}
tlsClientConfig, err := s.initTLS()
if err != nil {
panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
}
tr := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: certpool,
},
TLSClientConfig: tlsClientConfig,
ForceAttemptHTTP2: s.EnableHTTP2,
}
s.client.Transport = tr
@ -220,11 +412,49 @@ func (s *Server) StartTLS() {
}
s.Listener = tls.NewListener(s.Listener, s.TLS)
s.URL = "https://" + s.Listener.Addr().String()
s.goServe()
s.goServe(s.Listener)
}
// NewTLSServer starts and returns a new [Server] using TLS.
// The caller should call Close when finished, to shut it down.
func (s *Server) startFakeNet() {
s.startCommon(false)
s.client = &http.Client{}
tlsClientConfig, err := s.initTLS()
if err != nil {
panic(fmt.Sprintf("httptest: NewTestServer: %v", err))
}
tr := &http.Transport{
TLSClientConfig: tlsClientConfig,
ForceAttemptHTTP2: s.EnableHTTP2,
}
s.client.Transport = tr
s.fakeListener = nettest.NewListener()
s.fakeTLSListener = nettest.NewListener()
// Set InsecureSkipVerify rather than depending on a specific server hostname.
tr.TLSClientConfig.InsecureSkipVerify = true
tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
return s.fakeListener.NewConn(), nil
}
tr.DialTLSContext = func(ctx context.Context, network, address string) (net.Conn, error) {
return tls.Client(s.fakeTLSListener.NewConn(), tr.TLSClientConfig), nil
}
s.URL = "http://example.com"
s.goServe(s.fakeListener)
s.goServe(tls.NewListener(s.fakeTLSListener, s.TLS))
}
// NewTLSServer starts and returns a new [Server] using TLS and listening on a
// local network loopback interface.
// This is equivalent to calling [NewUnstartedServer] followed by [Server.StartTLS].
//
// The caller should call [Server.Close] when finished, to shut it down.
//
// Most users should use [NewTestServer] instead.
// See the [Server] documentation for details.
func NewTLSServer(handler http.Handler) *Server {
ts := NewUnstartedServer(handler)
ts.StartTLS()
@ -244,6 +474,10 @@ func (s *Server) Close() {
if s.Listener != nil {
s.Listener.Close()
}
if s.fakeListener != nil {
s.fakeListener.Close()
s.fakeTLSListener.Close()
}
s.Config.SetKeepAlivesEnabled(false)
for c, st := range s.conns {
// Force-close any idle connections (those between
@ -338,18 +572,18 @@ func (s *Server) Certificate() *x509.Certificate {
// Client returns an HTTP client configured for making requests to the server.
// It is configured to trust the server's TLS test certificate and will
// close its idle connections on [Server.Close].
// Use Server.URL as the base URL to send requests to the server.
// The returned client will also redirect any requests to "example.com"
// or its subdomains to the server.
func (s *Server) Client() *http.Client {
if s.t != nil {
s.startOnce.Do(s.startFakeNet)
}
return s.client
}
func (s *Server) goServe() {
func (s *Server) goServe(li net.Listener) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.Config.Serve(s.Listener)
s.Config.Serve(li)
}()
}

View file

@ -6,35 +6,58 @@ package httptest
import (
"bufio"
"internal/testenv"
"io"
"net"
"net/http"
"os"
"regexp"
"strings"
"sync"
"testing"
"testing/synctest"
)
type newServerFunc func(http.Handler) *Server
type newServerFunc func(*testing.T, http.Handler) *Server
var newServers = map[string]newServerFunc{
"NewServer": NewServer,
"NewTLSServer": NewTLSServer,
"NewServer": func(t *testing.T, h http.Handler) *Server {
return NewServer(h)
},
"NewTLSServer": func(t *testing.T, h http.Handler) *Server {
return NewTLSServer(h)
},
// The manual variants of newServer create a Server manually by only filling
// in the exported fields of Server.
"NewServerManual": func(h http.Handler) *Server {
"NewServerManual": func(t *testing.T, h http.Handler) *Server {
ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
ts.Start()
return ts
},
"NewTLSServerManual": func(h http.Handler) *Server {
"NewTLSServerManual": func(t *testing.T, h http.Handler) *Server {
ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
ts.StartTLS()
return ts
},
"NewTestServerMemory": func(t *testing.T, h http.Handler) *Server {
return NewTestServer(t, h)
},
"NewTestServerLoopback": func(t *testing.T, h http.Handler) *Server {
ts := NewTestServer(t, h)
ts.Start()
return ts
},
"NewTestServerLoopbackTLS": func(t *testing.T, h http.Handler) *Server {
ts := NewTestServer(t, h)
ts.StartTLS()
return ts
},
}
func TestServer(t *testing.T) {
for _, name := range []string{"NewServer", "NewServerManual"} {
for _, name := range []string{"NewServer", "NewServerManual", "NewTestServerLoopback"} {
t.Run(name, func(t *testing.T) {
newServer := newServers[name]
t.Run("Server", func(t *testing.T) { testServer(t, newServer) })
@ -44,7 +67,7 @@ func TestServer(t *testing.T) {
t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) })
})
}
for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} {
for _, name := range []string{"NewTLSServer", "NewTLSServerManual", "NewTestServerMemory", "NewTestServerLoopbackTLS"} {
t.Run(name, func(t *testing.T) {
newServer := newServers[name]
t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) })
@ -54,7 +77,7 @@ func TestServer(t *testing.T) {
}
func testServer(t *testing.T, newServer newServerFunc) {
ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello"))
}))
defer ts.Close()
@ -74,7 +97,7 @@ func testServer(t *testing.T, newServer newServerFunc) {
// Issue 12781
func testGetAfterClose(t *testing.T, newServer newServerFunc) {
ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello"))
}))
@ -101,7 +124,7 @@ func testGetAfterClose(t *testing.T, newServer newServerFunc) {
}
func testServerCloseBlocking(t *testing.T, newServer newServerFunc) {
ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello"))
}))
dial := func() net.Conn {
@ -131,7 +154,7 @@ func testServerCloseBlocking(t *testing.T, newServer newServerFunc) {
// Issue 14290
func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) {
var s *Server
s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s = newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.CloseClientConnections()
}))
defer s.Close()
@ -145,7 +168,7 @@ func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) {
// Tests that the Server.Client method works and returns an http.Client that can hit
// NewTLSServer without cert warnings.
func testServerClient(t *testing.T, newTLSServer newServerFunc) {
ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := newTLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello"))
}))
defer ts.Close()
@ -167,7 +190,7 @@ func testServerClient(t *testing.T, newTLSServer newServerFunc) {
// Tests that the Server.Client.Transport interface is implemented
// by a *http.Transport.
func testServerClientTransportType(t *testing.T, newServer newServerFunc) {
ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
client := ts.Client()
@ -179,7 +202,7 @@ func testServerClientTransportType(t *testing.T, newServer newServerFunc) {
// Tests that the TLS Server.Client.Transport interface is implemented
// by a *http.Transport.
func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) {
ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := newTLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
client := ts.Client()
@ -330,3 +353,88 @@ func TestClientExampleCom(t *testing.T) {
})
}
}
func TestServerInMemoryNetwork(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
ts := NewTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
for _, u := range []string{
"http://example.tld/",
"https://example.tld/",
"https://go.dev/",
"http://127.0.0.1/",
"http://[::1]/",
"https://127.0.0.1/",
} {
resp, err := ts.Client().Get(u)
if err != nil {
t.Errorf("Get(%q): %v", u, err)
continue
}
resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("Get(%q): Response.StatusCode = %v, want 200", u, resp.StatusCode)
}
if gotTLS, wantTLS := resp.TLS != nil, strings.HasPrefix(u, "https://"); gotTLS != wantTLS {
t.Errorf("Get(%q): TLS: %v; want %v", u, gotTLS, wantTLS)
}
}
})
}
func TestServerNilHandler(t *testing.T) {
ts := NewTestServer(t, nil)
resp, err := ts.Client().Get("http://example.tld/")
if err != nil {
t.Fatalf("Get: %v", err)
}
resp.Body.Close()
if got, want := resp.StatusCode, 500; got != want {
t.Errorf("Response.StatusCode = %v, want %v", got, want)
}
}
func TestServerPanicErrAbortHandler(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
ts := NewTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic(http.ErrAbortHandler)
}))
resp, err := ts.Client().Get("http://example.com/")
if err == nil {
resp.Body.Close()
t.Errorf("request succeeded; want failure")
}
})
}
func TestServerPanicFailsTest(t *testing.T) {
runTest(t, func() {
synctest.Test(t, func(t *testing.T) {
ts := NewTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("PANIC MESSAGE")
}))
ts.Client().Get("http://example.com/")
})
}, `--- FAIL: TestServerPanicFailsTest.*
.*: httptest: panic in server handler: PANIC MESSAGE
`)
}
func runTest(t *testing.T, f func(), pattern string) {
if os.Getenv("GO_WANT_HELPER_PROCESS") == "1" {
f()
return
}
t.Helper()
re := regexp.MustCompile(pattern)
testenv.MustHaveExec(t)
cmd := testenv.Command(t, testenv.Executable(t), "-test.run=^"+regexp.QuoteMeta(t.Name())+"$", "-test.count=1")
cmd = testenv.CleanCmdEnv(cmd)
cmd.Env = append(cmd.Env, "GO_WANT_HELPER_PROCESS=1")
out, _ := cmd.CombinedOutput()
if !re.Match(out) {
t.Errorf("got output:\n%s\nwant matching:\n%s", out, pattern)
}
}