mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
net: context aware Dialer.Dial functions
Add context aware dial functions for TCP, UDP, IP and Unix networks. Fixes #49097 Updates #59897 Change-Id: I7523452e8e463a587a852e0555cec822d8dcb3dd Reviewed-on: https://go-review.googlesource.com/c/go/+/490975 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com> Reviewed-by: David Chase <drchase@google.com> Reviewed-by: Sean Liao <sean@liao.dev>
This commit is contained in:
parent
6abfe7b0de
commit
2b804abf07
8 changed files with 234 additions and 29 deletions
4
api/next/49097.txt
Normal file
4
api/next/49097.txt
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
pkg net, method (*Dialer) DialIP(context.Context, string, netip.Addr, netip.Addr) (*IPConn, error) #49097
|
||||
pkg net, method (*Dialer) DialTCP(context.Context, string, netip.AddrPort, netip.AddrPort) (*TCPConn, error) #49097
|
||||
pkg net, method (*Dialer) DialUDP(context.Context, string, netip.AddrPort, netip.AddrPort) (*UDPConn, error) #49097
|
||||
pkg net, method (*Dialer) DialUnix(context.Context, string, *UnixAddr, *UnixAddr) (*UnixConn, error) #49097
|
||||
1
doc/next/6-stdlib/99-minor/net/49097.md
Normal file
1
doc/next/6-stdlib/99-minor/net/49097.md
Normal file
|
|
@ -0,0 +1 @@
|
|||
Added context aware dial functions for TCP, UDP, IP and Unix networks.
|
||||
118
src/net/dial.go
118
src/net/dial.go
|
|
@ -9,6 +9,7 @@ import (
|
|||
"internal/bytealg"
|
||||
"internal/godebug"
|
||||
"internal/nettrace"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -523,30 +524,8 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
|
|||
// See func [Dial] for a description of the network and address
|
||||
// parameters.
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
|
||||
if ctx == nil {
|
||||
panic("nil context")
|
||||
}
|
||||
deadline := d.deadline(ctx, time.Now())
|
||||
if !deadline.IsZero() {
|
||||
testHookStepTime()
|
||||
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
|
||||
subCtx, cancel := context.WithDeadline(ctx, deadline)
|
||||
defer cancel()
|
||||
ctx = subCtx
|
||||
}
|
||||
}
|
||||
if oldCancel := d.Cancel; oldCancel != nil {
|
||||
subCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
go func() {
|
||||
select {
|
||||
case <-oldCancel:
|
||||
cancel()
|
||||
case <-subCtx.Done():
|
||||
}
|
||||
}()
|
||||
ctx = subCtx
|
||||
}
|
||||
ctx, cancel := d.dialCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Shadow the nettrace (if any) during resolve so Connect events don't fire for DNS lookups.
|
||||
resolveCtx := ctx
|
||||
|
|
@ -578,6 +557,97 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
|
|||
return sd.dialParallel(ctx, primaries, fallbacks)
|
||||
}
|
||||
|
||||
func (d *Dialer) dialCtx(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
panic("nil context")
|
||||
}
|
||||
deadline := d.deadline(ctx, time.Now())
|
||||
var cancel1, cancel2 context.CancelFunc
|
||||
if !deadline.IsZero() {
|
||||
testHookStepTime()
|
||||
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
|
||||
var subCtx context.Context
|
||||
subCtx, cancel1 = context.WithDeadline(ctx, deadline)
|
||||
ctx = subCtx
|
||||
}
|
||||
}
|
||||
if oldCancel := d.Cancel; oldCancel != nil {
|
||||
subCtx, cancel2 := context.WithCancel(ctx)
|
||||
go func() {
|
||||
select {
|
||||
case <-oldCancel:
|
||||
cancel2()
|
||||
case <-subCtx.Done():
|
||||
}
|
||||
}()
|
||||
ctx = subCtx
|
||||
}
|
||||
return ctx, func() {
|
||||
if cancel1 != nil {
|
||||
cancel1()
|
||||
}
|
||||
if cancel2 != nil {
|
||||
cancel2()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DialTCP acts like Dial for TCP networks using the provided context.
|
||||
//
|
||||
// The provided Context must be non-nil. If the context expires before
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
//
|
||||
// The network must be a TCP network name; see func Dial for details.
|
||||
func (d *Dialer) DialTCP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*TCPConn, error) {
|
||||
ctx, cancel := d.dialCtx(ctx)
|
||||
defer cancel()
|
||||
return dialTCP(ctx, d, network, TCPAddrFromAddrPort(laddr), TCPAddrFromAddrPort(raddr))
|
||||
}
|
||||
|
||||
// DialUDP acts like Dial for UDP networks using the provided context.
|
||||
//
|
||||
// The provided Context must be non-nil. If the context expires before
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
//
|
||||
// The network must be a UDP network name; see func Dial for details.
|
||||
func (d *Dialer) DialUDP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*UDPConn, error) {
|
||||
ctx, cancel := d.dialCtx(ctx)
|
||||
defer cancel()
|
||||
return dialUDP(ctx, d, network, UDPAddrFromAddrPort(laddr), UDPAddrFromAddrPort(raddr))
|
||||
}
|
||||
|
||||
// DialIP acts like Dial for IP networks using the provided context.
|
||||
//
|
||||
// The provided Context must be non-nil. If the context expires before
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
//
|
||||
// The network must be an IP network name; see func Dial for details.
|
||||
func (d *Dialer) DialIP(ctx context.Context, network string, laddr netip.Addr, raddr netip.Addr) (*IPConn, error) {
|
||||
ctx, cancel := d.dialCtx(ctx)
|
||||
defer cancel()
|
||||
return dialIP(ctx, d, network, ipAddrFromAddr(laddr), ipAddrFromAddr(raddr))
|
||||
}
|
||||
|
||||
// DialUnix acts like Dial for Unix networks using the provided context.
|
||||
//
|
||||
// The provided Context must be non-nil. If the context expires before
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
//
|
||||
// The network must be a Unix network name; see func Dial for details.
|
||||
func (d *Dialer) DialUnix(ctx context.Context, network string, laddr *UnixAddr, raddr *UnixAddr) (*UnixConn, error) {
|
||||
ctx, cancel := d.dialCtx(ctx)
|
||||
defer cancel()
|
||||
return dialUnix(ctx, d, network, laddr, raddr)
|
||||
}
|
||||
|
||||
// dialParallel races two copies of dialSerial, giving the first a
|
||||
// head start. It returns the first established connection and
|
||||
// closes the others. Otherwise it returns an error from the first
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
"fmt"
|
||||
"internal/testenv"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
|
@ -1064,6 +1065,99 @@ func TestDialerControlContext(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestDialContext(t *testing.T) {
|
||||
switch runtime.GOOS {
|
||||
case "plan9":
|
||||
t.Skipf("not supported on %s", runtime.GOOS)
|
||||
case "js", "wasip1":
|
||||
t.Skipf("skipping: fake net does not support Dialer.ControlContext")
|
||||
}
|
||||
|
||||
t.Run("StreamDial", func(t *testing.T) {
|
||||
var err error
|
||||
for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
|
||||
if !testableNetwork(network) {
|
||||
continue
|
||||
}
|
||||
ln := newLocalListener(t, network)
|
||||
defer ln.Close()
|
||||
var id int
|
||||
d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
|
||||
id = ctx.Value("id").(int)
|
||||
return controlOnConnSetup(network, address, c)
|
||||
}}
|
||||
var c Conn
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
raddr, err := netip.ParseAddrPort(ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
c, err = d.DialTCP(context.WithValue(context.Background(), "id", i+1), network, (*TCPAddr)(nil).AddrPort(), raddr)
|
||||
case "unix", "unixpacket":
|
||||
raddr, err := ResolveUnixAddr(network, ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
c, err = d.DialUnix(context.WithValue(context.Background(), "id", i+1), network, nil, raddr)
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
if id != i+1 {
|
||||
t.Errorf("%s: got id %d, want %d", network, id, i+1)
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
})
|
||||
t.Run("PacketDial", func(t *testing.T) {
|
||||
var err error
|
||||
for i, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
|
||||
if !testableNetwork(network) {
|
||||
continue
|
||||
}
|
||||
c1 := newLocalPacketListener(t, network)
|
||||
if network == "unixgram" {
|
||||
defer os.Remove(c1.LocalAddr().String())
|
||||
}
|
||||
defer c1.Close()
|
||||
var id int
|
||||
d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
|
||||
id = ctx.Value("id").(int)
|
||||
return controlOnConnSetup(network, address, c)
|
||||
}}
|
||||
var c2 Conn
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
raddr, err := netip.ParseAddrPort(c1.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
c2, err = d.DialUDP(context.WithValue(context.Background(), "id", i+1), network, (*UDPAddr)(nil).AddrPort(), raddr)
|
||||
case "unixgram":
|
||||
raddr, err := ResolveUnixAddr(network, c1.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
c2, err = d.DialUnix(context.WithValue(context.Background(), "id", i+1), network, nil, raddr)
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
if id != i+1 {
|
||||
t.Errorf("%s: got id %d, want %d", network, id, i+1)
|
||||
}
|
||||
c2.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
|
||||
// except on non-Linux, non-mobile builders it permits the test to
|
||||
// run in -short mode.
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ package net
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
|
|
@ -24,6 +25,13 @@ import (
|
|||
// BUG(mikio): On JS and Plan 9, methods and functions related
|
||||
// to IPConn are not implemented.
|
||||
|
||||
func ipAddrFromAddr(addr netip.Addr) *IPAddr {
|
||||
return &IPAddr{
|
||||
IP: addr.AsSlice(),
|
||||
Zone: addr.Zone(),
|
||||
}
|
||||
}
|
||||
|
||||
// IPAddr represents the address of an IP end point.
|
||||
type IPAddr struct {
|
||||
IP IP
|
||||
|
|
@ -206,11 +214,18 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
|
|||
// If the IP field of raddr is nil or an unspecified IP address, the
|
||||
// local system is assumed.
|
||||
func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
|
||||
return dialIP(context.Background(), nil, network, laddr, raddr)
|
||||
}
|
||||
|
||||
func dialIP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *IPAddr) (*IPConn, error) {
|
||||
if raddr == nil {
|
||||
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
|
||||
}
|
||||
sd := &sysDialer{network: network, address: raddr.String()}
|
||||
c, err := sd.dialIP(context.Background(), laddr, raddr)
|
||||
if dialer != nil {
|
||||
sd.Dialer = *dialer
|
||||
}
|
||||
c, err := sd.dialIP(ctx, laddr, raddr)
|
||||
if err != nil {
|
||||
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -315,6 +315,10 @@ func newTCPConn(fd *netFD, keepAliveIdle time.Duration, keepAliveCfg KeepAliveCo
|
|||
// If the IP field of raddr is nil or an unspecified IP address, the
|
||||
// local system is assumed.
|
||||
func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
|
||||
return dialTCP(context.Background(), nil, network, laddr, raddr)
|
||||
}
|
||||
|
||||
func dialTCP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
default:
|
||||
|
|
@ -328,10 +332,13 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
|
|||
c *TCPConn
|
||||
err error
|
||||
)
|
||||
if dialer != nil {
|
||||
sd.Dialer = *dialer
|
||||
}
|
||||
if sd.MultipathTCP() {
|
||||
c, err = sd.dialMPTCP(context.Background(), laddr, raddr)
|
||||
c, err = sd.dialMPTCP(ctx, laddr, raddr)
|
||||
} else {
|
||||
c, err = sd.dialTCP(context.Background(), laddr, raddr)
|
||||
c, err = sd.dialTCP(ctx, laddr, raddr)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
|
||||
|
|
|
|||
|
|
@ -285,6 +285,10 @@ func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} }
|
|||
// If the IP field of raddr is nil or an unspecified IP address, the
|
||||
// local system is assumed.
|
||||
func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
|
||||
return dialUDP(context.Background(), nil, network, laddr, raddr)
|
||||
}
|
||||
|
||||
func dialUDP(ctx context.Context, dialer *Dialer, network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
default:
|
||||
|
|
@ -294,7 +298,10 @@ func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
|
|||
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
|
||||
}
|
||||
sd := &sysDialer{network: network, address: raddr.String()}
|
||||
c, err := sd.dialUDP(context.Background(), laddr, raddr)
|
||||
if dialer != nil {
|
||||
sd.Dialer = *dialer
|
||||
}
|
||||
c, err := sd.dialUDP(ctx, laddr, raddr)
|
||||
if err != nil {
|
||||
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -201,13 +201,20 @@ func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} }
|
|||
// If laddr is non-nil, it is used as the local address for the
|
||||
// connection.
|
||||
func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
|
||||
return dialUnix(context.Background(), nil, network, laddr, raddr)
|
||||
}
|
||||
|
||||
func dialUnix(ctx context.Context, dialer *Dialer, network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
|
||||
switch network {
|
||||
case "unix", "unixgram", "unixpacket":
|
||||
default:
|
||||
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
|
||||
}
|
||||
sd := &sysDialer{network: network, address: raddr.String()}
|
||||
c, err := sd.dialUnix(context.Background(), laddr, raddr)
|
||||
if dialer != nil {
|
||||
sd.Dialer = *dialer
|
||||
}
|
||||
c, err := sd.dialUnix(ctx, laddr, raddr)
|
||||
if err != nil {
|
||||
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue