internal/poll: don't use stack-allocated WSAMsg parameters

WSAMsg parameters should be passed to Windows as heap pointers instead
of stack pointers. This is because Windows might access the memory
after the syscall returned in case of a non-blocking operation (which
is the common case), and if the WSAMsg is on the stack, the Go
runtime might have moved it around.

Use a sync.Pool to cache WSAMsg structures to avoid a heap allocation
every time a WSAMsg is needed.

Fixes #74933

Cq-Include-Trybots: luci.golang.try:x_net-gotip-windows-amd64
Change-Id: I075e2ceb25cd545224ab3a10d404340faf19fc01
Reviewed-on: https://go-review.googlesource.com/c/go/+/698797
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
qmuntal 2025-08-25 09:57:49 +02:00 committed by Quim Muntal
parent dae9e456ae
commit bfd130db02

View file

@ -144,19 +144,47 @@ func (o *operation) ClearBufs() {
o.bufs = o.bufs[:0] o.bufs = o.bufs[:0]
} }
func newWSAMsg(p []byte, oob []byte, flags int) windows.WSAMsg { // wsaMsgPool is a pool of WSAMsg structures that can only hold a single WSABuf.
return windows.WSAMsg{ var wsaMsgPool = sync.Pool{
Buffers: &syscall.WSABuf{ New: func() any {
Len: uint32(len(p)), return &windows.WSAMsg{
Buf: unsafe.SliceData(p), Buffers: &syscall.WSABuf{},
},
BufferCount: 1, BufferCount: 1,
Control: syscall.WSABuf{ }
},
}
// newWSAMsg creates a new WSAMsg with the provided parameters.
// Use [freeWSAMsg] to free it.
func newWSAMsg(p []byte, oob []byte, flags int, rsa *syscall.RawSockaddrAny) *windows.WSAMsg {
// The returned object can't be allocated in the stack because it is accessed asynchronously
// by Windows in between several system calls. If the stack frame is moved while that happens,
// then Windows may access invalid memory.
// TODO(qmuntal): investigate using runtime.Pinner keeping this path allocation-free.
// Use a pool to reuse allocations.
msg := wsaMsgPool.Get().(*windows.WSAMsg)
msg.Buffers.Len = uint32(len(p))
msg.Buffers.Buf = unsafe.SliceData(p)
msg.Control = syscall.WSABuf{
Len: uint32(len(oob)), Len: uint32(len(oob)),
Buf: unsafe.SliceData(oob), Buf: unsafe.SliceData(oob),
},
Flags: uint32(flags),
} }
msg.Flags = uint32(flags)
msg.Name = syscall.Pointer(unsafe.Pointer(rsa))
if rsa != nil {
msg.Namelen = int32(unsafe.Sizeof(*rsa))
} else {
msg.Namelen = 0
}
return msg
}
func freeWSAMsg(msg *windows.WSAMsg) {
// Clear pointers to buffers so they can be released by garbage collector.
msg.Buffers.Len = 0
msg.Buffers.Buf = nil
wsaMsgPool.Put(msg)
} }
// waitIO waits for the IO operation o to complete. // waitIO waits for the IO operation o to complete.
@ -1297,11 +1325,10 @@ func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.S
if o.rsa == nil { if o.rsa == nil {
o.rsa = new(syscall.RawSockaddrAny) o.rsa = new(syscall.RawSockaddrAny)
} }
msg := newWSAMsg(p, oob, flags) msg := newWSAMsg(p, oob, flags, o.rsa)
msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa)) defer freeWSAMsg(msg)
msg.Namelen = int32(unsafe.Sizeof(*o.rsa))
n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) { n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) {
err = windows.WSARecvMsg(fd.Sysfd, &msg, &qty, &o.o, nil) err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil)
return qty, err return qty, err
}) })
err = fd.eofError(n, err) err = fd.eofError(n, err)
@ -1327,11 +1354,10 @@ func (fd *FD) ReadMsgInet4(p []byte, oob []byte, flags int, sa4 *syscall.Sockadd
if o.rsa == nil { if o.rsa == nil {
o.rsa = new(syscall.RawSockaddrAny) o.rsa = new(syscall.RawSockaddrAny)
} }
msg := newWSAMsg(p, oob, flags) msg := newWSAMsg(p, oob, flags, o.rsa)
msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa)) defer freeWSAMsg(msg)
msg.Namelen = int32(unsafe.Sizeof(*o.rsa))
n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) { n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) {
err = windows.WSARecvMsg(fd.Sysfd, &msg, &qty, &o.o, nil) err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil)
return qty, err return qty, err
}) })
err = fd.eofError(n, err) err = fd.eofError(n, err)
@ -1356,11 +1382,10 @@ func (fd *FD) ReadMsgInet6(p []byte, oob []byte, flags int, sa6 *syscall.Sockadd
if o.rsa == nil { if o.rsa == nil {
o.rsa = new(syscall.RawSockaddrAny) o.rsa = new(syscall.RawSockaddrAny)
} }
msg := newWSAMsg(p, oob, flags) msg := newWSAMsg(p, oob, flags, o.rsa)
msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa)) defer freeWSAMsg(msg)
msg.Namelen = int32(unsafe.Sizeof(*o.rsa))
n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) { n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) {
err = windows.WSARecvMsg(fd.Sysfd, &msg, &qty, &o.o, nil) err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil)
return qty, err return qty, err
}) })
err = fd.eofError(n, err) err = fd.eofError(n, err)
@ -1382,20 +1407,20 @@ func (fd *FD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (int, int, err
defer fd.writeUnlock() defer fd.writeUnlock()
o := &fd.wop o := &fd.wop
msg := newWSAMsg(p, oob, 0) if sa != nil && o.rsa == nil {
if sa != nil {
if o.rsa == nil {
o.rsa = new(syscall.RawSockaddrAny) o.rsa = new(syscall.RawSockaddrAny)
} }
len, err := sockaddrToRaw(o.rsa, sa) msg := newWSAMsg(p, oob, 0, o.rsa)
defer freeWSAMsg(msg)
if sa != nil {
var err error
msg.Namelen, err = sockaddrToRaw(o.rsa, sa)
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa))
msg.Namelen = len
} }
n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) { n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) {
err = windows.WSASendMsg(fd.Sysfd, &msg, 0, nil, &o.o, nil) err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil)
return qty, err return qty, err
}) })
return n, int(msg.Control.Len), err return n, int(msg.Control.Len), err
@ -1413,16 +1438,16 @@ func (fd *FD) WriteMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (in
defer fd.writeUnlock() defer fd.writeUnlock()
o := &fd.wop o := &fd.wop
msg := newWSAMsg(p, oob, 0) if sa != nil && o.rsa == nil {
if sa != nil {
if o.rsa == nil {
o.rsa = new(syscall.RawSockaddrAny) o.rsa = new(syscall.RawSockaddrAny)
} }
msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa)) msg := newWSAMsg(p, oob, 0, o.rsa)
defer freeWSAMsg(msg)
if sa != nil {
msg.Namelen = sockaddrInet4ToRaw(o.rsa, sa) msg.Namelen = sockaddrInet4ToRaw(o.rsa, sa)
} }
n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) { n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) {
err = windows.WSASendMsg(fd.Sysfd, &msg, 0, nil, &o.o, nil) err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil)
return qty, err return qty, err
}) })
return n, int(msg.Control.Len), err return n, int(msg.Control.Len), err
@ -1440,16 +1465,16 @@ func (fd *FD) WriteMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (in
defer fd.writeUnlock() defer fd.writeUnlock()
o := &fd.wop o := &fd.wop
msg := newWSAMsg(p, oob, 0) if sa != nil && o.rsa == nil {
if sa != nil {
if o.rsa == nil {
o.rsa = new(syscall.RawSockaddrAny) o.rsa = new(syscall.RawSockaddrAny)
} }
msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa)) msg := newWSAMsg(p, oob, 0, o.rsa)
defer freeWSAMsg(msg)
if sa != nil {
msg.Namelen = sockaddrInet6ToRaw(o.rsa, sa) msg.Namelen = sockaddrInet6ToRaw(o.rsa, sa)
} }
n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) { n, err := fd.execIO(o, func(o *operation) (qty uint32, err error) {
err = windows.WSASendMsg(fd.Sysfd, &msg, 0, nil, &o.o, nil) err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil)
return qty, err return qty, err
}) })
return n, int(msg.Control.Len), err return n, int(msg.Control.Len), err