net: allow zero value destination address in WriteMsgUDPAddrPort

The existing address validity checks already cover both connected and
non-connected sockets. Pass a nil sockaddr just like WriteMsgUDP, when
the address is zero value.

TestWriteToUDP is extended to cover the netip APIs.

Fixes #74841

Change-Id: I2708e7747e224958198fe7abb3fcd8d59bc5a88a
Reviewed-on: https://go-review.googlesource.com/c/go/+/692437
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Sean Liao <sean@liao.dev>
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
database64128 2025-08-02 00:42:27 +08:00 committed by Sean Liao
parent afc51ed007
commit 49a2f3ed87
4 changed files with 47 additions and 33 deletions

View file

@ -260,9 +260,6 @@ func addrPortToSockaddrInet6(ap netip.AddrPort) (syscall.SockaddrInet6, error) {
// to an IPv4-mapped IPv6 address. // to an IPv4-mapped IPv6 address.
// The error message is kept consistent with ipToSockaddrInet6. // The error message is kept consistent with ipToSockaddrInet6.
addr := ap.Addr() addr := ap.Addr()
if !addr.IsValid() {
return syscall.SockaddrInet6{}, &AddrError{Err: "non-IPv6 address", Addr: addr.String()}
}
sa := syscall.SockaddrInet6{ sa := syscall.SockaddrInet6{
Addr: addr.As16(), Addr: addr.As16(),
Port: int(ap.Port()), Port: int(ap.Port()),

View file

@ -1082,11 +1082,19 @@ func (ffd *fakeNetFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int
return n, 0, err return n, 0, err
} }
func (ffd *fakeNetFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) { func (ffd *fakeNetFD) writeMsgInet4(p []byte, oob []byte, sa4 *syscall.SockaddrInet4) (n int, oobn int, err error) {
var sa syscall.Sockaddr
if sa4 != nil {
sa = sa4
}
return ffd.writeMsg(p, oob, sa) return ffd.writeMsg(p, oob, sa)
} }
func (ffd *fakeNetFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) { func (ffd *fakeNetFD) writeMsgInet6(p []byte, oob []byte, sa6 *syscall.SockaddrInet6) (n int, oobn int, err error) {
var sa syscall.Sockaddr
if sa6 != nil {
sa = sa6
}
return ffd.writeMsg(p, oob, sa) return ffd.writeMsg(p, oob, sa)
} }

View file

@ -186,17 +186,25 @@ func (c *UDPConn) writeMsgAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn
switch c.fd.family { switch c.fd.family {
case syscall.AF_INET: case syscall.AF_INET:
var sap *syscall.SockaddrInet4
if addr.IsValid() {
sa, err := addrPortToSockaddrInet4(addr) sa, err := addrPortToSockaddrInet4(addr)
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
return c.fd.writeMsgInet4(b, oob, &sa) sap = &sa
}
return c.fd.writeMsgInet4(b, oob, sap)
case syscall.AF_INET6: case syscall.AF_INET6:
var sap *syscall.SockaddrInet6
if addr.IsValid() {
sa, err := addrPortToSockaddrInet6(addr) sa, err := addrPortToSockaddrInet6(addr)
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
return c.fd.writeMsgInet6(b, oob, &sa) sap = &sa
}
return c.fd.writeMsgInet6(b, oob, sap)
default: default:
return 0, 0, &AddrError{Err: "invalid address family", Addr: addr.Addr().String()} return 0, 0, &AddrError{Err: "invalid address family", Addr: addr.Addr().String()}
} }

View file

@ -141,36 +141,37 @@ func testWriteToConn(t *testing.T, raddr string) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
rap := ra.AddrPort()
assertErrWriteToConnected := func(t *testing.T, err error) {
t.Helper()
if e, ok := err.(*OpError); !ok || e.Err != ErrWriteToConnected {
t.Errorf("got %v; want ErrWriteToConnected", err)
}
}
b := []byte("CONNECTED-MODE SOCKET") b := []byte("CONNECTED-MODE SOCKET")
_, err = c.(*UDPConn).WriteToUDPAddrPort(b, rap)
assertErrWriteToConnected(t, err)
_, err = c.(*UDPConn).WriteToUDP(b, ra) _, err = c.(*UDPConn).WriteToUDP(b, ra)
if err == nil { assertErrWriteToConnected(t, err)
t.Fatal("should fail")
}
if err != nil && err.(*OpError).Err != ErrWriteToConnected {
t.Fatalf("should fail as ErrWriteToConnected: %v", err)
}
_, err = c.(*UDPConn).WriteTo(b, ra) _, err = c.(*UDPConn).WriteTo(b, ra)
if err == nil { assertErrWriteToConnected(t, err)
t.Fatal("should fail")
}
if err != nil && err.(*OpError).Err != ErrWriteToConnected {
t.Fatalf("should fail as ErrWriteToConnected: %v", err)
}
_, err = c.Write(b) _, err = c.Write(b)
if err != nil { if err != nil {
t.Fatal(err) t.Errorf("c.Write(b) = %v; want nil", err)
} }
_, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, ra) _, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, ra)
if err == nil { assertErrWriteToConnected(t, err)
t.Fatal("should fail")
}
if err != nil && err.(*OpError).Err != ErrWriteToConnected {
t.Fatalf("should fail as ErrWriteToConnected: %v", err)
}
_, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, nil) _, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Errorf("c.WriteMsgUDP(b, nil, nil) = %v; want nil", err)
}
_, _, err = c.(*UDPConn).WriteMsgUDPAddrPort(b, nil, rap)
assertErrWriteToConnected(t, err)
_, _, err = c.(*UDPConn).WriteMsgUDPAddrPort(b, nil, netip.AddrPort{})
if err != nil {
t.Errorf("c.WriteMsgUDPAddrPort(b, nil, netip.AddrPort{}) = %v; want nil", err)
} }
} }