net: use closesocket when closing socket os.File's on Windows

The WSASocket documentation states that the returned socket must be
closed by calling closesocket instead of CloseHandle. The different
File methods on the net package return an os.File that is not aware
that it should use closesocket. Ideally, os.NewFile should detect that
the passed handle is a socket and use the appropriate close function,
but there is no reliable way to detect that a handle is a socket on
Windows (see CL 671455).

To work around this, we add a hidden function to the os package that
can be used to return an os.File that uses closesocket. This approach
is the same as used on Unix, which also uses a hidden function for other
purposes.

While here, fix a potential issue with FileConn, which was using File.Fd
rather than File.SyscallConn to get the handle. This could result in the
File being closed and garbage collected before the syscall was made.

Fixes #73683.

Change-Id: I179405f34c63cbbd555d8119e0f77157c670eb3e
Reviewed-on: https://go-review.googlesource.com/c/go/+/672195
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
This commit is contained in:
qmuntal 2025-05-13 13:31:22 +02:00 committed by Quim Muntal
parent a197a471b9
commit 3be537e663
5 changed files with 236 additions and 195 deletions

View file

@ -318,7 +318,7 @@ type FD struct {
// message based socket connection.
ZeroReadIsEOF bool
// Whether this is a file rather than a network socket.
// Whether the handle is owned by os.File.
isFile bool
// The kind of this file.
@ -368,6 +368,7 @@ const (
kindFile
kindConsole
kindPipe
kindFileNet
)
// Init initializes the FD. The Sysfd field should already be set.
@ -388,6 +389,8 @@ func (fd *FD) Init(net string, pollable bool) error {
fd.kind = kindConsole
case "pipe":
fd.kind = kindPipe
case "file+net":
fd.kind = kindFileNet
default:
// We don't actually care about the various network types.
fd.kind = kindNet
@ -453,7 +456,7 @@ func (fd *FD) destroy() error {
fd.pd.close()
var err error
switch fd.kind {
case kindNet:
case kindNet, kindFileNet:
// The net package uses the CloseFunc variable for testing.
err = CloseFunc(fd.Sysfd)
default:
@ -494,7 +497,7 @@ func (fd *FD) Read(buf []byte) (int, error) {
return 0, err
}
defer fd.readUnlock()
if fd.isFile {
if fd.kind == kindFile {
fd.l.Lock()
defer fd.l.Unlock()
}
@ -747,7 +750,7 @@ func (fd *FD) Write(buf []byte) (int, error) {
return 0, err
}
defer fd.writeUnlock()
if fd.isFile {
if fd.kind == kindFile {
fd.l.Lock()
defer fd.l.Unlock()
}

View file

@ -233,6 +233,9 @@ func (fd *netFD) accept() (*netFD, error) {
return netfd, nil
}
// Defined in os package.
func newWindowsFile(h syscall.Handle, name string) *os.File
func (fd *netFD) dup() (*os.File, error) {
// Disassociate the IOCP from the socket,
// it is not safe to share a duplicated handle
@ -251,5 +254,8 @@ func (fd *netFD) dup() (*os.File, error) {
if err != nil {
return nil, err
}
return os.NewFile(uintptr(h), fd.name()), nil
// All WSASocket calls must be match with a syscall.Closesocket call,
// but os.NewFile calls syscall.CloseHandle instead. We need to use
// a hidden function so that the returned file is aware of this fact.
return newWindowsFile(h, fd.name()), nil
}

View file

@ -34,89 +34,90 @@ func TestFileConn(t *testing.T) {
}
for _, tt := range fileConnTests {
if !testableNetwork(tt.network) {
t.Logf("skipping %s test", tt.network)
continue
}
var network, address string
switch tt.network {
case "udp":
c := newLocalPacketListener(t, tt.network)
defer c.Close()
network = c.LocalAddr().Network()
address = c.LocalAddr().String()
default:
handler := func(ls *localServer, ln Listener) {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
var b [1]byte
c.Read(b[:])
t.Run(tt.network, func(t *testing.T) {
if !testableNetwork(tt.network) {
t.Skipf("skipping %s test", tt.network)
}
ls := newLocalServer(t, tt.network)
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
var network, address string
switch tt.network {
case "udp":
c := newLocalPacketListener(t, tt.network)
defer c.Close()
network = c.LocalAddr().Network()
address = c.LocalAddr().String()
default:
handler := func(ls *localServer, ln Listener) {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
var b [1]byte
c.Read(b[:])
}
ls := newLocalServer(t, tt.network)
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
}
network = ls.Listener.Addr().Network()
address = ls.Listener.Addr().String()
}
c1, err := Dial(network, address)
if err != nil {
if perr := parseDialError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
network = ls.Listener.Addr().Network()
address = ls.Listener.Addr().String()
}
addr := c1.LocalAddr()
c1, err := Dial(network, address)
if err != nil {
if perr := parseDialError(err); perr != nil {
t.Error(perr)
var f *os.File
switch c1 := c1.(type) {
case *TCPConn:
f, err = c1.File()
case *UDPConn:
f, err = c1.File()
case *UnixConn:
f, err = c1.File()
}
if err := c1.Close(); err != nil {
if perr := parseCloseError(err, false); perr != nil {
t.Error(perr)
}
t.Error(err)
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
t.Fatal(err)
}
addr := c1.LocalAddr()
var f *os.File
switch c1 := c1.(type) {
case *TCPConn:
f, err = c1.File()
case *UDPConn:
f, err = c1.File()
case *UnixConn:
f, err = c1.File()
}
if err := c1.Close(); err != nil {
if perr := parseCloseError(err, false); perr != nil {
t.Error(perr)
c2, err := FileConn(f)
if err := f.Close(); err != nil {
t.Error(err)
}
t.Error(err)
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
t.Fatal(err)
}
defer c2.Close()
c2, err := FileConn(f)
if err := f.Close(); err != nil {
t.Error(err)
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
if _, err := c2.Write([]byte("FILECONN TEST")); err != nil {
if perr := parseWriteError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
t.Fatal(err)
}
defer c2.Close()
if _, err := c2.Write([]byte("FILECONN TEST")); err != nil {
if perr := parseWriteError(err); perr != nil {
t.Error(perr)
if !reflect.DeepEqual(c2.LocalAddr(), addr) {
t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
}
t.Fatal(err)
}
if !reflect.DeepEqual(c2.LocalAddr(), addr) {
t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
}
})
}
}
@ -135,81 +136,82 @@ func TestFileListener(t *testing.T) {
}
for _, tt := range fileListenerTests {
if !testableNetwork(tt.network) {
t.Logf("skipping %s test", tt.network)
continue
}
ln1 := newLocalListener(t, tt.network)
switch tt.network {
case "unix", "unixpacket":
defer os.Remove(ln1.Addr().String())
}
addr := ln1.Addr()
var (
f *os.File
err error
)
switch ln1 := ln1.(type) {
case *TCPListener:
f, err = ln1.File()
case *UnixListener:
f, err = ln1.File()
}
switch tt.network {
case "unix", "unixpacket":
defer ln1.Close() // UnixListener.Close calls syscall.Unlink internally
default:
if err := ln1.Close(); err != nil {
t.Error(err)
t.Run(tt.network, func(t *testing.T) {
if !testableNetwork(tt.network) {
t.Skipf("skipping %s test", tt.network)
}
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
ln2, err := FileListener(f)
if err := f.Close(); err != nil {
t.Error(err)
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
ln1 := newLocalListener(t, tt.network)
switch tt.network {
case "unix", "unixpacket":
defer os.Remove(ln1.Addr().String())
}
t.Fatal(err)
}
defer ln2.Close()
addr := ln1.Addr()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := Dial(ln2.Addr().Network(), ln2.Addr().String())
var (
f *os.File
err error
)
switch ln1 := ln1.(type) {
case *TCPListener:
f, err = ln1.File()
case *UnixListener:
f, err = ln1.File()
}
switch tt.network {
case "unix", "unixpacket":
defer ln1.Close() // UnixListener.Close calls syscall.Unlink internally
default:
if err := ln1.Close(); err != nil {
t.Error(err)
}
}
if err != nil {
if perr := parseDialError(err); perr != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
ln2, err := FileListener(f)
if err := f.Close(); err != nil {
t.Error(err)
return
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
defer ln2.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := Dial(ln2.Addr().Network(), ln2.Addr().String())
if err != nil {
if perr := parseDialError(err); perr != nil {
t.Error(perr)
}
t.Error(err)
return
}
c.Close()
}()
c, err := ln2.Accept()
if err != nil {
if perr := parseAcceptError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
c.Close()
}()
c, err := ln2.Accept()
if err != nil {
if perr := parseAcceptError(err); perr != nil {
t.Error(perr)
wg.Wait()
if !reflect.DeepEqual(ln2.Addr(), addr) {
t.Fatalf("got %#v; want %#v", ln2.Addr(), addr)
}
t.Fatal(err)
}
c.Close()
wg.Wait()
if !reflect.DeepEqual(ln2.Addr(), addr) {
t.Fatalf("got %#v; want %#v", ln2.Addr(), addr)
}
})
}
}
@ -227,62 +229,63 @@ func TestFilePacketConn(t *testing.T) {
}
for _, tt := range filePacketConnTests {
if !testableNetwork(tt.network) {
t.Logf("skipping %s test", tt.network)
continue
}
c1 := newLocalPacketListener(t, tt.network)
switch tt.network {
case "unixgram":
defer os.Remove(c1.LocalAddr().String())
}
addr := c1.LocalAddr()
var (
f *os.File
err error
)
switch c1 := c1.(type) {
case *UDPConn:
f, err = c1.File()
case *UnixConn:
f, err = c1.File()
}
if err := c1.Close(); err != nil {
if perr := parseCloseError(err, false); perr != nil {
t.Error(perr)
t.Run(tt.network, func(t *testing.T) {
if !testableNetwork(tt.network) {
t.Skipf("skipping %s test", tt.network)
}
t.Error(err)
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
c2, err := FilePacketConn(f)
if err := f.Close(); err != nil {
t.Error(err)
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
c1 := newLocalPacketListener(t, tt.network)
switch tt.network {
case "unixgram":
defer os.Remove(c1.LocalAddr().String())
}
t.Fatal(err)
}
defer c2.Close()
addr := c1.LocalAddr()
if _, err := c2.WriteTo([]byte("FILEPACKETCONN TEST"), addr); err != nil {
if perr := parseWriteError(err); perr != nil {
t.Error(perr)
var (
f *os.File
err error
)
switch c1 := c1.(type) {
case *UDPConn:
f, err = c1.File()
case *UnixConn:
f, err = c1.File()
}
t.Fatal(err)
}
if !reflect.DeepEqual(c2.LocalAddr(), addr) {
t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
}
if err := c1.Close(); err != nil {
if perr := parseCloseError(err, false); perr != nil {
t.Error(perr)
}
t.Error(err)
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
c2, err := FilePacketConn(f)
if err := f.Close(); err != nil {
t.Error(err)
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
defer c2.Close()
if _, err := c2.WriteTo([]byte("FILEPACKETCONN TEST"), addr); err != nil {
if perr := parseWriteError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
if !reflect.DeepEqual(c2.LocalAddr(), addr) {
t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
}
})
}
}

View file

@ -22,9 +22,26 @@ func dupSocket(h syscall.Handle) (syscall.Handle, error) {
}
func dupFileSocket(f *os.File) (syscall.Handle, error) {
// The resulting handle should not be associated to an IOCP, else the IO operations
// will block an OS thread, and that's not what net package users expect.
h, err := dupSocket(syscall.Handle(f.Fd()))
// Call Fd to disassociate the IOCP from the handle,
// it is not safe to share a duplicated handle
// that is associated with IOCP.
// Don't use the returned fd, as it might be closed
// if f happens to be the last reference to the file.
f.Fd()
sc, err := f.SyscallConn()
if err != nil {
return 0, err
}
var h syscall.Handle
var syserr error
err = sc.Control(func(fd uintptr) {
h, syserr = dupSocket(syscall.Handle(fd))
})
if err != nil {
err = syserr
}
if err != nil {
return 0, err
}

View file

@ -92,6 +92,18 @@ func newFileFromNewFile(fd uintptr, name string) *File {
return newFile(h, name, "file", nonBlocking)
}
// net_newWindowsFile is a hidden entry point called by net.conn.File.
// This is used so that the File.pfd.close method calls [syscall.Closesocket]
// instead of [syscall.CloseHandle].
//
//go:linkname net_newWindowsFile net.newWindowsFile
func net_newWindowsFile(h syscall.Handle, name string) *File {
if h == syscall.InvalidHandle {
panic("invalid FD")
}
return newFile(h, name, "file+net", true)
}
func epipecheck(file *File, e error) {
}