internal/syscall/windows: avoid uint16 overflow in NewNTUnicodeString

For #78916

Change-Id: I8d97059b66bea8dbcb2e6afbb455cbc76a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/770100
Auto-Submit: Damien Neil <dneil@google.com>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Reviewed-by: Quim Muntal <quimmuntal@gmail.com>
LUCI-TryBot-Result: golang-scoped@luci-project-accounts.iam.gserviceaccount.com <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Damien Neil 2026-04-22 17:16:31 -07:00 committed by Gopher Robot
parent a804e04b7e
commit a91e9fa1de
2 changed files with 57 additions and 3 deletions

View file

@ -22,11 +22,14 @@ func NewNTUnicodeString(s string) (*NTUnicodeString, error) {
if err != nil {
return nil, err
}
n := uint16(len(s16) * 2)
n := len(s16) * 2
if n > (1<<16)-1 {
return nil, syscall.EINVAL
}
// https://learn.microsoft.com/en-us/windows-hardware/drivers/ddi/wdmsec/nf-wdmsec-wdmlibrtlinitunicodestringex
return &NTUnicodeString{
Length: n - 2, // subtract 2 bytes for the NUL terminator
MaximumLength: n,
Length: uint16(n) - 2, // subtract 2 bytes for the NUL terminator
MaximumLength: uint16(n),
Buffer: &s16[0],
}, nil
}

View file

@ -0,0 +1,51 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package windows_test
import (
"internal/syscall/windows"
"math"
"strings"
"syscall"
"testing"
"unsafe"
)
func TestRoundtripNTUnicodeString(t *testing.T) {
// NTUnicodeString maximum string length must fit in a uint16, less for terminal NUL.
maxString := strings.Repeat("*", (math.MaxUint16/2)-1)
for _, test := range []struct {
s string
wantErr bool
}{{
s: "",
}, {
s: "hello",
}, {
s: "Ƀ",
}, {
s: maxString,
}, {
s: maxString + "*",
wantErr: true,
}, {
s: "a\x00a",
wantErr: true,
}} {
ntus, err := windows.NewNTUnicodeString(test.s)
if (err != nil) != test.wantErr {
t.Errorf("NewNTUnicodeString(%q): %v, wantErr:%v", test.s, err, test.wantErr)
continue
}
if err != nil {
continue
}
u16 := unsafe.Slice(ntus.Buffer, ntus.MaximumLength/2)[:ntus.Length/2]
s2 := syscall.UTF16ToString(u16)
if test.s != s2 {
t.Errorf("round trip of %q = %q, wanted original", test.s, s2)
}
}
}