mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
net/http/httptest: redirect example.com requests to server
The default server cert used by NewServer already includes example.com in its DNSNames, and by default, the client's RootCA configuration means it won't trust a response from the real example.com. Fixes #31054 Change-Id: I0686977e5ffe2c2f22f3fc09a47ee8ecc44765db Reviewed-on: https://go-review.googlesource.com/c/go/+/666855 Reviewed-by: Damien Neil <dneil@google.com> Reviewed-by: Carlos Amedee <carlos@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
parent
d86ec92499
commit
f2db0dca0b
3 changed files with 71 additions and 2 deletions
2
doc/next/6-stdlib/99-minor/net/http/httptest/31054.md
Normal file
2
doc/next/6-stdlib/99-minor/net/http/httptest/31054.md
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
The HTTP client returned by [Server.Client] will now redirect requests for
|
||||||
|
`example.com` and any subdomains to the server being tested.
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
package httptest
|
package httptest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"flag"
|
"flag"
|
||||||
|
|
@ -126,8 +127,24 @@ func (s *Server) Start() {
|
||||||
if s.URL != "" {
|
if s.URL != "" {
|
||||||
panic("Server already started")
|
panic("Server already started")
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.client == nil {
|
if s.client == nil {
|
||||||
s.client = &http.Client{Transport: &http.Transport{}}
|
tr := &http.Transport{}
|
||||||
|
dialer := net.Dialer{}
|
||||||
|
// User code may set either of Dial or DialContext, with DialContext taking precedence.
|
||||||
|
// We set DialContext here to preserve any context values that are passed in,
|
||||||
|
// but fall back to Dial if the user has set it.
|
||||||
|
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
if tr.Dial != nil {
|
||||||
|
return tr.Dial(network, addr)
|
||||||
|
}
|
||||||
|
if addr == "example.com:80" || strings.HasSuffix(addr, ".example.com:80") {
|
||||||
|
addr = s.Listener.Addr().String()
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
s.client = &http.Client{Transport: tr}
|
||||||
|
|
||||||
}
|
}
|
||||||
s.URL = "http://" + s.Listener.Addr().String()
|
s.URL = "http://" + s.Listener.Addr().String()
|
||||||
s.wrap()
|
s.wrap()
|
||||||
|
|
@ -173,12 +190,23 @@ func (s *Server) StartTLS() {
|
||||||
}
|
}
|
||||||
certpool := x509.NewCertPool()
|
certpool := x509.NewCertPool()
|
||||||
certpool.AddCert(s.certificate)
|
certpool.AddCert(s.certificate)
|
||||||
s.client.Transport = &http.Transport{
|
tr := &http.Transport{
|
||||||
TLSClientConfig: &tls.Config{
|
TLSClientConfig: &tls.Config{
|
||||||
RootCAs: certpool,
|
RootCAs: certpool,
|
||||||
},
|
},
|
||||||
ForceAttemptHTTP2: s.EnableHTTP2,
|
ForceAttemptHTTP2: s.EnableHTTP2,
|
||||||
}
|
}
|
||||||
|
dialer := net.Dialer{}
|
||||||
|
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
if tr.Dial != nil {
|
||||||
|
return tr.Dial(network, addr)
|
||||||
|
}
|
||||||
|
if addr == "example.com:443" || strings.HasSuffix(addr, ".example.com:443") {
|
||||||
|
addr = s.Listener.Addr().String()
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
s.client.Transport = tr
|
||||||
s.Listener = tls.NewListener(s.Listener, s.TLS)
|
s.Listener = tls.NewListener(s.Listener, s.TLS)
|
||||||
s.URL = "https://" + s.Listener.Addr().String()
|
s.URL = "https://" + s.Listener.Addr().String()
|
||||||
s.wrap()
|
s.wrap()
|
||||||
|
|
@ -300,6 +328,8 @@ func (s *Server) Certificate() *x509.Certificate {
|
||||||
// It is configured to trust the server's TLS test certificate and will
|
// It is configured to trust the server's TLS test certificate and will
|
||||||
// close its idle connections on [Server.Close].
|
// close its idle connections on [Server.Close].
|
||||||
// Use Server.URL as the base URL to send requests to the server.
|
// Use Server.URL as the base URL to send requests to the server.
|
||||||
|
// The returned client will also redirect any requests to "example.com"
|
||||||
|
// or its subdomains to the server.
|
||||||
func (s *Server) Client() *http.Client {
|
func (s *Server) Client() *http.Client {
|
||||||
return s.client
|
return s.client
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -293,3 +293,40 @@ func TestTLSServerWithHTTP2(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientExampleCom(t *testing.T) {
|
||||||
|
modes := []struct {
|
||||||
|
proto string
|
||||||
|
host string
|
||||||
|
}{
|
||||||
|
{"http", "example.com"},
|
||||||
|
{"http", "foo.example.com"},
|
||||||
|
{"https", "example.com"},
|
||||||
|
{"https", "foo.example.com"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range modes {
|
||||||
|
t.Run(tt.proto+" "+tt.host, func(t *testing.T) {
|
||||||
|
cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("requested-hostname", r.Host)
|
||||||
|
}))
|
||||||
|
switch tt.proto {
|
||||||
|
case "https":
|
||||||
|
cst.EnableHTTP2 = true
|
||||||
|
cst.StartTLS()
|
||||||
|
default:
|
||||||
|
cst.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
defer cst.Close()
|
||||||
|
|
||||||
|
res, err := cst.Client().Get(tt.proto + "://" + tt.host)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to make request: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := res.Header.Get("requested-hostname"), tt.host; got != want {
|
||||||
|
t.Fatalf("Requested hostname mismatch\ngot: %q\nwant: %q", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue