net: context aware Dialer.Dial functions

Add context aware dial functions for TCP, UDP, IP and Unix networks.

Fixes #49097
Updates #59897

Change-Id: I7523452e8e463a587a852e0555cec822d8dcb3dd
Reviewed-on: https://go-review.googlesource.com/c/go/+/490975
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Sean Liao <sean@liao.dev>
This commit is contained in:
Michael Fraenkel 2023-04-30 09:12:27 -06:00 committed by Sean Liao
parent 6abfe7b0de
commit 2b804abf07
8 changed files with 234 additions and 29 deletions

4
api/next/49097.txt Normal file
View file

@ -0,0 +1,4 @@
pkg net, method (*Dialer) DialIP(context.Context, string, netip.Addr, netip.Addr) (*IPConn, error) #49097
pkg net, method (*Dialer) DialTCP(context.Context, string, netip.AddrPort, netip.AddrPort) (*TCPConn, error) #49097
pkg net, method (*Dialer) DialUDP(context.Context, string, netip.AddrPort, netip.AddrPort) (*UDPConn, error) #49097
pkg net, method (*Dialer) DialUnix(context.Context, string, *UnixAddr, *UnixAddr) (*UnixConn, error) #49097

View file

@ -0,0 +1 @@
Added context aware dial functions for TCP, UDP, IP and Unix networks.

View file

@ -9,6 +9,7 @@ import (
"internal/bytealg"
"internal/godebug"
"internal/nettrace"
"net/netip"
"syscall"
"time"
)
@ -523,30 +524,8 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
// See func [Dial] for a description of the network and address
// parameters.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
if ctx == nil {
panic("nil context")
}
deadline := d.deadline(ctx, time.Now())
if !deadline.IsZero() {
testHookStepTime()
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
subCtx, cancel := context.WithDeadline(ctx, deadline)
defer cancel()
ctx = subCtx
}
}
if oldCancel := d.Cancel; oldCancel != nil {
subCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-oldCancel:
cancel()
case <-subCtx.Done():
}
}()
ctx = subCtx
}
ctx, cancel := d.dialCtx(ctx)
defer cancel()
// Shadow the nettrace (if any) during resolve so Connect events don't fire for DNS lookups.
resolveCtx := ctx
@ -578,6 +557,97 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
return sd.dialParallel(ctx, primaries, fallbacks)
}
func (d *Dialer) dialCtx(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
panic("nil context")
}
deadline := d.deadline(ctx, time.Now())
var cancel1, cancel2 context.CancelFunc
if !deadline.IsZero() {
testHookStepTime()
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
var subCtx context.Context
subCtx, cancel1 = context.WithDeadline(ctx, deadline)
ctx = subCtx
}
}
if oldCancel := d.Cancel; oldCancel != nil {
subCtx, cancel2 := context.WithCancel(ctx)
go func() {
select {
case <-oldCancel:
cancel2()
case <-subCtx.Done():
}
}()
ctx = subCtx
}
return ctx, func() {
if cancel1 != nil {
cancel1()
}
if cancel2 != nil {
cancel2()
}
}
}
// DialTCP acts like Dial for TCP networks using the provided context.
//
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
//
// The network must be a TCP network name; see func Dial for details.
func (d *Dialer) DialTCP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*TCPConn, error) {
ctx, cancel := d.dialCtx(ctx)
defer cancel()
return dialTCP(ctx, d, network, TCPAddrFromAddrPort(laddr), TCPAddrFromAddrPort(raddr))
}
// DialUDP acts like Dial for UDP networks using the provided context.
//
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
//
// The network must be a UDP network name; see func Dial for details.
func (d *Dialer) DialUDP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*UDPConn, error) {
ctx, cancel := d.dialCtx(ctx)
defer cancel()
return dialUDP(ctx, d, network, UDPAddrFromAddrPort(laddr), UDPAddrFromAddrPort(raddr))
}
// DialIP acts like Dial for IP networks using the provided context.
//
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
//
// The network must be an IP network name; see func Dial for details.
func (d *Dialer) DialIP(ctx context.Context, network string, laddr netip.Addr, raddr netip.Addr) (*IPConn, error) {
ctx, cancel := d.dialCtx(ctx)
defer cancel()
return dialIP(ctx, d, network, ipAddrFromAddr(laddr), ipAddrFromAddr(raddr))
}
// DialUnix acts like Dial for Unix networks using the provided context.
//
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
//
// The network must be a Unix network name; see func Dial for details.
func (d *Dialer) DialUnix(ctx context.Context, network string, laddr *UnixAddr, raddr *UnixAddr) (*UnixConn, error) {
ctx, cancel := d.dialCtx(ctx)
defer cancel()
return dialUnix(ctx, d, network, laddr, raddr)
}
// dialParallel races two copies of dialSerial, giving the first a
// head start. It returns the first established connection and
// closes the others. Otherwise it returns an error from the first

View file

@ -11,6 +11,7 @@ import (
"fmt"
"internal/testenv"
"io"
"net/netip"
"os"
"runtime"
"strings"
@ -1064,6 +1065,99 @@ func TestDialerControlContext(t *testing.T) {
})
}
func TestDialContext(t *testing.T) {
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
case "js", "wasip1":
t.Skipf("skipping: fake net does not support Dialer.ControlContext")
}
t.Run("StreamDial", func(t *testing.T) {
var err error
for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
if !testableNetwork(network) {
continue
}
ln := newLocalListener(t, network)
defer ln.Close()
var id int
d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
id = ctx.Value("id").(int)
return controlOnConnSetup(network, address, c)
}}
var c Conn
switch network {
case "tcp", "tcp4", "tcp6":
raddr, err := netip.ParseAddrPort(ln.Addr().String())
if err != nil {
t.Error(err)
continue
}
c, err = d.DialTCP(context.WithValue(context.Background(), "id", i+1), network, (*TCPAddr)(nil).AddrPort(), raddr)
case "unix", "unixpacket":
raddr, err := ResolveUnixAddr(network, ln.Addr().String())
if err != nil {
t.Error(err)
continue
}
c, err = d.DialUnix(context.WithValue(context.Background(), "id", i+1), network, nil, raddr)
}
if err != nil {
t.Error(err)
continue
}
if id != i+1 {
t.Errorf("%s: got id %d, want %d", network, id, i+1)
}
c.Close()
}
})
t.Run("PacketDial", func(t *testing.T) {
var err error
for i, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
if !testableNetwork(network) {
continue
}
c1 := newLocalPacketListener(t, network)
if network == "unixgram" {
defer os.Remove(c1.LocalAddr().String())
}
defer c1.Close()
var id int
d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
id = ctx.Value("id").(int)
return controlOnConnSetup(network, address, c)
}}
var c2 Conn
switch network {
case "udp", "udp4", "udp6":
raddr, err := netip.ParseAddrPort(c1.LocalAddr().String())
if err != nil {
t.Error(err)
continue
}
c2, err = d.DialUDP(context.WithValue(context.Background(), "id", i+1), network, (*UDPAddr)(nil).AddrPort(), raddr)
case "unixgram":
raddr, err := ResolveUnixAddr(network, c1.LocalAddr().String())
if err != nil {
t.Error(err)
continue
}
c2, err = d.DialUnix(context.WithValue(context.Background(), "id", i+1), network, nil, raddr)
}
if err != nil {
t.Error(err)
continue
}
if id != i+1 {
t.Errorf("%s: got id %d, want %d", network, id, i+1)
}
c2.Close()
}
})
}
// mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
// except on non-Linux, non-mobile builders it permits the test to
// run in -short mode.

View file

@ -6,6 +6,7 @@ package net
import (
"context"
"net/netip"
"syscall"
)
@ -24,6 +25,13 @@ import (
// BUG(mikio): On JS and Plan 9, methods and functions related
// to IPConn are not implemented.
func ipAddrFromAddr(addr netip.Addr) *IPAddr {
return &IPAddr{
IP: addr.AsSlice(),
Zone: addr.Zone(),
}
}
// IPAddr represents the address of an IP end point.
type IPAddr struct {
IP IP
@ -206,11 +214,18 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
// If the IP field of raddr is nil or an unspecified IP address, the
// local system is assumed.
func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
return dialIP(context.Background(), nil, network, laddr, raddr)
}
func dialIP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *IPAddr) (*IPConn, error) {
if raddr == nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
sd := &sysDialer{network: network, address: raddr.String()}
c, err := sd.dialIP(context.Background(), laddr, raddr)
if dialer != nil {
sd.Dialer = *dialer
}
c, err := sd.dialIP(ctx, laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}

View file

@ -315,6 +315,10 @@ func newTCPConn(fd *netFD, keepAliveIdle time.Duration, keepAliveCfg KeepAliveCo
// If the IP field of raddr is nil or an unspecified IP address, the
// local system is assumed.
func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
return dialTCP(context.Background(), nil, network, laddr, raddr)
}
func dialTCP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
switch network {
case "tcp", "tcp4", "tcp6":
default:
@ -328,10 +332,13 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
c *TCPConn
err error
)
if dialer != nil {
sd.Dialer = *dialer
}
if sd.MultipathTCP() {
c, err = sd.dialMPTCP(context.Background(), laddr, raddr)
c, err = sd.dialMPTCP(ctx, laddr, raddr)
} else {
c, err = sd.dialTCP(context.Background(), laddr, raddr)
c, err = sd.dialTCP(ctx, laddr, raddr)
}
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}

View file

@ -285,6 +285,10 @@ func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} }
// If the IP field of raddr is nil or an unspecified IP address, the
// local system is assumed.
func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
return dialUDP(context.Background(), nil, network, laddr, raddr)
}
func dialUDP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
switch network {
case "udp", "udp4", "udp6":
default:
@ -294,7 +298,10 @@ func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
sd := &sysDialer{network: network, address: raddr.String()}
c, err := sd.dialUDP(context.Background(), laddr, raddr)
if dialer != nil {
sd.Dialer = *dialer
}
c, err := sd.dialUDP(ctx, laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}

View file

@ -201,13 +201,20 @@ func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} }
// If laddr is non-nil, it is used as the local address for the
// connection.
func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
return dialUnix(context.Background(), nil, network, laddr, raddr)
}
func dialUnix(ctx context.Context, dialer *Dialer, network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
switch network {
case "unix", "unixgram", "unixpacket":
default:
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
}
sd := &sysDialer{network: network, address: raddr.String()}
c, err := sd.dialUnix(context.Background(), laddr, raddr)
if dialer != nil {
sd.Dialer = *dialer
}
c, err := sd.dialUnix(ctx, laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}