net: use closesocket when closing socket os.File's on Windows

The WSASocket documentation states that the returned socket must be
closed by calling closesocket instead of CloseHandle. The different
File methods on the net package return an os.File that is not aware
that it should use closesocket. Ideally, os.NewFile should detect that
the passed handle is a socket and use the appropriate close function,
but there is no reliable way to detect that a handle is a socket on
Windows (see CL 671455).

To work around this, we add a hidden function to the os package that
can be used to return an os.File that uses closesocket. This approach
is the same as used on Unix, which also uses a hidden function for other
purposes.

While here, fix a potential issue with FileConn, which was using File.Fd
rather than File.SyscallConn to get the handle. This could result in the
File being closed and garbage collected before the syscall was made.

Fixes #73683.

Change-Id: I179405f34c63cbbd555d8119e0f77157c670eb3e
Reviewed-on: https://go-review.googlesource.com/c/go/+/672195
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
This commit is contained in:
qmuntal 2025-05-13 13:31:22 +02:00 committed by Quim Muntal
parent a197a471b9
commit 3be537e663
5 changed files with 236 additions and 195 deletions

View file

@ -318,7 +318,7 @@ type FD struct {
// message based socket connection. // message based socket connection.
ZeroReadIsEOF bool ZeroReadIsEOF bool
// Whether this is a file rather than a network socket. // Whether the handle is owned by os.File.
isFile bool isFile bool
// The kind of this file. // The kind of this file.
@ -368,6 +368,7 @@ const (
kindFile kindFile
kindConsole kindConsole
kindPipe kindPipe
kindFileNet
) )
// Init initializes the FD. The Sysfd field should already be set. // Init initializes the FD. The Sysfd field should already be set.
@ -388,6 +389,8 @@ func (fd *FD) Init(net string, pollable bool) error {
fd.kind = kindConsole fd.kind = kindConsole
case "pipe": case "pipe":
fd.kind = kindPipe fd.kind = kindPipe
case "file+net":
fd.kind = kindFileNet
default: default:
// We don't actually care about the various network types. // We don't actually care about the various network types.
fd.kind = kindNet fd.kind = kindNet
@ -453,7 +456,7 @@ func (fd *FD) destroy() error {
fd.pd.close() fd.pd.close()
var err error var err error
switch fd.kind { switch fd.kind {
case kindNet: case kindNet, kindFileNet:
// The net package uses the CloseFunc variable for testing. // The net package uses the CloseFunc variable for testing.
err = CloseFunc(fd.Sysfd) err = CloseFunc(fd.Sysfd)
default: default:
@ -494,7 +497,7 @@ func (fd *FD) Read(buf []byte) (int, error) {
return 0, err return 0, err
} }
defer fd.readUnlock() defer fd.readUnlock()
if fd.isFile { if fd.kind == kindFile {
fd.l.Lock() fd.l.Lock()
defer fd.l.Unlock() defer fd.l.Unlock()
} }
@ -747,7 +750,7 @@ func (fd *FD) Write(buf []byte) (int, error) {
return 0, err return 0, err
} }
defer fd.writeUnlock() defer fd.writeUnlock()
if fd.isFile { if fd.kind == kindFile {
fd.l.Lock() fd.l.Lock()
defer fd.l.Unlock() defer fd.l.Unlock()
} }

View file

@ -233,6 +233,9 @@ func (fd *netFD) accept() (*netFD, error) {
return netfd, nil return netfd, nil
} }
// Defined in os package.
func newWindowsFile(h syscall.Handle, name string) *os.File
func (fd *netFD) dup() (*os.File, error) { func (fd *netFD) dup() (*os.File, error) {
// Disassociate the IOCP from the socket, // Disassociate the IOCP from the socket,
// it is not safe to share a duplicated handle // it is not safe to share a duplicated handle
@ -251,5 +254,8 @@ func (fd *netFD) dup() (*os.File, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return os.NewFile(uintptr(h), fd.name()), nil // All WSASocket calls must be match with a syscall.Closesocket call,
// but os.NewFile calls syscall.CloseHandle instead. We need to use
// a hidden function so that the returned file is aware of this fact.
return newWindowsFile(h, fd.name()), nil
} }

View file

@ -34,89 +34,90 @@ func TestFileConn(t *testing.T) {
} }
for _, tt := range fileConnTests { for _, tt := range fileConnTests {
if !testableNetwork(tt.network) { t.Run(tt.network, func(t *testing.T) {
t.Logf("skipping %s test", tt.network) if !testableNetwork(tt.network) {
continue t.Skipf("skipping %s test", tt.network)
}
var network, address string
switch tt.network {
case "udp":
c := newLocalPacketListener(t, tt.network)
defer c.Close()
network = c.LocalAddr().Network()
address = c.LocalAddr().String()
default:
handler := func(ls *localServer, ln Listener) {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
var b [1]byte
c.Read(b[:])
} }
ls := newLocalServer(t, tt.network)
defer ls.teardown() var network, address string
if err := ls.buildup(handler); err != nil { switch tt.network {
case "udp":
c := newLocalPacketListener(t, tt.network)
defer c.Close()
network = c.LocalAddr().Network()
address = c.LocalAddr().String()
default:
handler := func(ls *localServer, ln Listener) {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
var b [1]byte
c.Read(b[:])
}
ls := newLocalServer(t, tt.network)
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
}
network = ls.Listener.Addr().Network()
address = ls.Listener.Addr().String()
}
c1, err := Dial(network, address)
if err != nil {
if perr := parseDialError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err) t.Fatal(err)
} }
network = ls.Listener.Addr().Network() addr := c1.LocalAddr()
address = ls.Listener.Addr().String()
}
c1, err := Dial(network, address) var f *os.File
if err != nil { switch c1 := c1.(type) {
if perr := parseDialError(err); perr != nil { case *TCPConn:
t.Error(perr) f, err = c1.File()
case *UDPConn:
f, err = c1.File()
case *UnixConn:
f, err = c1.File()
}
if err := c1.Close(); err != nil {
if perr := parseCloseError(err, false); perr != nil {
t.Error(perr)
}
t.Error(err)
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
} }
t.Fatal(err)
}
addr := c1.LocalAddr()
var f *os.File c2, err := FileConn(f)
switch c1 := c1.(type) { if err := f.Close(); err != nil {
case *TCPConn: t.Error(err)
f, err = c1.File()
case *UDPConn:
f, err = c1.File()
case *UnixConn:
f, err = c1.File()
}
if err := c1.Close(); err != nil {
if perr := parseCloseError(err, false); perr != nil {
t.Error(perr)
} }
t.Error(err) if err != nil {
} if perr := parseCommonError(err); perr != nil {
if err != nil { t.Error(perr)
if perr := parseCommonError(err); perr != nil { }
t.Error(perr) t.Fatal(err)
} }
t.Fatal(err) defer c2.Close()
}
c2, err := FileConn(f) if _, err := c2.Write([]byte("FILECONN TEST")); err != nil {
if err := f.Close(); err != nil { if perr := parseWriteError(err); perr != nil {
t.Error(err) t.Error(perr)
} }
if err != nil { t.Fatal(err)
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
} }
t.Fatal(err) if !reflect.DeepEqual(c2.LocalAddr(), addr) {
} t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
defer c2.Close()
if _, err := c2.Write([]byte("FILECONN TEST")); err != nil {
if perr := parseWriteError(err); perr != nil {
t.Error(perr)
} }
t.Fatal(err) })
}
if !reflect.DeepEqual(c2.LocalAddr(), addr) {
t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
}
} }
} }
@ -135,81 +136,82 @@ func TestFileListener(t *testing.T) {
} }
for _, tt := range fileListenerTests { for _, tt := range fileListenerTests {
if !testableNetwork(tt.network) { t.Run(tt.network, func(t *testing.T) {
t.Logf("skipping %s test", tt.network) if !testableNetwork(tt.network) {
continue t.Skipf("skipping %s test", tt.network)
}
ln1 := newLocalListener(t, tt.network)
switch tt.network {
case "unix", "unixpacket":
defer os.Remove(ln1.Addr().String())
}
addr := ln1.Addr()
var (
f *os.File
err error
)
switch ln1 := ln1.(type) {
case *TCPListener:
f, err = ln1.File()
case *UnixListener:
f, err = ln1.File()
}
switch tt.network {
case "unix", "unixpacket":
defer ln1.Close() // UnixListener.Close calls syscall.Unlink internally
default:
if err := ln1.Close(); err != nil {
t.Error(err)
} }
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
ln2, err := FileListener(f) ln1 := newLocalListener(t, tt.network)
if err := f.Close(); err != nil { switch tt.network {
t.Error(err) case "unix", "unixpacket":
} defer os.Remove(ln1.Addr().String())
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
} }
t.Fatal(err) addr := ln1.Addr()
}
defer ln2.Close()
var wg sync.WaitGroup var (
wg.Add(1) f *os.File
go func() { err error
defer wg.Done() )
c, err := Dial(ln2.Addr().Network(), ln2.Addr().String()) switch ln1 := ln1.(type) {
case *TCPListener:
f, err = ln1.File()
case *UnixListener:
f, err = ln1.File()
}
switch tt.network {
case "unix", "unixpacket":
defer ln1.Close() // UnixListener.Close calls syscall.Unlink internally
default:
if err := ln1.Close(); err != nil {
t.Error(err)
}
}
if err != nil { if err != nil {
if perr := parseDialError(err); perr != nil { if perr := parseCommonError(err); perr != nil {
t.Error(perr) t.Error(perr)
} }
t.Fatal(err)
}
ln2, err := FileListener(f)
if err := f.Close(); err != nil {
t.Error(err) t.Error(err)
return }
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
defer ln2.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := Dial(ln2.Addr().Network(), ln2.Addr().String())
if err != nil {
if perr := parseDialError(err); perr != nil {
t.Error(perr)
}
t.Error(err)
return
}
c.Close()
}()
c, err := ln2.Accept()
if err != nil {
if perr := parseAcceptError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
} }
c.Close() c.Close()
}() wg.Wait()
c, err := ln2.Accept() if !reflect.DeepEqual(ln2.Addr(), addr) {
if err != nil { t.Fatalf("got %#v; want %#v", ln2.Addr(), addr)
if perr := parseAcceptError(err); perr != nil {
t.Error(perr)
} }
t.Fatal(err) })
}
c.Close()
wg.Wait()
if !reflect.DeepEqual(ln2.Addr(), addr) {
t.Fatalf("got %#v; want %#v", ln2.Addr(), addr)
}
} }
} }
@ -227,62 +229,63 @@ func TestFilePacketConn(t *testing.T) {
} }
for _, tt := range filePacketConnTests { for _, tt := range filePacketConnTests {
if !testableNetwork(tt.network) { t.Run(tt.network, func(t *testing.T) {
t.Logf("skipping %s test", tt.network) if !testableNetwork(tt.network) {
continue t.Skipf("skipping %s test", tt.network)
}
c1 := newLocalPacketListener(t, tt.network)
switch tt.network {
case "unixgram":
defer os.Remove(c1.LocalAddr().String())
}
addr := c1.LocalAddr()
var (
f *os.File
err error
)
switch c1 := c1.(type) {
case *UDPConn:
f, err = c1.File()
case *UnixConn:
f, err = c1.File()
}
if err := c1.Close(); err != nil {
if perr := parseCloseError(err, false); perr != nil {
t.Error(perr)
} }
t.Error(err)
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
c2, err := FilePacketConn(f) c1 := newLocalPacketListener(t, tt.network)
if err := f.Close(); err != nil { switch tt.network {
t.Error(err) case "unixgram":
} defer os.Remove(c1.LocalAddr().String())
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
} }
t.Fatal(err) addr := c1.LocalAddr()
}
defer c2.Close()
if _, err := c2.WriteTo([]byte("FILEPACKETCONN TEST"), addr); err != nil { var (
if perr := parseWriteError(err); perr != nil { f *os.File
t.Error(perr) err error
)
switch c1 := c1.(type) {
case *UDPConn:
f, err = c1.File()
case *UnixConn:
f, err = c1.File()
} }
t.Fatal(err) if err := c1.Close(); err != nil {
} if perr := parseCloseError(err, false); perr != nil {
if !reflect.DeepEqual(c2.LocalAddr(), addr) { t.Error(perr)
t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr) }
} t.Error(err)
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
c2, err := FilePacketConn(f)
if err := f.Close(); err != nil {
t.Error(err)
}
if err != nil {
if perr := parseCommonError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
defer c2.Close()
if _, err := c2.WriteTo([]byte("FILEPACKETCONN TEST"), addr); err != nil {
if perr := parseWriteError(err); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
if !reflect.DeepEqual(c2.LocalAddr(), addr) {
t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
}
})
} }
} }

View file

@ -22,9 +22,26 @@ func dupSocket(h syscall.Handle) (syscall.Handle, error) {
} }
func dupFileSocket(f *os.File) (syscall.Handle, error) { func dupFileSocket(f *os.File) (syscall.Handle, error) {
// The resulting handle should not be associated to an IOCP, else the IO operations // Call Fd to disassociate the IOCP from the handle,
// will block an OS thread, and that's not what net package users expect. // it is not safe to share a duplicated handle
h, err := dupSocket(syscall.Handle(f.Fd())) // that is associated with IOCP.
// Don't use the returned fd, as it might be closed
// if f happens to be the last reference to the file.
f.Fd()
sc, err := f.SyscallConn()
if err != nil {
return 0, err
}
var h syscall.Handle
var syserr error
err = sc.Control(func(fd uintptr) {
h, syserr = dupSocket(syscall.Handle(fd))
})
if err != nil {
err = syserr
}
if err != nil { if err != nil {
return 0, err return 0, err
} }

View file

@ -92,6 +92,18 @@ func newFileFromNewFile(fd uintptr, name string) *File {
return newFile(h, name, "file", nonBlocking) return newFile(h, name, "file", nonBlocking)
} }
// net_newWindowsFile is a hidden entry point called by net.conn.File.
// This is used so that the File.pfd.close method calls [syscall.Closesocket]
// instead of [syscall.CloseHandle].
//
//go:linkname net_newWindowsFile net.newWindowsFile
func net_newWindowsFile(h syscall.Handle, name string) *File {
if h == syscall.InvalidHandle {
panic("invalid FD")
}
return newFile(h, name, "file+net", true)
}
func epipecheck(file *File, e error) { func epipecheck(file *File, e error) {
} }