[dev.simd] simd/_gen: migrate simdgen from x/arch

This moves the simdgen tool and its supporting unify package from
golang.org/x/arch/internal as of CL 695619 to simd/_gen in the main repo.

The simdgen tool was started in x/arch to live next to xeddata and a
few other assembler generators that already lived there. However, as
we've been developing simdgen, we've discovered that there's a
tremendous amount of process friction coordinating commits to x/arch
with the corresponding generated files in the main repo.

Many of the existing generators in x/arch were started before modules
existed. In GOPATH world, it was impractical for them to live in the
main repo because they have dependencies that are not allowed in the
main repo. However, now that we have modules and can use small
submodules in the main repo, we can isolate these dependencies to just
the generators, making it practical for them to live in the main repo.

This commit was generated by the following script:

	# Checks
	set -e
	if [[ ! -d src/simd ]]; then
	    echo >&2 "$PWD is not the root of the main repo on dev.simd"
	    exit 1
	fi
	if [[ -z "$XEDDATA" ]]; then
	    echo >&2 "Must set \$XEDDATA"
	    exit 1
	fi
	which go >/dev/null

	# Move simdgen from x/arch
	xarch=$(mktemp -d)
	git clone https://go.googlesource.com/arch $xarch
	xarchCL=$(git -C $xarch log -1 --format=%b | awk -F/ '/^Reviewed-on:/ {print $NF}')
	echo >&2 "x/arch CL: $xarchCL"
	mv $xarch/internal src/simd/_gen
	sed --in-place s,golang.org/x/arch/internal/,simd/_gen/, src/simd/_gen/*/*.go
	# Create self-contained module
	cat > src/simd/_gen/go.mod <<EOF
	module simd/_gen

	go 1.24
	EOF
	cd src/simd/_gen
	go mod tidy
	git add .
	git gofmt
	# Regenerate file
	go run -C simdgen . -xedPath $XEDDATA -o godefs -goroot $(go env GOROOT) go.yaml types.yaml categories.yaml
	go run -C ../../cmd/compile/internal/ssa/_gen .

Change-Id: I56dd8473e913a9eb1978d9b3b3518ed632972f6f
Reviewed-on: https://go-review.googlesource.com/c/go/+/695975
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
This commit is contained in:
Austin Clements 2025-08-13 15:30:27 -04:00
parent 257c1356ec
commit b7c8698549
60 changed files with 9083 additions and 0 deletions

8
src/simd/_gen/go.mod Normal file
View file

@ -0,0 +1,8 @@
module simd/_gen
go 1.24
require (
golang.org/x/arch v0.20.0
gopkg.in/yaml.v3 v3.0.1
)

6
src/simd/_gen/go.sum Normal file
View file

@ -0,0 +1,6 @@
golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c=
golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

3
src/simd/_gen/simdgen/.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
testdata/*
.gemini/*
.gemini*

View file

@ -0,0 +1,107 @@
# Hand-written toy input like -xedPath would generate.
# This input can be substituted for -xedPath.
!sum
- asm: ADDPS
goarch: amd64
feature: "SSE2"
in:
- asmPos: 0
class: vreg
base: float
elemBits: 32
bits: 128
- asmPos: 1
class: vreg
base: float
elemBits: 32
bits: 128
out:
- asmPos: 0
class: vreg
base: float
elemBits: 32
bits: 128
- asm: ADDPD
goarch: amd64
feature: "SSE2"
in:
- asmPos: 0
class: vreg
base: float
elemBits: 64
bits: 128
- asmPos: 1
class: vreg
base: float
elemBits: 64
bits: 128
out:
- asmPos: 0
class: vreg
base: float
elemBits: 64
bits: 128
- asm: PADDB
goarch: amd64
feature: "SSE2"
in:
- asmPos: 0
class: vreg
base: int|uint
elemBits: 32
bits: 128
- asmPos: 1
class: vreg
base: int|uint
elemBits: 32
bits: 128
out:
- asmPos: 0
class: vreg
base: int|uint
elemBits: 32
bits: 128
- asm: VPADDB
goarch: amd64
feature: "AVX"
in:
- asmPos: 1
class: vreg
base: int|uint
elemBits: 8
bits: 128
- asmPos: 2
class: vreg
base: int|uint
elemBits: 8
bits: 128
out:
- asmPos: 0
class: vreg
base: int|uint
elemBits: 8
bits: 128
- asm: VPADDB
goarch: amd64
feature: "AVX2"
in:
- asmPos: 1
class: vreg
base: int|uint
elemBits: 8
bits: 256
- asmPos: 2
class: vreg
base: int|uint
elemBits: 8
bits: 256
out:
- asmPos: 0
class: vreg
base: int|uint
elemBits: 8
bits: 256

View file

@ -0,0 +1 @@
!import ops/*/categories.yaml

View file

@ -0,0 +1,33 @@
#!/bin/bash -x
cat <<\\EOF
This is an end-to-end test of Go SIMD. It checks out a fresh Go
repository from the go.simd branch, then generates the SIMD input
files and runs simdgen writing into the fresh repository.
After that it generates the modified ssa pattern matching files, then
builds the compiler.
\EOF
rm -rf go-test
git clone https://go.googlesource.com/go -b dev.simd go-test
go run . -xedPath xeddata -o godefs -goroot ./go-test go.yaml types.yaml categories.yaml
(cd go-test/src/cmd/compile/internal/ssa/_gen ; go run *.go )
(cd go-test/src ; GOEXPERIMENT=simd ./make.bash )
(cd go-test/bin; b=`pwd` ; cd ../src/simd/testdata; GOARCH=amd64 $b/go run .)
(cd go-test/bin; b=`pwd` ; cd ../src ;
GOEXPERIMENT=simd GOARCH=amd64 $b/go test -v simd
GOEXPERIMENT=simd $b/go test go/doc
GOEXPERIMENT=simd $b/go test go/build
GOEXPERIMENT=simd $b/go test cmd/api -v -check
$b/go test go/doc
$b/go test go/build
$b/go test cmd/api -v -check
$b/go test cmd/compile/internal/ssagen -simd=0
GOEXPERIMENT=simd $b/go test cmd/compile/internal/ssagen -simd=0
)
# next, add some tests of SIMD itself

View file

@ -0,0 +1,70 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"fmt"
"sort"
)
const simdGenericOpsTmpl = `
package main
func simdGenericOps() []opData {
return []opData{
{{- range .Ops }}
{name: "{{.OpName}}", argLength: {{.OpInLen}}, commutative: {{.Comm}}},
{{- end }}
{{- range .OpsImm }}
{name: "{{.OpName}}", argLength: {{.OpInLen}}, commutative: {{.Comm}}, aux: "UInt8"},
{{- end }}
}
}
`
// writeSIMDGenericOps generates the generic ops and writes it to simdAMD64ops.go
// within the specified directory.
func writeSIMDGenericOps(ops []Operation) *bytes.Buffer {
t := templateOf(simdGenericOpsTmpl, "simdgenericOps")
buffer := new(bytes.Buffer)
buffer.WriteString(generatedHeader)
type genericOpsData struct {
OpName string
OpInLen int
Comm bool
}
type opData struct {
Ops []genericOpsData
OpsImm []genericOpsData
}
var opsData opData
for _, op := range ops {
if op.NoGenericOps != nil && *op.NoGenericOps == "true" {
continue
}
_, _, _, immType, gOp := op.shape()
gOpData := genericOpsData{gOp.GenericName(), len(gOp.In), op.Commutative}
if immType == VarImm || immType == ConstVarImm {
opsData.OpsImm = append(opsData.OpsImm, gOpData)
} else {
opsData.Ops = append(opsData.Ops, gOpData)
}
}
sort.Slice(opsData.Ops, func(i, j int) bool {
return compareNatural(opsData.Ops[i].OpName, opsData.Ops[j].OpName) < 0
})
sort.Slice(opsData.OpsImm, func(i, j int) bool {
return compareNatural(opsData.OpsImm[i].OpName, opsData.OpsImm[j].OpName) < 0
})
err := t.Execute(buffer, opsData)
if err != nil {
panic(fmt.Errorf("failed to execute template: %w", err))
}
return buffer
}

View file

@ -0,0 +1,151 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"fmt"
"slices"
)
const simdIntrinsicsTmpl = `
{{define "header"}}
package ssagen
import (
"cmd/compile/internal/ir"
"cmd/compile/internal/ssa"
"cmd/compile/internal/types"
"cmd/internal/sys"
)
const simdPackage = "` + simdPackage + `"
func simdIntrinsics(addF func(pkg, fn string, b intrinsicBuilder, archFamilies ...sys.ArchFamily)) {
{{end}}
{{define "op1"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen1(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
{{end}}
{{define "op2"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen2(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
{{end}}
{{define "op2_21"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen2_21(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
{{end}}
{{define "op2_21Type1"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen2_21(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
{{end}}
{{define "op3"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen3(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
{{end}}
{{define "op3_21"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen3_21(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
{{end}}
{{define "op3_21Type1"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen3_21(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
{{end}}
{{define "op3_231Type1"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen3_231(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
{{end}}
{{define "op3_31"}} addF(simdPackage, "{{(index .In 2).Go}}.{{.Go}}", opLen3_31(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
{{end}}
{{define "op4"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen4(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
{{end}}
{{define "op4_231Type1"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen4_231(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
{{end}}
{{define "op4_31"}} addF(simdPackage, "{{(index .In 2).Go}}.{{.Go}}", opLen4_31(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
{{end}}
{{define "op1Imm8"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen1Imm8(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
{{end}}
{{define "op2Imm8"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen2Imm8(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
{{end}}
{{define "op2Imm8_2I"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen2Imm8_2I(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
{{end}}
{{define "op3Imm8"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen3Imm8(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
{{end}}
{{define "op3Imm8_2I"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen3Imm8_2I(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
{{end}}
{{define "op4Imm8"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen4Imm8(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
{{end}}
{{define "vectorConversion"}} addF(simdPackage, "{{.Tsrc.Name}}.As{{.Tdst.Name}}", func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return args[0] }, sys.AMD64)
{{end}}
{{define "loadStore"}} addF(simdPackage, "Load{{.Name}}", simdLoad(), sys.AMD64)
addF(simdPackage, "{{.Name}}.Store", simdStore(), sys.AMD64)
{{end}}
{{define "maskedLoadStore"}} addF(simdPackage, "LoadMasked{{.Name}}", simdMaskedLoad(ssa.OpLoadMasked{{.ElemBits}}), sys.AMD64)
addF(simdPackage, "{{.Name}}.StoreMasked", simdMaskedStore(ssa.OpStoreMasked{{.ElemBits}}), sys.AMD64)
{{end}}
{{define "mask"}} addF(simdPackage, "{{.Name}}.As{{.VectorCounterpart}}", func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return args[0] }, sys.AMD64)
addF(simdPackage, "{{.VectorCounterpart}}.As{{.Name}}", func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return args[0] }, sys.AMD64)
addF(simdPackage, "{{.Name}}.And", opLen2(ssa.OpAnd{{.ReshapedVectorWithAndOr}}, types.TypeVec{{.Size}}), sys.AMD64)
addF(simdPackage, "{{.Name}}.Or", opLen2(ssa.OpOr{{.ReshapedVectorWithAndOr}}, types.TypeVec{{.Size}}), sys.AMD64)
addF(simdPackage, "Load{{.Name}}FromBits", simdLoadMask({{.ElemBits}}, {{.Lanes}}), sys.AMD64)
addF(simdPackage, "{{.Name}}.StoreToBits", simdStoreMask({{.ElemBits}}, {{.Lanes}}), sys.AMD64)
addF(simdPackage, "{{.Name}}FromBits", simdCvtVToMask({{.ElemBits}}, {{.Lanes}}), sys.AMD64)
addF(simdPackage, "{{.Name}}.ToBits", simdCvtMaskToV({{.ElemBits}}, {{.Lanes}}), sys.AMD64)
{{end}}
{{define "footer"}}}
{{end}}
`
// writeSIMDIntrinsics generates the intrinsic mappings and writes it to simdintrinsics.go
// within the specified directory.
func writeSIMDIntrinsics(ops []Operation, typeMap simdTypeMap) *bytes.Buffer {
t := templateOf(simdIntrinsicsTmpl, "simdintrinsics")
buffer := new(bytes.Buffer)
buffer.WriteString(generatedHeader)
if err := t.ExecuteTemplate(buffer, "header", nil); err != nil {
panic(fmt.Errorf("failed to execute header template: %w", err))
}
slices.SortFunc(ops, compareOperations)
for _, op := range ops {
if op.NoTypes != nil && *op.NoTypes == "true" {
continue
}
if s, op, err := classifyOp(op); err == nil {
if err := t.ExecuteTemplate(buffer, s, op); err != nil {
panic(fmt.Errorf("failed to execute template %s for op %s: %w", s, op.Go, err))
}
} else {
panic(fmt.Errorf("failed to classify op %v: %w", op.Go, err))
}
}
for _, conv := range vConvertFromTypeMap(typeMap) {
if err := t.ExecuteTemplate(buffer, "vectorConversion", conv); err != nil {
panic(fmt.Errorf("failed to execute vectorConversion template: %w", err))
}
}
for _, typ := range typesFromTypeMap(typeMap) {
if typ.Type != "mask" {
if err := t.ExecuteTemplate(buffer, "loadStore", typ); err != nil {
panic(fmt.Errorf("failed to execute loadStore template: %w", err))
}
}
}
for _, typ := range typesFromTypeMap(typeMap) {
if typ.MaskedLoadStoreFilter() {
if err := t.ExecuteTemplate(buffer, "maskedLoadStore", typ); err != nil {
panic(fmt.Errorf("failed to execute maskedLoadStore template: %w", err))
}
}
}
for _, mask := range masksFromTypeMap(typeMap) {
if err := t.ExecuteTemplate(buffer, "mask", mask); err != nil {
panic(fmt.Errorf("failed to execute mask template: %w", err))
}
}
if err := t.ExecuteTemplate(buffer, "footer", nil); err != nil {
panic(fmt.Errorf("failed to execute footer template: %w", err))
}
return buffer
}

View file

@ -0,0 +1,122 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"fmt"
"sort"
"strings"
)
const simdMachineOpsTmpl = `
package main
func simdAMD64Ops(v11, v21, v2k, vkv, v2kv, v2kk, v31, v3kv, vgpv, vgp, vfpv, vfpkv, w11, w21, w2k, wkw, w2kw, w2kk, w31, w3kw, wgpw, wgp, wfpw, wfpkw regInfo) []opData {
return []opData{
{{- range .OpsData }}
{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
{{- end }}
{{- range .OpsDataImm }}
{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
{{- end }}
}
}
`
// writeSIMDMachineOps generates the machine ops and writes it to simdAMD64ops.go
// within the specified directory.
func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
t := templateOf(simdMachineOpsTmpl, "simdAMD64Ops")
buffer := new(bytes.Buffer)
buffer.WriteString(generatedHeader)
type opData struct {
OpName string
Asm string
OpInLen int
RegInfo string
Comm bool
Type string
ResultInArg0 bool
}
type machineOpsData struct {
OpsData []opData
OpsDataImm []opData
}
seen := map[string]struct{}{}
regInfoSet := map[string]bool{
"v11": true, "v21": true, "v2k": true, "v2kv": true, "v2kk": true, "vkv": true, "v31": true, "v3kv": true, "vgpv": true, "vgp": true, "vfpv": true, "vfpkv": true,
"w11": true, "w21": true, "w2k": true, "w2kw": true, "w2kk": true, "wkw": true, "w31": true, "w3kw": true, "wgpw": true, "wgp": true, "wfpw": true, "wfpkw": true}
opsData := make([]opData, 0)
opsDataImm := make([]opData, 0)
for _, op := range ops {
shapeIn, shapeOut, maskType, _, gOp := op.shape()
asm := machineOpName(maskType, gOp)
// TODO: all our masked operations are now zeroing, we need to generate machine ops with merging masks, maybe copy
// one here with a name suffix "Merging". The rewrite rules will need them.
if _, ok := seen[asm]; ok {
continue
}
seen[asm] = struct{}{}
regInfo, err := op.regShape()
if err != nil {
panic(err)
}
idx, err := checkVecAsScalar(op)
if err != nil {
panic(err)
}
if idx != -1 {
if regInfo == "v21" {
regInfo = "vfpv"
} else if regInfo == "v2kv" {
regInfo = "vfpkv"
} else {
panic(fmt.Errorf("simdgen does not recognize uses of treatLikeAScalarOfSize with op regShape %s in op: %s", regInfo, op))
}
}
// Makes AVX512 operations use upper registers
if strings.Contains(op.CPUFeature, "AVX512") {
regInfo = strings.ReplaceAll(regInfo, "v", "w")
}
if _, ok := regInfoSet[regInfo]; !ok {
panic(fmt.Errorf("unsupported register constraint, please update the template and AMD64Ops.go: %s. Op is %s", regInfo, op))
}
var outType string
if shapeOut == OneVregOut || shapeOut == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
// If class overwrite is happening, that's not really a mask but a vreg.
outType = fmt.Sprintf("Vec%d", *gOp.Out[0].Bits)
} else if shapeOut == OneGregOut {
outType = gOp.GoType() // this is a straight Go type, not a VecNNN type
} else if shapeOut == OneKmaskOut {
outType = "Mask"
} else {
panic(fmt.Errorf("simdgen does not recognize this output shape: %d", shapeOut))
}
resultInArg0 := false
if shapeOut == OneVregOutAtIn {
resultInArg0 = true
}
if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
opsDataImm = append(opsDataImm, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
} else {
opsData = append(opsData, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
}
}
sort.Slice(opsData, func(i, j int) bool {
return compareNatural(opsData[i].OpName, opsData[j].OpName) < 0
})
sort.Slice(opsDataImm, func(i, j int) bool {
return compareNatural(opsData[i].OpName, opsData[j].OpName) < 0
})
err := t.Execute(buffer, machineOpsData{opsData, opsDataImm})
if err != nil {
panic(fmt.Errorf("failed to execute template: %w", err))
}
return buffer
}

View file

@ -0,0 +1,631 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"cmp"
"fmt"
"maps"
"slices"
"sort"
"strings"
)
type simdType struct {
Name string // The go type name of this simd type, for example Int32x4.
Lanes int // The number of elements in this vector/mask.
Base string // The element's type, like for Int32x4 it will be int32.
Fields string // The struct fields, it should be right formatted.
Type string // Either "mask" or "vreg"
VectorCounterpart string // For mask use only: just replacing the "Mask" in [simdType.Name] with "Int"
ReshapedVectorWithAndOr string // For mask use only: vector AND and OR are only available in some shape with element width 32.
Size int // The size of the vector type
}
func (x simdType) ElemBits() int {
return x.Size / x.Lanes
}
// LanesContainer returns the smallest int/uint bit size that is
// large enough to hold one bit for each lane. E.g., Mask32x4
// is 4 lanes, and a uint8 is the smallest uint that has 4 bits.
func (x simdType) LanesContainer() int {
if x.Lanes > 64 {
panic("too many lanes")
}
if x.Lanes > 32 {
return 64
}
if x.Lanes > 16 {
return 32
}
if x.Lanes > 8 {
return 16
}
return 8
}
// MaskedLoadStoreFilter encodes which simd type type currently
// get masked loads/stores generated, it is used in two places,
// this forces coordination.
func (x simdType) MaskedLoadStoreFilter() bool {
return x.Size == 512 || x.ElemBits() >= 32 && x.Type != "mask"
}
func (x simdType) IntelSizeSuffix() string {
switch x.ElemBits() {
case 8:
return "B"
case 16:
return "W"
case 32:
return "D"
case 64:
return "Q"
}
panic("oops")
}
func (x simdType) MaskedLoadDoc() string {
if x.Size == 512 || x.ElemBits() < 32 {
return fmt.Sprintf("// Asm: VMOVDQU%d.Z, CPU Feature: AVX512", x.ElemBits())
} else {
return fmt.Sprintf("// Asm: VMASKMOV%s, CPU Feature: AVX2", x.IntelSizeSuffix())
}
}
func (x simdType) MaskedStoreDoc() string {
if x.Size == 512 || x.ElemBits() < 32 {
return fmt.Sprintf("// Asm: VMOVDQU%d, CPU Feature: AVX512", x.ElemBits())
} else {
return fmt.Sprintf("// Asm: VMASKMOV%s, CPU Feature: AVX2", x.IntelSizeSuffix())
}
}
func compareSimdTypes(x, y simdType) int {
// "vreg" then "mask"
if c := -compareNatural(x.Type, y.Type); c != 0 {
return c
}
// want "flo" < "int" < "uin" (and then 8 < 16 < 32 < 64),
// not "int16" < "int32" < "int64" < "int8")
// so limit comparison to first 3 bytes in string.
if c := compareNatural(x.Base[:3], y.Base[:3]); c != 0 {
return c
}
// base type size, 8 < 16 < 32 < 64
if c := x.ElemBits() - y.ElemBits(); c != 0 {
return c
}
// vector size last
return x.Size - y.Size
}
type simdTypeMap map[int][]simdType
type simdTypePair struct {
Tsrc simdType
Tdst simdType
}
func compareSimdTypePairs(x, y simdTypePair) int {
c := compareSimdTypes(x.Tsrc, y.Tsrc)
if c != 0 {
return c
}
return compareSimdTypes(x.Tdst, y.Tdst)
}
const simdPackageHeader = generatedHeader + `
//go:build goexperiment.simd
package simd
`
const simdTypesTemplates = `
{{define "sizeTmpl"}}
// v{{.}} is a tag type that tells the compiler that this is really {{.}}-bit SIMD
type v{{.}} struct {
_{{.}} struct{}
}
{{end}}
{{define "typeTmpl"}}
// {{.Name}} is a {{.Size}}-bit SIMD vector of {{.Lanes}} {{.Base}}
type {{.Name}} struct {
{{.Fields}}
}
{{end}}
`
const simdFeaturesTemplate = `
import "internal/cpu"
{{range .}}
{{- if eq .Feature "AVX512"}}
// Has{{.Feature}} returns whether the CPU supports the AVX512F+CD+BW+DQ+VL features.
//
// These five CPU features are bundled together, and no use of AVX-512
// is allowed unless all of these features are supported together.
// Nearly every CPU that has shipped with any support for AVX-512 has
// supported all five of these features.
{{- else -}}
// Has{{.Feature}} returns whether the CPU supports the {{.Feature}} feature.
{{- end}}
//
// Has{{.Feature}} is defined on all GOARCHes, but will only return true on
// GOARCH {{.GoArch}}.
func Has{{.Feature}}() bool {
return cpu.X86.Has{{.Feature}}
}
{{end}}
`
const simdLoadStoreTemplate = `
// Len returns the number of elements in a {{.Name}}
func (x {{.Name}}) Len() int { return {{.Lanes}} }
// Load{{.Name}} loads a {{.Name}} from an array
//
//go:noescape
func Load{{.Name}}(y *[{{.Lanes}}]{{.Base}}) {{.Name}}
// Store stores a {{.Name}} to an array
//
//go:noescape
func (x {{.Name}}) Store(y *[{{.Lanes}}]{{.Base}})
`
const simdMaskFromBitsTemplate = `
// Load{{.Name}}FromBits constructs a {{.Name}} from a bitmap, where 1 means set for the indexed element, 0 means unset.
// Only the lower {{.Lanes}} bits of y are used.
//
// CPU Features: AVX512
//go:noescape
func Load{{.Name}}FromBits(y *uint64) {{.Name}}
// StoreToBits stores a {{.Name}} as a bitmap, where 1 means set for the indexed element, 0 means unset.
// Only the lower {{.Lanes}} bits of y are used.
//
// CPU Features: AVX512
//go:noescape
func (x {{.Name}}) StoreToBits(y *uint64)
`
const simdMaskFromValTemplate = `
// {{.Name}}FromBits constructs a {{.Name}} from a bitmap value, where 1 means set for the indexed element, 0 means unset.
// Only the lower {{.Lanes}} bits of y are used.
//
// Asm: KMOV{{.IntelSizeSuffix}}, CPU Feature: AVX512
func {{.Name}}FromBits(y uint{{.LanesContainer}}) {{.Name}}
// ToBits constructs a bitmap from a {{.Name}}, where 1 means set for the indexed element, 0 means unset.
// Only the lower {{.Lanes}} bits of y are used.
//
// Asm: KMOV{{.IntelSizeSuffix}}, CPU Features: AVX512
func (x {{.Name}}) ToBits() uint{{.LanesContainer}}
`
const simdMaskedLoadStoreTemplate = `
// LoadMasked{{.Name}} loads a {{.Name}} from an array,
// at those elements enabled by mask
//
{{.MaskedLoadDoc}}
//
//go:noescape
func LoadMasked{{.Name}}(y *[{{.Lanes}}]{{.Base}}, mask Mask{{.ElemBits}}x{{.Lanes}}) {{.Name}}
// StoreMasked stores a {{.Name}} to an array,
// at those elements enabled by mask
//
{{.MaskedStoreDoc}}
//
//go:noescape
func (x {{.Name}}) StoreMasked(y *[{{.Lanes}}]{{.Base}}, mask Mask{{.ElemBits}}x{{.Lanes}})
`
const simdStubsTmpl = `
{{define "op1"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op0NameAndType "x"}}) {{.Go}}() {{.GoType}}
{{end}}
{{define "op2"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op0NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}) {{.GoType}}
{{end}}
{{define "op2_21"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}) {{.GoType}}
{{end}}
{{define "op2_21Type1"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}) {{.GoType}}
{{end}}
{{define "op3"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op0NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op2NameAndType "z"}}) {{.GoType}}
{{end}}
{{define "op3_31"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op2NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op0NameAndType "z"}}) {{.GoType}}
{{end}}
{{define "op3_21"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}, {{.Op2NameAndType "z"}}) {{.GoType}}
{{end}}
{{define "op3_21Type1"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}, {{.Op2NameAndType "z"}}) {{.GoType}}
{{end}}
{{define "op3_231Type1"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.Op0NameAndType "z"}}) {{.GoType}}
{{end}}
{{define "op2VecAsScalar"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op0NameAndType "x"}}) {{.Go}}(y uint{{(index .In 1).TreatLikeAScalarOfSize}}) {{(index .Out 0).Go}}
{{end}}
{{define "op3VecAsScalar"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op0NameAndType "x"}}) {{.Go}}(y uint{{(index .In 1).TreatLikeAScalarOfSize}}, {{.Op2NameAndType "z"}}) {{(index .Out 0).Go}}
{{end}}
{{define "op4"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op0NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op2NameAndType "z"}}, {{.Op3NameAndType "u"}}) {{.GoType}}
{{end}}
{{define "op4_231Type1"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.Op0NameAndType "z"}}, {{.Op3NameAndType "u"}}) {{.GoType}}
{{end}}
{{define "op4_31"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op2NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op0NameAndType "z"}}, {{.Op3NameAndType "u"}}) {{.GoType}}
{{end}}
{{define "op1Imm8"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
//
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8) {{.GoType}}
{{end}}
{{define "op2Imm8"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
//
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}) {{.GoType}}
{{end}}
{{define "op2Imm8_2I"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
//
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.ImmName}} uint8) {{.GoType}}
{{end}}
{{define "op3Imm8"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
//
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}, {{.Op3NameAndType "z"}}) {{.GoType}}
{{end}}
{{define "op3Imm8_2I"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
//
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.ImmName}} uint8, {{.Op3NameAndType "z"}}) {{.GoType}}
{{end}}
{{define "op4Imm8"}}
{{if .Documentation}}{{.Documentation}}
//{{end}}
// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
//
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}, {{.Op3NameAndType "z"}}, {{.Op4NameAndType "u"}}) {{.GoType}}
{{end}}
{{define "vectorConversion"}}
// {{.Tdst.Name}} converts from {{.Tsrc.Name}} to {{.Tdst.Name}}
func (from {{.Tsrc.Name}}) As{{.Tdst.Name}}() (to {{.Tdst.Name}})
{{end}}
{{define "mask"}}
// converts from {{.Name}} to {{.VectorCounterpart}}
func (from {{.Name}}) As{{.VectorCounterpart}}() (to {{.VectorCounterpart}})
// converts from {{.VectorCounterpart}} to {{.Name}}
func (from {{.VectorCounterpart}}) As{{.Name}}() (to {{.Name}})
func (x {{.Name}}) And(y {{.Name}}) {{.Name}}
func (x {{.Name}}) Or(y {{.Name}}) {{.Name}}
{{end}}
`
// parseSIMDTypes groups go simd types by their vector sizes, and
// returns a map whose key is the vector size, value is the simd type.
func parseSIMDTypes(ops []Operation) simdTypeMap {
// TODO: maybe instead of going over ops, let's try go over types.yaml.
ret := map[int][]simdType{}
seen := map[string]struct{}{}
processArg := func(arg Operand) {
if arg.Class == "immediate" || arg.Class == "greg" {
// Immediates are not encoded as vector types.
return
}
if _, ok := seen[*arg.Go]; ok {
return
}
seen[*arg.Go] = struct{}{}
lanes := *arg.Lanes
base := fmt.Sprintf("%s%d", *arg.Base, *arg.ElemBits)
tagFieldNameS := fmt.Sprintf("%sx%d", base, lanes)
tagFieldS := fmt.Sprintf("%s v%d", tagFieldNameS, *arg.Bits)
valFieldS := fmt.Sprintf("vals%s[%d]%s", strings.Repeat(" ", len(tagFieldNameS)-3), lanes, base)
fields := fmt.Sprintf("\t%s\n\t%s", tagFieldS, valFieldS)
if arg.Class == "mask" {
vectorCounterpart := strings.ReplaceAll(*arg.Go, "Mask", "Int")
reshapedVectorWithAndOr := fmt.Sprintf("Int32x%d", *arg.Bits/32)
ret[*arg.Bits] = append(ret[*arg.Bits], simdType{*arg.Go, lanes, base, fields, arg.Class, vectorCounterpart, reshapedVectorWithAndOr, *arg.Bits})
// In case the vector counterpart of a mask is not present, put its vector counterpart typedef into the map as well.
if _, ok := seen[vectorCounterpart]; !ok {
seen[vectorCounterpart] = struct{}{}
ret[*arg.Bits] = append(ret[*arg.Bits], simdType{vectorCounterpart, lanes, base, fields, "vreg", "", "", *arg.Bits})
}
} else {
ret[*arg.Bits] = append(ret[*arg.Bits], simdType{*arg.Go, lanes, base, fields, arg.Class, "", "", *arg.Bits})
}
}
for _, op := range ops {
for _, arg := range op.In {
processArg(arg)
}
for _, arg := range op.Out {
processArg(arg)
}
}
return ret
}
func vConvertFromTypeMap(typeMap simdTypeMap) []simdTypePair {
v := []simdTypePair{}
for _, ts := range typeMap {
for i, tsrc := range ts {
for j, tdst := range ts {
if i != j && tsrc.Type == tdst.Type && tsrc.Type == "vreg" &&
tsrc.Lanes > 1 && tdst.Lanes > 1 {
v = append(v, simdTypePair{tsrc, tdst})
}
}
}
}
slices.SortFunc(v, compareSimdTypePairs)
return v
}
func masksFromTypeMap(typeMap simdTypeMap) []simdType {
m := []simdType{}
for _, ts := range typeMap {
for _, tsrc := range ts {
if tsrc.Type == "mask" {
m = append(m, tsrc)
}
}
}
slices.SortFunc(m, compareSimdTypes)
return m
}
func typesFromTypeMap(typeMap simdTypeMap) []simdType {
m := []simdType{}
for _, ts := range typeMap {
for _, tsrc := range ts {
if tsrc.Lanes > 1 {
m = append(m, tsrc)
}
}
}
slices.SortFunc(m, compareSimdTypes)
return m
}
// writeSIMDTypes generates the simd vector types into a bytes.Buffer
func writeSIMDTypes(typeMap simdTypeMap) *bytes.Buffer {
t := templateOf(simdTypesTemplates, "types_amd64")
loadStore := templateOf(simdLoadStoreTemplate, "loadstore_amd64")
maskedLoadStore := templateOf(simdMaskedLoadStoreTemplate, "maskedloadstore_amd64")
maskFromBits := templateOf(simdMaskFromBitsTemplate, "maskFromBits_amd64")
maskFromVal := templateOf(simdMaskFromValTemplate, "maskFromVal_amd64")
buffer := new(bytes.Buffer)
buffer.WriteString(simdPackageHeader)
sizes := make([]int, 0, len(typeMap))
for size, types := range typeMap {
slices.SortFunc(types, compareSimdTypes)
sizes = append(sizes, size)
}
sort.Ints(sizes)
for _, size := range sizes {
if size <= 64 {
// these are scalar
continue
}
if err := t.ExecuteTemplate(buffer, "sizeTmpl", size); err != nil {
panic(fmt.Errorf("failed to execute size template for size %d: %w", size, err))
}
for _, typeDef := range typeMap[size] {
if typeDef.Lanes == 1 {
continue
}
if err := t.ExecuteTemplate(buffer, "typeTmpl", typeDef); err != nil {
panic(fmt.Errorf("failed to execute type template for type %s: %w", typeDef.Name, err))
}
if typeDef.Type != "mask" {
if err := loadStore.ExecuteTemplate(buffer, "loadstore_amd64", typeDef); err != nil {
panic(fmt.Errorf("failed to execute loadstore template for type %s: %w", typeDef.Name, err))
}
// restrict to AVX2 masked loads/stores first.
if typeDef.MaskedLoadStoreFilter() {
if err := maskedLoadStore.ExecuteTemplate(buffer, "maskedloadstore_amd64", typeDef); err != nil {
panic(fmt.Errorf("failed to execute maskedloadstore template for type %s: %w", typeDef.Name, err))
}
}
} else {
if err := maskFromBits.ExecuteTemplate(buffer, "maskFromBits_amd64", typeDef); err != nil {
panic(fmt.Errorf("failed to execute maskFromBits template for type %s: %w", typeDef.Name, err))
}
if err := maskFromVal.ExecuteTemplate(buffer, "maskFromVal_amd64", typeDef); err != nil {
panic(fmt.Errorf("failed to execute maskFromVal template for type %s: %w", typeDef.Name, err))
}
}
}
}
return buffer
}
func writeSIMDFeatures(ops []Operation) *bytes.Buffer {
// Gather all features
type featureKey struct {
GoArch string
Feature string
}
featureSet := make(map[featureKey]struct{})
for _, op := range ops {
featureSet[featureKey{op.GoArch, op.CPUFeature}] = struct{}{}
}
features := slices.SortedFunc(maps.Keys(featureSet), func(a, b featureKey) int {
if c := cmp.Compare(a.GoArch, b.GoArch); c != 0 {
return c
}
return compareNatural(a.Feature, b.Feature)
})
// If we ever have the same feature name on more than one GOARCH, we'll have
// to be more careful about this.
t := templateOf(simdFeaturesTemplate, "features")
buffer := new(bytes.Buffer)
buffer.WriteString(simdPackageHeader)
if err := t.Execute(buffer, features); err != nil {
panic(fmt.Errorf("failed to execute features template: %w", err))
}
return buffer
}
// writeSIMDStubs generates the simd vector intrinsic stubs and writes it to ops_amd64.go and ops_internal_amd64.go
// within the specified directory.
func writeSIMDStubs(ops []Operation, typeMap simdTypeMap) *bytes.Buffer {
t := templateOf(simdStubsTmpl, "simdStubs")
buffer := new(bytes.Buffer)
buffer.WriteString(simdPackageHeader)
slices.SortFunc(ops, compareOperations)
for i, op := range ops {
if op.NoTypes != nil && *op.NoTypes == "true" {
continue
}
idxVecAsScalar, err := checkVecAsScalar(op)
if err != nil {
panic(err)
}
if s, op, err := classifyOp(op); err == nil {
if idxVecAsScalar != -1 {
if s == "op2" || s == "op3" {
s += "VecAsScalar"
} else {
panic(fmt.Errorf("simdgen only supports op2 or op3 with TreatLikeAScalarOfSize"))
}
}
if i == 0 || op.Go != ops[i-1].Go {
fmt.Fprintf(buffer, "\n/* %s */\n", op.Go)
}
if err := t.ExecuteTemplate(buffer, s, op); err != nil {
panic(fmt.Errorf("failed to execute template %s for op %v: %w", s, op, err))
}
} else {
panic(fmt.Errorf("failed to classify op %v: %w", op.Go, err))
}
}
vectorConversions := vConvertFromTypeMap(typeMap)
for _, conv := range vectorConversions {
if err := t.ExecuteTemplate(buffer, "vectorConversion", conv); err != nil {
panic(fmt.Errorf("failed to execute vectorConversion template: %w", err))
}
}
masks := masksFromTypeMap(typeMap)
for _, mask := range masks {
if err := t.ExecuteTemplate(buffer, "mask", mask); err != nil {
panic(fmt.Errorf("failed to execute mask template for mask %s: %w", mask.Name, err))
}
}
return buffer
}

View file

@ -0,0 +1,211 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"fmt"
"slices"
"text/template"
)
type tplRuleData struct {
tplName string // e.g. "sftimm"
GoOp string // e.g. "ShiftAllLeft"
GoType string // e.g. "Uint32x8"
Args string // e.g. "x y"
Asm string // e.g. "VPSLLD256"
ArgsOut string // e.g. "x y"
MaskInConvert string // e.g. "VPMOVVec32x8ToM"
MaskOutConvert string // e.g. "VPMOVMToVec32x8"
}
var (
ruleTemplates = template.Must(template.New("simdRules").Parse(`
{{define "pureVreg"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ArgsOut}})
{{end}}
{{define "maskIn"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask))
{{end}}
{{define "maskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}}))
{{end}}
{{define "maskInMaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask)))
{{end}}
{{define "sftimm"}}({{.Asm}} x (MOVQconst [c])) => ({{.Asm}}const [uint8(c)] x)
{{end}}
{{define "masksftimm"}}({{.Asm}} x (MOVQconst [c]) mask) => ({{.Asm}}const [uint8(c)] x mask)
{{end}}
`))
)
// SSA rewrite rules need to appear in a most-to-least-specific order. This works for that.
var tmplOrder = map[string]int{
"masksftimm": 0,
"sftimm": 1,
"maskInMaskOut": 2,
"maskOut": 3,
"maskIn": 4,
"pureVreg": 5,
}
func compareTplRuleData(x, y tplRuleData) int {
if c := compareNatural(x.GoOp, y.GoOp); c != 0 {
return c
}
if c := compareNatural(x.GoType, y.GoType); c != 0 {
return c
}
if c := compareNatural(x.Args, y.Args); c != 0 {
return c
}
if x.tplName == y.tplName {
return 0
}
xo, xok := tmplOrder[x.tplName]
yo, yok := tmplOrder[y.tplName]
if !xok {
panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", x.tplName))
}
if !yok {
panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", y.tplName))
}
return xo - yo
}
// writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules
// within the specified directory.
func writeSIMDRules(ops []Operation) *bytes.Buffer {
buffer := new(bytes.Buffer)
buffer.WriteString(generatedHeader + "\n")
var allData []tplRuleData
for _, opr := range ops {
if opr.NoGenericOps != nil && *opr.NoGenericOps == "true" {
continue
}
opInShape, opOutShape, maskType, immType, gOp := opr.shape()
asm := machineOpName(maskType, gOp)
vregInCnt := len(gOp.In)
if maskType == OneMask {
vregInCnt--
}
data := tplRuleData{
GoOp: gOp.Go,
Asm: asm,
}
if vregInCnt == 1 {
data.Args = "x"
data.ArgsOut = data.Args
} else if vregInCnt == 2 {
data.Args = "x y"
data.ArgsOut = data.Args
} else if vregInCnt == 3 {
data.Args = "x y z"
data.ArgsOut = data.Args
} else {
panic(fmt.Errorf("simdgen does not support more than 3 vreg in inputs"))
}
if immType == ConstImm {
data.ArgsOut = fmt.Sprintf("[%s] %s", *opr.In[0].Const, data.ArgsOut)
} else if immType == VarImm {
data.Args = fmt.Sprintf("[a] %s", data.Args)
data.ArgsOut = fmt.Sprintf("[a] %s", data.ArgsOut)
} else if immType == ConstVarImm {
data.Args = fmt.Sprintf("[a] %s", data.Args)
data.ArgsOut = fmt.Sprintf("[a+%s] %s", *opr.In[0].Const, data.ArgsOut)
}
goType := func(op Operation) string {
if op.OperandOrder != nil {
switch *op.OperandOrder {
case "21Type1", "231Type1":
// Permute uses operand[1] for method receiver.
return *op.In[1].Go
}
}
return *op.In[0].Go
}
var tplName string
// If class overwrite is happening, that's not really a mask but a vreg.
if opOutShape == OneVregOut || opOutShape == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
switch opInShape {
case OneImmIn:
tplName = "pureVreg"
data.GoType = goType(gOp)
case PureVregIn:
tplName = "pureVreg"
data.GoType = goType(gOp)
case OneKmaskImmIn:
fallthrough
case OneKmaskIn:
tplName = "maskIn"
data.GoType = goType(gOp)
rearIdx := len(gOp.In) - 1
// Mask is at the end.
data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes)
case PureKmaskIn:
panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
}
} else if opOutShape == OneGregOut {
tplName = "pureVreg" // TODO this will be wrong
data.GoType = goType(gOp)
} else {
// OneKmaskOut case
data.MaskOutConvert = fmt.Sprintf("VPMOVMToVec%dx%d", *gOp.Out[0].ElemBits, *gOp.In[0].Lanes)
switch opInShape {
case OneImmIn:
fallthrough
case PureVregIn:
tplName = "maskOut"
data.GoType = goType(gOp)
case OneKmaskImmIn:
fallthrough
case OneKmaskIn:
tplName = "maskInMaskOut"
data.GoType = goType(gOp)
rearIdx := len(gOp.In) - 1
data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes)
case PureKmaskIn:
panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
}
}
if gOp.SpecialLower != nil {
if *gOp.SpecialLower == "sftimm" {
if data.GoType[0] == 'I' {
// only do these for signed types, it is a duplicate rewrite for unsigned
sftImmData := data
if tplName == "maskIn" {
sftImmData.tplName = "masksftimm"
} else {
sftImmData.tplName = "sftimm"
}
allData = append(allData, sftImmData)
}
} else {
panic("simdgen sees unknwon special lower " + *gOp.SpecialLower + ", maybe implement it?")
}
}
if tplName == "pureVreg" && data.Args == data.ArgsOut {
data.Args = "..."
data.ArgsOut = "..."
}
data.tplName = tplName
allData = append(allData, data)
}
slices.SortFunc(allData, compareTplRuleData)
for _, data := range allData {
if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err))
}
}
return buffer
}

View file

@ -0,0 +1,173 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"fmt"
"strings"
"text/template"
)
var (
ssaTemplates = template.Must(template.New("simdSSA").Parse(`
{{define "header"}}// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
package amd64
import (
"cmd/compile/internal/ssa"
"cmd/compile/internal/ssagen"
"cmd/internal/obj"
"cmd/internal/obj/x86"
)
func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool {
var p *obj.Prog
switch v.Op {{"{"}}{{end}}
{{define "case"}}
case {{.Cases}}:
p = {{.Helper}}(s, v)
{{end}}
{{define "footer"}}
default:
// Unknown reg shape
return false
}
{{end}}
{{define "zeroing"}}
// Masked operation are always compiled with zeroing.
switch v.Op {
case {{.}}:
x86.ParseSuffix(p, "Z")
}
{{end}}
{{define "ending"}}
return true
}
{{end}}`))
)
type tplSSAData struct {
Cases string
Helper string
}
// writeSIMDSSA generates the ssa to prog lowering codes and writes it to simdssa.go
// within the specified directory.
func writeSIMDSSA(ops []Operation) *bytes.Buffer {
var ZeroingMask []string
regInfoKeys := []string{
"v11",
"v21",
"v2k",
"v2kv",
"v2kk",
"vkv",
"v31",
"v3kv",
"v11Imm8",
"vkvImm8",
"v21Imm8",
"v2kImm8",
"v2kkImm8",
"v31ResultInArg0",
"v3kvResultInArg0",
"vfpv",
"vfpkv",
"vgpvImm8",
"vgpImm8",
"v2kvImm8",
}
regInfoSet := map[string][]string{}
for _, key := range regInfoKeys {
regInfoSet[key] = []string{}
}
seen := map[string]struct{}{}
allUnseen := make(map[string][]Operation)
for _, op := range ops {
shapeIn, shapeOut, maskType, _, gOp := op.shape()
asm := machineOpName(maskType, gOp)
if _, ok := seen[asm]; ok {
continue
}
seen[asm] = struct{}{}
caseStr := fmt.Sprintf("ssa.OpAMD64%s", asm)
if shapeIn == OneKmaskIn || shapeIn == OneKmaskImmIn {
if gOp.Zeroing == nil {
ZeroingMask = append(ZeroingMask, caseStr)
}
}
regShape, err := op.regShape()
if err != nil {
panic(err)
}
if shapeOut == OneVregOutAtIn {
regShape += "ResultInArg0"
}
if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
regShape += "Imm8"
}
idx, err := checkVecAsScalar(op)
if err != nil {
panic(err)
}
if idx != -1 {
if regShape == "v21" {
regShape = "vfpv"
} else if regShape == "v2kv" {
regShape = "vfpkv"
} else {
panic(fmt.Errorf("simdgen does not recognize uses of treatLikeAScalarOfSize with op regShape %s in op: %s", regShape, op))
}
}
if _, ok := regInfoSet[regShape]; !ok {
allUnseen[regShape] = append(allUnseen[regShape], op)
}
regInfoSet[regShape] = append(regInfoSet[regShape], caseStr)
}
if len(allUnseen) != 0 {
panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v", allUnseen))
}
buffer := new(bytes.Buffer)
if err := ssaTemplates.ExecuteTemplate(buffer, "header", nil); err != nil {
panic(fmt.Errorf("failed to execute header template: %w", err))
}
for _, regShape := range regInfoKeys {
// Stable traversal of regInfoSet
cases := regInfoSet[regShape]
if len(cases) == 0 {
continue
}
data := tplSSAData{
Cases: strings.Join(cases, ",\n\t\t"),
Helper: "simd" + capitalizeFirst(regShape),
}
if err := ssaTemplates.ExecuteTemplate(buffer, "case", data); err != nil {
panic(fmt.Errorf("failed to execute case template for %s: %w", regShape, err))
}
}
if err := ssaTemplates.ExecuteTemplate(buffer, "footer", nil); err != nil {
panic(fmt.Errorf("failed to execute footer template: %w", err))
}
if len(ZeroingMask) != 0 {
if err := ssaTemplates.ExecuteTemplate(buffer, "zeroing", strings.Join(ZeroingMask, ",\n\t\t")); err != nil {
panic(fmt.Errorf("failed to execute footer template: %w", err))
}
}
if err := ssaTemplates.ExecuteTemplate(buffer, "ending", nil); err != nil {
panic(fmt.Errorf("failed to execute footer template: %w", err))
}
return buffer
}

View file

@ -0,0 +1,729 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bufio"
"bytes"
"fmt"
"go/format"
"log"
"os"
"path/filepath"
"reflect"
"slices"
"sort"
"strings"
"text/template"
"unicode"
)
func templateOf(temp, name string) *template.Template {
t, err := template.New(name).Parse(temp)
if err != nil {
panic(fmt.Errorf("failed to parse template %s: %w", name, err))
}
return t
}
func createPath(goroot string, file string) (*os.File, error) {
fp := filepath.Join(goroot, file)
dir := filepath.Dir(fp)
err := os.MkdirAll(dir, 0755)
if err != nil {
return nil, fmt.Errorf("failed to create directory %s: %w", dir, err)
}
f, err := os.Create(fp)
if err != nil {
return nil, fmt.Errorf("failed to create file %s: %w", fp, err)
}
return f, nil
}
func formatWriteAndClose(out *bytes.Buffer, goroot string, file string) {
b, err := format.Source(out.Bytes())
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
fmt.Fprintf(os.Stderr, "%v\n", err)
panic(err)
} else {
writeAndClose(b, goroot, file)
}
}
func writeAndClose(b []byte, goroot string, file string) {
ofile, err := createPath(goroot, file)
if err != nil {
panic(err)
}
ofile.Write(b)
ofile.Close()
}
// numberLines takes a slice of bytes, and returns a string where each line
// is numbered, starting from 1.
func numberLines(data []byte) string {
var buf bytes.Buffer
r := bytes.NewReader(data)
s := bufio.NewScanner(r)
for i := 1; s.Scan(); i++ {
fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
}
return buf.String()
}
type inShape uint8
type outShape uint8
type maskShape uint8
type immShape uint8
const (
InvalidIn inShape = iota
PureVregIn // vector register input only
OneKmaskIn // vector and kmask input
OneImmIn // vector and immediate input
OneKmaskImmIn // vector, kmask, and immediate inputs
PureKmaskIn // only mask inputs.
)
const (
InvalidOut outShape = iota
NoOut // no output
OneVregOut // (one) vector register output
OneGregOut // (one) general register output
OneKmaskOut // mask output
OneVregOutAtIn // the first input is also the output
)
const (
InvalidMask maskShape = iota
NoMask // no mask
OneMask // with mask (K1 to K7)
AllMasks // a K mask instruction (K0-K7)
)
const (
InvalidImm immShape = iota
NoImm // no immediate
ConstImm // const only immediate
VarImm // pure imm argument provided by the users
ConstVarImm // a combination of user arg and const
)
// opShape returns the several integers describing the shape of the operation,
// and modified versions of the op:
//
// opNoImm is op with its inputs excluding the const imm.
//
// This function does not modify op.
func (op *Operation) shape() (shapeIn inShape, shapeOut outShape, maskType maskShape, immType immShape,
opNoImm Operation) {
if len(op.Out) > 1 {
panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
}
var outputReg int
if len(op.Out) == 1 {
outputReg = op.Out[0].AsmPos
if op.Out[0].Class == "vreg" {
shapeOut = OneVregOut
} else if op.Out[0].Class == "greg" {
shapeOut = OneGregOut
} else if op.Out[0].Class == "mask" {
shapeOut = OneKmaskOut
} else {
panic(fmt.Errorf("simdgen only supports output of class vreg or mask: %s", op))
}
} else {
shapeOut = NoOut
// TODO: are these only Load/Stores?
// We manually supported two Load and Store, are those enough?
panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
}
hasImm := false
maskCount := 0
hasVreg := false
for _, in := range op.In {
if in.AsmPos == outputReg {
if shapeOut != OneVregOutAtIn && in.AsmPos == 0 && in.Class == "vreg" {
shapeOut = OneVregOutAtIn
} else {
panic(fmt.Errorf("simdgen only support output and input sharing the same position case of \"the first input is vreg and the only output\": %s", op))
}
}
if in.Class == "immediate" {
// A manual check on XED data found that AMD64 SIMD instructions at most
// have 1 immediates. So we don't need to check this here.
if *in.Bits != 8 {
panic(fmt.Errorf("simdgen only supports immediates of 8 bits: %s", op))
}
hasImm = true
} else if in.Class == "mask" {
maskCount++
} else {
hasVreg = true
}
}
opNoImm = *op
removeImm := func(o *Operation) {
o.In = o.In[1:]
}
if hasImm {
removeImm(&opNoImm)
if op.In[0].Const != nil {
if op.In[0].ImmOffset != nil {
immType = ConstVarImm
} else {
immType = ConstImm
}
} else if op.In[0].ImmOffset != nil {
immType = VarImm
} else {
panic(fmt.Errorf("simdgen requires imm to have at least one of ImmOffset or Const set: %s", op))
}
} else {
immType = NoImm
}
if maskCount == 0 {
maskType = NoMask
} else {
maskType = OneMask
}
checkPureMask := func() bool {
if hasImm {
panic(fmt.Errorf("simdgen does not support immediates in pure mask operations: %s", op))
}
if hasVreg {
panic(fmt.Errorf("simdgen does not support more than 1 masks in non-pure mask operations: %s", op))
}
return false
}
if !hasImm && maskCount == 0 {
shapeIn = PureVregIn
} else if !hasImm && maskCount > 0 {
if maskCount == 1 {
shapeIn = OneKmaskIn
} else {
if checkPureMask() {
return
}
shapeIn = PureKmaskIn
maskType = AllMasks
}
} else if hasImm && maskCount == 0 {
shapeIn = OneImmIn
} else {
if maskCount == 1 {
shapeIn = OneKmaskImmIn
} else {
checkPureMask()
return
}
}
return
}
// regShape returns a string representation of the register shape.
func (op *Operation) regShape() (string, error) {
_, _, _, _, gOp := op.shape()
var regInfo string
var vRegInCnt, gRegInCnt, kMaskInCnt, vRegOutCnt, gRegOutCnt, kMaskOutCnt int
for _, in := range gOp.In {
if in.Class == "vreg" {
vRegInCnt++
} else if in.Class == "greg" {
gRegInCnt++
} else if in.Class == "mask" {
kMaskInCnt++
}
}
for _, out := range gOp.Out {
// If class overwrite is happening, that's not really a mask but a vreg.
if out.Class == "vreg" || out.OverwriteClass != nil {
vRegOutCnt++
} else if out.Class == "greg" {
gRegOutCnt++
} else if out.Class == "mask" {
kMaskOutCnt++
}
}
var inRegs, inMasks, outRegs, outMasks string
rmAbbrev := func(s string, i int) string {
if i == 0 {
return ""
}
if i == 1 {
return s
}
return fmt.Sprintf("%s%d", s, i)
}
inRegs = rmAbbrev("v", vRegInCnt)
inRegs += rmAbbrev("gp", gRegInCnt)
inMasks = rmAbbrev("k", kMaskInCnt)
outRegs = rmAbbrev("v", vRegOutCnt)
outRegs += rmAbbrev("gp", gRegOutCnt)
outMasks = rmAbbrev("k", kMaskOutCnt)
if kMaskInCnt == 0 && kMaskOutCnt == 0 && gRegInCnt == 0 && gRegOutCnt == 0 {
// For pure v we can abbreviate it as v%d%d.
regInfo = fmt.Sprintf("v%d%d", vRegInCnt, vRegOutCnt)
} else if kMaskInCnt == 0 && kMaskOutCnt == 0 {
regInfo = fmt.Sprintf("%s%s", inRegs, outRegs)
} else {
regInfo = fmt.Sprintf("%s%s%s%s", inRegs, inMasks, outRegs, outMasks)
}
return regInfo, nil
}
// sortOperand sorts op.In by putting immediates first, then vreg, and mask the last.
// TODO: verify that this is a safe assumption of the prog structure.
// from my observation looks like in asm, imms are always the first,
// masks are always the last, with vreg in between.
func (op *Operation) sortOperand() {
priority := map[string]int{"immediate": 0, "vreg": 1, "greg": 1, "mask": 2}
sort.SliceStable(op.In, func(i, j int) bool {
pi := priority[op.In[i].Class]
pj := priority[op.In[j].Class]
if pi != pj {
return pi < pj
}
return op.In[i].AsmPos < op.In[j].AsmPos
})
}
// goNormalType returns the Go type name for the result of an Op that
// does not return a vector, i.e., that returns a result in a general
// register. Currently there's only one family of Ops in Go's simd library
// that does this (GetElem), and so this is specialized to work for that,
// but the problem (mismatch betwen hardware register width and Go type
// width) seems likely to recur if there are any other cases.
func (op Operation) goNormalType() string {
if op.Go == "GetElem" {
// GetElem returns an element of the vector into a general register
// but as far as the hardware is concerned, that result is either 32
// or 64 bits wide, no matter what the vector element width is.
// This is not "wrong" but it is not the right answer for Go source code.
// To get the Go type right, combine the base type ("int", "uint", "float"),
// with the input vector element width in bits (8,16,32,64).
at := 0 // proper value of at depends on whether immediate was stripped or not
if op.In[at].Class == "immediate" {
at++
}
return fmt.Sprintf("%s%d", *op.Out[0].Base, *op.In[at].ElemBits)
}
panic(fmt.Errorf("Implement goNormalType for %v", op))
}
// SSAType returns the string for the type reference in SSA generation,
// for example in the intrinsics generating template.
func (op Operation) SSAType() string {
if op.Out[0].Class == "greg" {
return fmt.Sprintf("types.Types[types.T%s]", strings.ToUpper(op.goNormalType()))
}
return fmt.Sprintf("types.TypeVec%d", *op.Out[0].Bits)
}
// GoType returns the Go type returned by this operation (relative to the simd package),
// for example "int32" or "Int8x16". This is used in a template.
func (op Operation) GoType() string {
if op.Out[0].Class == "greg" {
return op.goNormalType()
}
return *op.Out[0].Go
}
// ImmName returns the name to use for an operation's immediate operand.
// This can be overriden in the yaml with "name" on an operand,
// otherwise, for now, "constant"
func (op Operation) ImmName() string {
return op.Op0Name("constant")
}
func (o Operand) OpName(s string) string {
if n := o.Name; n != nil {
return *n
}
if o.Class == "mask" {
return "mask"
}
return s
}
func (o Operand) OpNameAndType(s string) string {
return o.OpName(s) + " " + *o.Go
}
// GoExported returns [Go] with first character capitalized.
func (op Operation) GoExported() string {
return capitalizeFirst(op.Go)
}
// DocumentationExported returns [Documentation] with method name capitalized.
func (op Operation) DocumentationExported() string {
return strings.ReplaceAll(op.Documentation, op.Go, op.GoExported())
}
// Op0Name returns the name to use for the 0 operand,
// if any is present, otherwise the parameter is used.
func (op Operation) Op0Name(s string) string {
return op.In[0].OpName(s)
}
// Op1Name returns the name to use for the 1 operand,
// if any is present, otherwise the parameter is used.
func (op Operation) Op1Name(s string) string {
return op.In[1].OpName(s)
}
// Op2Name returns the name to use for the 2 operand,
// if any is present, otherwise the parameter is used.
func (op Operation) Op2Name(s string) string {
return op.In[2].OpName(s)
}
// Op3Name returns the name to use for the 3 operand,
// if any is present, otherwise the parameter is used.
func (op Operation) Op3Name(s string) string {
return op.In[3].OpName(s)
}
// Op0NameAndType returns the name and type to use for
// the 0 operand, if a name is provided, otherwise
// the parameter value is used as the default.
func (op Operation) Op0NameAndType(s string) string {
return op.In[0].OpNameAndType(s)
}
// Op1NameAndType returns the name and type to use for
// the 1 operand, if a name is provided, otherwise
// the parameter value is used as the default.
func (op Operation) Op1NameAndType(s string) string {
return op.In[1].OpNameAndType(s)
}
// Op2NameAndType returns the name and type to use for
// the 2 operand, if a name is provided, otherwise
// the parameter value is used as the default.
func (op Operation) Op2NameAndType(s string) string {
return op.In[2].OpNameAndType(s)
}
// Op3NameAndType returns the name and type to use for
// the 3 operand, if a name is provided, otherwise
// the parameter value is used as the default.
func (op Operation) Op3NameAndType(s string) string {
return op.In[3].OpNameAndType(s)
}
// Op4NameAndType returns the name and type to use for
// the 4 operand, if a name is provided, otherwise
// the parameter value is used as the default.
func (op Operation) Op4NameAndType(s string) string {
return op.In[4].OpNameAndType(s)
}
var immClasses []string = []string{"BAD0Imm", "BAD1Imm", "op1Imm8", "op2Imm8", "op3Imm8", "op4Imm8"}
var classes []string = []string{"BAD0", "op1", "op2", "op3", "op4"}
// classifyOp returns a classification string, modified operation, and perhaps error based
// on the stub and intrinsic shape for the operation.
// The classification string is in the regular expression set "op[1234](Imm8)?(_<order>)?"
// where the "<order>" suffix is optionally attached to the Operation in its input yaml.
// The classification string is used to select a template or a clause of a template
// for intrinsics declaration and the ssagen intrinisics glue code in the compiler.
func classifyOp(op Operation) (string, Operation, error) {
_, _, _, immType, gOp := op.shape()
var class string
if immType == VarImm || immType == ConstVarImm {
switch l := len(op.In); l {
case 1:
return "", op, fmt.Errorf("simdgen does not recognize this operation of only immediate input: %s", op)
case 2, 3, 4, 5:
class = immClasses[l]
default:
return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
}
if order := op.OperandOrder; order != nil {
class += "_" + *order
}
return class, op, nil
} else {
switch l := len(gOp.In); l {
case 1, 2, 3, 4:
class = classes[l]
default:
return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
}
if order := op.OperandOrder; order != nil {
class += "_" + *order
}
return class, gOp, nil
}
}
func checkVecAsScalar(op Operation) (idx int, err error) {
idx = -1
sSize := 0
for i, o := range op.In {
if o.TreatLikeAScalarOfSize != nil {
if idx == -1 {
idx = i
sSize = *o.TreatLikeAScalarOfSize
} else {
err = fmt.Errorf("simdgen only supports one TreatLikeAScalarOfSize in the arg list: %s", op)
return
}
}
}
if idx >= 0 {
if idx != 1 {
err = fmt.Errorf("simdgen only supports TreatLikeAScalarOfSize at the 2nd arg of the arg list: %s", op)
return
}
if sSize != 8 && sSize != 16 && sSize != 32 && sSize != 64 {
err = fmt.Errorf("simdgen does not recognize this uint size: %d, %s", sSize, op)
return
}
}
return
}
// dedup is deduping operations in the full structure level.
func dedup(ops []Operation) (deduped []Operation) {
for _, op := range ops {
seen := false
for _, dop := range deduped {
if reflect.DeepEqual(op, dop) {
seen = true
break
}
}
if !seen {
deduped = append(deduped, op)
}
}
return
}
func (op Operation) GenericName() string {
if op.OperandOrder != nil {
switch *op.OperandOrder {
case "21Type1", "231Type1":
// Permute uses operand[1] for method receiver.
return op.Go + *op.In[1].Go
}
}
if op.In[0].Class == "immediate" {
return op.Go + *op.In[1].Go
}
return op.Go + *op.In[0].Go
}
// dedupGodef is deduping operations in [Op.Go]+[*Op.In[0].Go] level.
// By deduping, it means picking the least advanced architecture that satisfy the requirement:
// AVX512 will be least preferred.
// If FlagNoDedup is set, it will report the duplicates to the console.
func dedupGodef(ops []Operation) ([]Operation, error) {
seen := map[string][]Operation{}
for _, op := range ops {
_, _, _, _, gOp := op.shape()
gN := gOp.GenericName()
seen[gN] = append(seen[gN], op)
}
if *FlagReportDup {
for gName, dup := range seen {
if len(dup) > 1 {
log.Printf("Duplicate for %s:\n", gName)
for _, op := range dup {
log.Printf("%s\n", op)
}
}
}
return ops, nil
}
isAVX512 := func(op Operation) bool {
return strings.Contains(op.CPUFeature, "AVX512")
}
deduped := []Operation{}
for _, dup := range seen {
if len(dup) > 1 {
slices.SortFunc(dup, func(i, j Operation) int {
// Put non-AVX512 candidates at the beginning
if !isAVX512(i) && isAVX512(j) {
return -1
}
if isAVX512(i) && !isAVX512(j) {
return 1
}
return strings.Compare(i.CPUFeature, j.CPUFeature)
})
}
deduped = append(deduped, dup[0])
}
slices.SortFunc(deduped, compareOperations)
return deduped, nil
}
// Copy op.ConstImm to op.In[0].Const
// This is a hack to reduce the size of defs we need for const imm operations.
func copyConstImm(ops []Operation) error {
for _, op := range ops {
if op.ConstImm == nil {
continue
}
_, _, _, immType, _ := op.shape()
if immType == ConstImm || immType == ConstVarImm {
op.In[0].Const = op.ConstImm
}
// Otherwise, just not port it - e.g. {VPCMP[BWDQ] imm=0} and {VPCMPEQ[BWDQ]} are
// the same operations "Equal", [dedupgodef] should be able to distinguish them.
}
return nil
}
func capitalizeFirst(s string) string {
if s == "" {
return ""
}
// Convert the string to a slice of runes to handle multi-byte characters correctly.
r := []rune(s)
r[0] = unicode.ToUpper(r[0])
return string(r)
}
// overwrite corrects some errors due to:
// - The XED data is wrong
// - Go's SIMD API requirement, for example AVX2 compares should also produce masks.
// This rewrite has strict constraints, please see the error message.
// These constraints are also explointed in [writeSIMDRules], [writeSIMDMachineOps]
// and [writeSIMDSSA], please be careful when updating these constraints.
func overwrite(ops []Operation) error {
hasClassOverwrite := false
overwrite := func(op []Operand, idx int, o Operation) error {
if op[idx].OverwriteElementBits != nil {
if op[idx].ElemBits == nil {
panic(fmt.Errorf("ElemBits is nil at operand %d of %v", idx, o))
}
*op[idx].ElemBits = *op[idx].OverwriteElementBits
*op[idx].Lanes = *op[idx].Bits / *op[idx].ElemBits
*op[idx].Go = fmt.Sprintf("%s%dx%d", capitalizeFirst(*op[idx].Base), *op[idx].ElemBits, *op[idx].Lanes)
}
if op[idx].OverwriteClass != nil {
if op[idx].OverwriteBase == nil {
panic(fmt.Errorf("simdgen: [OverwriteClass] must be set together with [OverwriteBase]: %s", op[idx]))
}
oBase := *op[idx].OverwriteBase
oClass := *op[idx].OverwriteClass
if oClass != "mask" {
panic(fmt.Errorf("simdgen: [Class] overwrite only supports overwritting to mask: %s", op[idx]))
}
if oBase != "int" {
panic(fmt.Errorf("simdgen: [Class] overwrite must set [OverwriteBase] to int: %s", op[idx]))
}
if op[idx].Class != "vreg" {
panic(fmt.Errorf("simdgen: [Class] overwrite must be overwriting [Class] from vreg: %s", op[idx]))
}
hasClassOverwrite = true
*op[idx].Base = oBase
op[idx].Class = oClass
*op[idx].Go = fmt.Sprintf("Mask%dx%d", *op[idx].ElemBits, *op[idx].Lanes)
} else if op[idx].OverwriteBase != nil {
oBase := *op[idx].OverwriteBase
*op[idx].Go = strings.ReplaceAll(*op[idx].Go, capitalizeFirst(*op[idx].Base), capitalizeFirst(oBase))
if op[idx].Class == "greg" {
*op[idx].Go = strings.ReplaceAll(*op[idx].Go, *op[idx].Base, oBase)
}
*op[idx].Base = oBase
}
return nil
}
for i, o := range ops {
hasClassOverwrite = false
for j := range ops[i].In {
if err := overwrite(ops[i].In, j, o); err != nil {
return err
}
if hasClassOverwrite {
return fmt.Errorf("simdgen does not support [OverwriteClass] in inputs: %s", ops[i])
}
}
for j := range ops[i].Out {
if err := overwrite(ops[i].Out, j, o); err != nil {
return err
}
}
if hasClassOverwrite {
for _, in := range ops[i].In {
if in.Class == "mask" {
return fmt.Errorf("simdgen only supports [OverwriteClass] for operations without mask inputs")
}
}
}
}
return nil
}
// reportXEDInconsistency reports potential XED inconsistencies.
// We can add more fields to [Operation] to enable more checks and implement it here.
// Supported checks:
// [NameAndSizeCheck]: NAME[BWDQ] should set the elemBits accordingly.
// This check is useful to find inconsistencies, then we can add overwrite fields to
// those defs to correct them manually.
func reportXEDInconsistency(ops []Operation) error {
for _, o := range ops {
if o.NameAndSizeCheck != nil {
suffixSizeMap := map[byte]int{'B': 8, 'W': 16, 'D': 32, 'Q': 64}
checkOperand := func(opr Operand) error {
if opr.ElemBits == nil {
return fmt.Errorf("simdgen expects elemBits to be set when performing NameAndSizeCheck")
}
if v, ok := suffixSizeMap[o.Asm[len(o.Asm)-1]]; !ok {
return fmt.Errorf("simdgen expects asm to end with [BWDQ] when performing NameAndSizeCheck")
} else {
if v != *opr.ElemBits {
return fmt.Errorf("simdgen finds NameAndSizeCheck inconsistency in def: %s", o)
}
}
return nil
}
for _, in := range o.In {
if in.Class != "vreg" && in.Class != "mask" {
continue
}
if in.TreatLikeAScalarOfSize != nil {
// This is an irregular operand, don't check it.
continue
}
if err := checkOperand(in); err != nil {
return err
}
}
for _, out := range o.Out {
if err := checkOperand(out); err != nil {
return err
}
}
}
}
return nil
}
func (o Operation) String() string {
return pprints(o)
}
func (op Operand) String() string {
return pprints(op)
}

View file

@ -0,0 +1 @@
!import ops/*/go.yaml

View file

@ -0,0 +1,379 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"log"
"regexp"
"slices"
"strconv"
"strings"
"simd/_gen/unify"
)
type Operation struct {
rawOperation
// Go is the Go method name of this operation.
//
// It is derived from the raw Go method name by adding optional suffixes.
// Currently, "Masked" is the only suffix.
Go string
// Documentation is the doc string for this API.
//
// It is computed from the raw documentation:
//
// - "NAME" is replaced by the Go method name.
//
// - For masked operation, a sentence about masking is added.
Documentation string
// In is the sequence of parameters to the Go method.
//
// For masked operations, this will have the mask operand appended.
In []Operand
}
// rawOperation is the unifier representation of an [Operation]. It is
// translated into a more parsed form after unifier decoding.
type rawOperation struct {
Go string // Base Go method name
GoArch string // GOARCH for this definition
Asm string // Assembly mnemonic
OperandOrder *string // optional Operand order for better Go declarations
// Optional tag to indicate this operation is paired with special generic->machine ssa lowering rules.
// Should be paired with special templates in gen_simdrules.go
SpecialLower *string
In []Operand // Parameters
InVariant []Operand // Optional parameters
Out []Operand // Results
Commutative bool // Commutativity
CPUFeature string // CPUID/Has* feature name
Zeroing *bool // nil => use asm suffix ".Z"; false => do not use asm suffix ".Z"
Documentation *string // Documentation will be appended to the stubs comments.
// ConstMask is a hack to reduce the size of defs the user writes for const-immediate
// If present, it will be copied to [In[0].Const].
ConstImm *string
// NameAndSizeCheck is used to check [BWDQ] maps to (8|16|32|64) elemBits.
NameAndSizeCheck *bool
// If non-nil, all generation in gen_simdTypes.go and gen_intrinsics will be skipped.
NoTypes *string
// If non-nil, all generation in gen_simdGenericOps and gen_simdrules will be skipped.
NoGenericOps *string
// If non-nil, this string will be attached to the machine ssa op name.
SSAVariant *string
}
func (o *Operation) DecodeUnified(v *unify.Value) error {
if err := v.Decode(&o.rawOperation); err != nil {
return err
}
isMasked := false
if len(o.InVariant) == 0 {
// No variant
} else if len(o.InVariant) == 1 && o.InVariant[0].Class == "mask" {
isMasked = true
} else {
return fmt.Errorf("unknown inVariant")
}
// Compute full Go method name.
o.Go = o.rawOperation.Go
if isMasked {
o.Go += "Masked"
}
// Compute doc string.
if o.rawOperation.Documentation != nil {
o.Documentation = *o.rawOperation.Documentation
} else {
o.Documentation = "// UNDOCUMENTED"
}
o.Documentation = regexp.MustCompile(`\bNAME\b`).ReplaceAllString(o.Documentation, o.Go)
if isMasked {
o.Documentation += "\n//\n// This operation is applied selectively under a write mask."
}
o.In = append(o.rawOperation.In, o.rawOperation.InVariant...)
return nil
}
func (o *Operation) VectorWidth() int {
out := o.Out[0]
if out.Class == "vreg" {
return *out.Bits
} else if out.Class == "greg" || out.Class == "mask" {
for i := range o.In {
if o.In[i].Class == "vreg" {
return *o.In[i].Bits
}
}
}
panic(fmt.Errorf("Figure out what the vector width is for %v and implement it", *o))
}
func machineOpName(maskType maskShape, gOp Operation) string {
asm := gOp.Asm
if maskType == 2 {
asm += "Masked"
}
asm = fmt.Sprintf("%s%d", asm, gOp.VectorWidth())
if gOp.SSAVariant != nil {
asm += *gOp.SSAVariant
}
return asm
}
func compareStringPointers(x, y *string) int {
if x != nil && y != nil {
return compareNatural(*x, *y)
}
if x == nil && y == nil {
return 0
}
if x == nil {
return -1
}
return 1
}
func compareIntPointers(x, y *int) int {
if x != nil && y != nil {
return *x - *y
}
if x == nil && y == nil {
return 0
}
if x == nil {
return -1
}
return 1
}
func compareOperations(x, y Operation) int {
if c := compareNatural(x.Go, y.Go); c != 0 {
return c
}
xIn, yIn := x.In, y.In
if len(xIn) > len(yIn) && xIn[len(xIn)-1].Class == "mask" {
xIn = xIn[:len(xIn)-1]
} else if len(xIn) < len(yIn) && yIn[len(yIn)-1].Class == "mask" {
yIn = yIn[:len(yIn)-1]
}
if len(xIn) < len(yIn) {
return -1
}
if len(xIn) > len(yIn) {
return 1
}
if len(x.Out) < len(y.Out) {
return -1
}
if len(x.Out) > len(y.Out) {
return 1
}
for i := range xIn {
ox, oy := &xIn[i], &yIn[i]
if c := compareOperands(ox, oy); c != 0 {
return c
}
}
return 0
}
func compareOperands(x, y *Operand) int {
if c := compareNatural(x.Class, y.Class); c != 0 {
return c
}
if x.Class == "immediate" {
return compareStringPointers(x.ImmOffset, y.ImmOffset)
} else {
if c := compareStringPointers(x.Base, y.Base); c != 0 {
return c
}
if c := compareIntPointers(x.ElemBits, y.ElemBits); c != 0 {
return c
}
if c := compareIntPointers(x.Bits, y.Bits); c != 0 {
return c
}
return 0
}
}
type Operand struct {
Class string // One of "mask", "immediate", "vreg", "greg", and "mem"
Go *string // Go type of this operand
AsmPos int // Position of this operand in the assembly instruction
Base *string // Base Go type ("int", "uint", "float")
ElemBits *int // Element bit width
Bits *int // Total vector bit width
Const *string // Optional constant value for immediates.
// Optional immediate arg offsets. If this field is non-nil,
// This operand will be an immediate operand:
// The compiler will right-shift the user-passed value by ImmOffset and set it as the AuxInt
// field of the operation.
ImmOffset *string
Name *string // optional name in the Go intrinsic declaration
Lanes *int // *Lanes equals Bits/ElemBits except for scalars, when *Lanes == 1
// TreatLikeAScalarOfSize means only the lower $TreatLikeAScalarOfSize bits of the vector
// is used, so at the API level we can make it just a scalar value of this size; Then we
// can overwrite it to a vector of the right size during intrinsics stage.
TreatLikeAScalarOfSize *int
// If non-nil, it means the [Class] field is overwritten here, right now this is used to
// overwrite the results of AVX2 compares to masks.
OverwriteClass *string
// If non-nil, it means the [Base] field is overwritten here. This field exist solely
// because Intel's XED data is inconsistent. e.g. VANDNP[SD] marks its operand int.
OverwriteBase *string
// If non-nil, it means the [ElementBits] field is overwritten. This field exist solely
// because Intel's XED data is inconsistent. e.g. AVX512 VPMADDUBSW marks its operand
// elemBits 16, which should be 8.
OverwriteElementBits *int
}
// isDigit returns true if the byte is an ASCII digit.
func isDigit(b byte) bool {
return b >= '0' && b <= '9'
}
// compareNatural performs a "natural sort" comparison of two strings.
// It compares non-digit sections lexicographically and digit sections
// numerically. In the case of string-unequal "equal" strings like
// "a01b" and "a1b", strings.Compare breaks the tie.
//
// It returns:
//
// -1 if s1 < s2
// 0 if s1 == s2
// +1 if s1 > s2
func compareNatural(s1, s2 string) int {
i, j := 0, 0
len1, len2 := len(s1), len(s2)
for i < len1 && j < len2 {
// Find a non-digit segment or a number segment in both strings.
if isDigit(s1[i]) && isDigit(s2[j]) {
// Number segment comparison.
numStart1 := i
for i < len1 && isDigit(s1[i]) {
i++
}
num1, _ := strconv.Atoi(s1[numStart1:i])
numStart2 := j
for j < len2 && isDigit(s2[j]) {
j++
}
num2, _ := strconv.Atoi(s2[numStart2:j])
if num1 < num2 {
return -1
}
if num1 > num2 {
return 1
}
// If numbers are equal, continue to the next segment.
} else {
// Non-digit comparison.
if s1[i] < s2[j] {
return -1
}
if s1[i] > s2[j] {
return 1
}
i++
j++
}
}
// deal with a01b vs a1b; there needs to be an order.
return strings.Compare(s1, s2)
}
const generatedHeader = `// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
`
func writeGoDefs(path string, cl unify.Closure) error {
// TODO: Merge operations with the same signature but multiple
// implementations (e.g., SSE vs AVX)
var ops []Operation
for def := range cl.All() {
var op Operation
if !def.Exact() {
continue
}
if err := def.Decode(&op); err != nil {
log.Println(err.Error())
log.Println(def)
continue
}
// TODO: verify that this is safe.
op.sortOperand()
ops = append(ops, op)
}
slices.SortFunc(ops, compareOperations)
// The parsed XED data might contain duplicates, like
// 512 bits VPADDP.
deduped := dedup(ops)
slices.SortFunc(deduped, compareOperations)
if *Verbose {
log.Printf("dedup len: %d\n", len(ops))
}
var err error
if err = overwrite(deduped); err != nil {
return err
}
if *Verbose {
log.Printf("dedup len: %d\n", len(deduped))
}
if *Verbose {
log.Printf("dedup len: %d\n", len(deduped))
}
if !*FlagNoDedup {
// TODO: This can hide mistakes in the API definitions, especially when
// multiple patterns result in the same API unintentionally. Make it stricter.
if deduped, err = dedupGodef(deduped); err != nil {
return err
}
}
if *Verbose {
log.Printf("dedup len: %d\n", len(deduped))
}
if !*FlagNoConstImmPorting {
if err = copyConstImm(deduped); err != nil {
return err
}
}
if *Verbose {
log.Printf("dedup len: %d\n", len(deduped))
}
reportXEDInconsistency(deduped)
typeMap := parseSIMDTypes(deduped)
formatWriteAndClose(writeSIMDTypes(typeMap), path, "src/"+simdPackage+"/types_amd64.go")
formatWriteAndClose(writeSIMDFeatures(deduped), path, "src/"+simdPackage+"/cpu.go")
formatWriteAndClose(writeSIMDStubs(deduped, typeMap), path, "src/"+simdPackage+"/ops_amd64.go")
formatWriteAndClose(writeSIMDIntrinsics(deduped, typeMap), path, "src/cmd/compile/internal/ssagen/simdintrinsics.go")
formatWriteAndClose(writeSIMDGenericOps(deduped), path, "src/cmd/compile/internal/ssa/_gen/simdgenericOps.go")
formatWriteAndClose(writeSIMDMachineOps(deduped), path, "src/cmd/compile/internal/ssa/_gen/simdAMD64ops.go")
formatWriteAndClose(writeSIMDSSA(deduped), path, "src/cmd/compile/internal/amd64/simdssa.go")
writeAndClose(writeSIMDRules(deduped).Bytes(), path, "src/cmd/compile/internal/ssa/_gen/simdAMD64.rules")
return nil
}

View file

@ -0,0 +1,280 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// simdgen is an experiment in generating Go <-> asm SIMD mappings.
//
// Usage: simdgen [-xedPath=path] [-q=query] input.yaml...
//
// If -xedPath is provided, one of the inputs is a sum of op-code definitions
// generated from the Intel XED data at path.
//
// If input YAML files are provided, each file is read as an input value. See
// [unify.Closure.UnmarshalYAML] or "go doc unify.Closure.UnmarshalYAML" for the
// format of these files.
//
// TODO: Example definitions and values.
//
// The command unifies across all of the inputs and prints all possible results
// of this unification.
//
// If the -q flag is provided, its string value is parsed as a value and treated
// as another input to unification. This is intended as a way to "query" the
// result, typically by narrowing it down to a small subset of results.
//
// Typical usage:
//
// go run . -xedPath $XEDPATH *.yaml
//
// To see just the definitions generated from XED, run:
//
// go run . -xedPath $XEDPATH
//
// (This works because if there's only one input, there's nothing to unify it
// with, so the result is simply itself.)
//
// To see just the definitions for VPADDQ:
//
// go run . -xedPath $XEDPATH -q '{asm: VPADDQ}'
//
// simdgen can also generate Go definitions of SIMD mappings:
// To generate go files to the go root, run:
//
// go run . -xedPath $XEDPATH -o godefs -goroot $PATH/TO/go go.yaml categories.yaml types.yaml
//
// types.yaml is already written, it specifies the shapes of vectors.
// categories.yaml and go.yaml contains definitions that unifies with types.yaml and XED
// data, you can find an example in ops/AddSub/.
//
// When generating Go definitions, simdgen do 3 "magic"s:
// - It splits masked operations(with op's [Masked] field set) to const and non const:
// - One is a normal masked operation, the original
// - The other has its mask operand's [Const] fields set to "K0".
// - This way the user does not need to provide a separate "K0"-masked operation def.
//
// - It deduplicates intrinsic names that have duplicates:
// - If there are two operations that shares the same signature, one is AVX512 the other
// is before AVX512, the other will be selected.
// - This happens often when some operations are defined both before AVX512 and after.
// This way the user does not need to provide a separate "K0" operation for the
// AVX512 counterpart.
//
// - It copies the op's [ConstImm] field to its immediate operand's [Const] field.
// - This way the user does not need to provide verbose op definition while only
// the const immediate field is different. This is useful to reduce verbosity of
// compares with imm control predicates.
//
// These 3 magics could be disabled by enabling -nosplitmask, -nodedup or
// -noconstimmporting flags.
//
// simdgen right now only supports amd64, -arch=$OTHERARCH will trigger a fatal error.
package main
// Big TODOs:
//
// - This can produce duplicates, which can also lead to less efficient
// environment merging. Add hashing and use it for deduplication. Be careful
// about how this shows up in debug traces, since it could make things
// confusing if we don't show it happening.
//
// - Do I need Closure, Value, and Domain? It feels like I should only need two
// types.
import (
"cmp"
"flag"
"fmt"
"log"
"maps"
"os"
"path/filepath"
"runtime/pprof"
"slices"
"strings"
"gopkg.in/yaml.v3"
"simd/_gen/unify"
)
var (
xedPath = flag.String("xedPath", "", "load XED datafiles from `path`")
flagQ = flag.String("q", "", "query: read `def` as another input (skips final validation)")
flagO = flag.String("o", "yaml", "output type: yaml, godefs (generate definitions into a Go source tree")
flagGoDefRoot = flag.String("goroot", ".", "the path to the Go dev directory that will receive the generated files")
FlagNoDedup = flag.Bool("nodedup", false, "disable deduplicating godefs of 2 qualifying operations from different extensions")
FlagNoConstImmPorting = flag.Bool("noconstimmporting", false, "disable const immediate porting from op to imm operand")
FlagArch = flag.String("arch", "amd64", "the target architecture")
Verbose = flag.Bool("v", false, "verbose")
flagDebugXED = flag.Bool("debug-xed", false, "show XED instructions")
flagDebugUnify = flag.Bool("debug-unify", false, "print unification trace")
flagDebugHTML = flag.String("debug-html", "", "write unification trace to `file.html`")
FlagReportDup = flag.Bool("reportdup", false, "report the duplicate godefs")
flagCPUProfile = flag.String("cpuprofile", "", "write CPU profile to `file`")
flagMemProfile = flag.String("memprofile", "", "write memory profile to `file`")
)
const simdPackage = "simd"
func main() {
flag.Parse()
if *flagCPUProfile != "" {
f, err := os.Create(*flagCPUProfile)
if err != nil {
log.Fatalf("-cpuprofile: %s", err)
}
defer f.Close()
pprof.StartCPUProfile(f)
defer pprof.StopCPUProfile()
}
if *flagMemProfile != "" {
f, err := os.Create(*flagMemProfile)
if err != nil {
log.Fatalf("-memprofile: %s", err)
}
defer func() {
pprof.WriteHeapProfile(f)
f.Close()
}()
}
var inputs []unify.Closure
if *FlagArch != "amd64" {
log.Fatalf("simdgen only supports amd64")
}
// Load XED into a defs set.
if *xedPath != "" {
xedDefs := loadXED(*xedPath)
inputs = append(inputs, unify.NewSum(xedDefs...))
}
// Load query.
if *flagQ != "" {
r := strings.NewReader(*flagQ)
def, err := unify.Read(r, "<query>", unify.ReadOpts{})
if err != nil {
log.Fatalf("parsing -q: %s", err)
}
inputs = append(inputs, def)
}
// Load defs files.
must := make(map[*unify.Value]struct{})
for _, path := range flag.Args() {
defs, err := unify.ReadFile(path, unify.ReadOpts{})
if err != nil {
log.Fatal(err)
}
inputs = append(inputs, defs)
if filepath.Base(path) == "go.yaml" {
// These must all be used in the final result
for def := range defs.Summands() {
must[def] = struct{}{}
}
}
}
// Prepare for unification
if *flagDebugUnify {
unify.Debug.UnifyLog = os.Stderr
}
if *flagDebugHTML != "" {
f, err := os.Create(*flagDebugHTML)
if err != nil {
log.Fatal(err)
}
unify.Debug.HTML = f
defer f.Close()
}
// Unify!
unified, err := unify.Unify(inputs...)
if err != nil {
log.Fatal(err)
}
// Print results.
switch *flagO {
case "yaml":
// Produce a result that looks like encoding a slice, but stream it.
fmt.Println("!sum")
var val1 [1]*unify.Value
for val := range unified.All() {
val1[0] = val
// We have to make a new encoder each time or it'll print a document
// separator between each object.
enc := yaml.NewEncoder(os.Stdout)
if err := enc.Encode(val1); err != nil {
log.Fatal(err)
}
enc.Close()
}
case "godefs":
if err := writeGoDefs(*flagGoDefRoot, unified); err != nil {
log.Fatalf("Failed writing godefs: %+v", err)
}
}
if !*Verbose && *xedPath != "" {
if operandRemarks == 0 {
fmt.Fprintf(os.Stderr, "XED decoding generated no errors, which is unusual.\n")
} else {
fmt.Fprintf(os.Stderr, "XED decoding generated %d \"errors\" which is not cause for alarm, use -v for details.\n", operandRemarks)
}
}
// Validate results.
//
// Don't validate if this is a command-line query because that tends to
// eliminate lots of required defs and is used in cases where maybe defs
// aren't enumerable anyway.
if *flagQ == "" && len(must) > 0 {
validate(unified, must)
}
}
func validate(cl unify.Closure, required map[*unify.Value]struct{}) {
// Validate that:
// 1. All final defs are exact
// 2. All required defs are used
for def := range cl.All() {
if _, ok := def.Domain.(unify.Def); !ok {
fmt.Fprintf(os.Stderr, "%s: expected Def, got %T\n", def.PosString(), def.Domain)
continue
}
if !def.Exact() {
fmt.Fprintf(os.Stderr, "%s: def not reduced to an exact value, why is %s:\n", def.PosString(), def.WhyNotExact())
fmt.Fprintf(os.Stderr, "\t%s\n", strings.ReplaceAll(def.String(), "\n", "\n\t"))
}
for root := range def.Provenance() {
delete(required, root)
}
}
// Report unused defs
unused := slices.SortedFunc(maps.Keys(required),
func(a, b *unify.Value) int {
return cmp.Or(
cmp.Compare(a.Pos().Path, b.Pos().Path),
cmp.Compare(a.Pos().Line, b.Pos().Line),
)
})
for _, def := range unused {
// TODO: Can we say anything more actionable? This is always a problem
// with unification: if it fails, it's very hard to point a finger at
// any particular reason. We could go back and try unifying this again
// with each subset of the inputs (starting with individual inputs) to
// at least say "it doesn't unify with anything in x.yaml". That's a lot
// of work, but if we have trouble debugging unification failure it may
// be worth it.
fmt.Fprintf(os.Stderr, "%s: def required, but did not unify (%v)\n",
def.PosString(), def)
}
}

View file

@ -0,0 +1,37 @@
!sum
- go: Add
commutative: true
documentation: !string |-
// NAME adds corresponding elements of two vectors.
- go: AddSaturated
commutative: true
documentation: !string |-
// NAME adds corresponding elements of two vectors with saturation.
- go: Sub
commutative: false
documentation: !string |-
// NAME subtracts corresponding elements of two vectors.
- go: SubSaturated
commutative: false
documentation: !string |-
// NAME subtracts corresponding elements of two vectors with saturation.
- go: AddPairs
commutative: false
documentation: !string |-
// NAME horizontally adds adjacent pairs of elements.
// For x = [x0, x1, x2, x3, ...] and y = [y0, y1, y2, y3, ...], the result is [y0+y1, y2+y3, ..., x0+x1, x2+x3, ...].
- go: SubPairs
commutative: false
documentation: !string |-
// NAME horizontally subtracts adjacent pairs of elements.
// For x = [x0, x1, x2, x3, ...] and y = [y0, y1, y2, y3, ...], the result is [y0-y1, y2-y3, ..., x0-x1, x2-x3, ...].
- go: AddPairsSaturated
commutative: false
documentation: !string |-
// NAME horizontally adds adjacent pairs of elements with saturation.
// For x = [x0, x1, x2, x3, ...] and y = [y0, y1, y2, y3, ...], the result is [y0+y1, y2+y3, ..., x0+x1, x2+x3, ...].
- go: SubPairsSaturated
commutative: false
documentation: !string |-
// NAME horizontally subtracts adjacent pairs of elements with saturation.
// For x = [x0, x1, x2, x3, ...] and y = [y0, y1, y2, y3, ...], the result is [y0-y1, y2-y3, ..., x0-x1, x2-x3, ...].

View file

@ -0,0 +1,77 @@
!sum
# Add
- go: Add
asm: "VPADD[BWDQ]|VADDP[SD]"
in:
- &any
go: $t
- *any
out:
- *any
# Add Saturated
- go: AddSaturated
asm: "VPADDS[BWDQ]"
in:
- &int
go: $t
base: int
- *int
out:
- *int
- go: AddSaturated
asm: "VPADDUS[BWDQ]"
in:
- &uint
go: $t
base: uint
- *uint
out:
- *uint
# Sub
- go: Sub
asm: "VPSUB[BWDQ]|VSUBP[SD]"
in: &2any
- *any
- *any
out: &1any
- *any
# Sub Saturated
- go: SubSaturated
asm: "VPSUBS[BWDQ]"
in: &2int
- *int
- *int
out: &1int
- *int
- go: SubSaturated
asm: "VPSUBUS[BWDQ]"
in:
- *uint
- *uint
out:
- *uint
- go: AddPairs
asm: "VPHADD[DW]"
in: *2any
out: *1any
- go: SubPairs
asm: "VPHSUB[DW]"
in: *2any
out: *1any
- go: AddPairs
asm: "VHADDP[SD]" # floats
in: *2any
out: *1any
- go: SubPairs
asm: "VHSUBP[SD]" # floats
in: *2any
out: *1any
- go: AddPairsSaturated
asm: "VPHADDS[DW]"
in: *2int
out: *1int
- go: SubPairsSaturated
asm: "VPHSUBS[DW]"
in: *2int
out: *1int

View file

@ -0,0 +1,20 @@
!sum
- go: And
commutative: true
documentation: !string |-
// NAME performs a bitwise AND operation between two vectors.
- go: Or
commutative: true
documentation: !string |-
// NAME performs a bitwise OR operation between two vectors.
- go: AndNot
commutative: false
documentation: !string |-
// NAME performs a bitwise x &^ y.
- go: Xor
commutative: true
documentation: !string |-
// NAME performs a bitwise XOR operation between two vectors.
# We also have PTEST and VPTERNLOG, those should be hidden from the users
# and only appear in rewrite rules.

View file

@ -0,0 +1,128 @@
!sum
# In the XED data, *all* floating point bitwise logic operation has their
# operand type marked as uint. We are not trying to understand why Intel
# decided that they want FP bit-wise logic operations, but this irregularity
# has to be dealed with in separate rules with some overwrites.
# For many bit-wise operations, we have the following non-orthogonal
# choices:
#
# - Non-masked AVX operations have no element width (because it
# doesn't matter), but only cover 128 and 256 bit vectors.
#
# - Masked AVX-512 operations have an element width (because it needs
# to know how to interpret the mask), and cover 128, 256, and 512 bit
# vectors. These only cover 32- and 64-bit element widths.
#
# - Non-masked AVX-512 operations still have an element width (because
# they're just the masked operations with an implicit K0 mask) but it
# doesn't matter! This is the only option for non-masked 512 bit
# operations, and we can pick any of the element widths.
#
# We unify with ALL of these operations and the compiler generator
# picks when there are multiple options.
# TODO: We don't currently generate unmasked bit-wise operations on 512 bit
# vectors of 8- or 16-bit elements. AVX-512 only has *masked* bit-wise
# operations for 32- and 64-bit elements; while the element width doesn't matter
# for unmasked operations, right now we don't realize that we can just use the
# 32- or 64-bit version for the unmasked form. Maybe in the XED decoder we
# should recognize bit-wise operations when generating unmasked versions and
# omit the element width.
# For binary operations, we constrain their two inputs and one output to the
# same Go type using a variable.
- go: And
asm: "VPAND[DQ]?"
in:
- &any
go: $t
- *any
out:
- *any
- go: And
asm: "VPANDD" # Fill in the gap, And is missing for Uint8x64 and Int8x64
inVariant: []
in: &twoI8x64
- &i8x64
go: $t
overwriteElementBits: 8
- *i8x64
out: &oneI8x64
- *i8x64
- go: And
asm: "VPANDD" # Fill in the gap, And is missing for Uint16x32 and Int16x32
inVariant: []
in: &twoI16x32
- &i16x32
go: $t
overwriteElementBits: 16
- *i16x32
out: &oneI16x32
- *i16x32
- go: AndNot
asm: "VPANDN[DQ]?"
operandOrder: "21" # switch the arg order
in:
- *any
- *any
out:
- *any
- go: AndNot
asm: "VPANDND" # Fill in the gap, AndNot is missing for Uint8x64 and Int8x64
operandOrder: "21" # switch the arg order
inVariant: []
in: *twoI8x64
out: *oneI8x64
- go: AndNot
asm: "VPANDND" # Fill in the gap, AndNot is missing for Uint16x32 and Int16x32
operandOrder: "21" # switch the arg order
inVariant: []
in: *twoI16x32
out: *oneI16x32
- go: Or
asm: "VPOR[DQ]?"
in:
- *any
- *any
out:
- *any
- go: Or
asm: "VPORD" # Fill in the gap, Or is missing for Uint8x64 and Int8x64
inVariant: []
in: *twoI8x64
out: *oneI8x64
- go: Or
asm: "VPORD" # Fill in the gap, Or is missing for Uint16x32 and Int16x32
inVariant: []
in: *twoI16x32
out: *oneI16x32
- go: Xor
asm: "VPXOR[DQ]?"
in:
- *any
- *any
out:
- *any
- go: Xor
asm: "VPXORD" # Fill in the gap, Or is missing for Uint8x64 and Int8x64
inVariant: []
in: *twoI8x64
out: *oneI8x64
- go: Xor
asm: "VPXORD" # Fill in the gap, Or is missing for Uint16x32 and Int16x32
inVariant: []
in: *twoI16x32
out: *oneI16x32

View file

@ -0,0 +1,43 @@
!sum
# const imm predicate(holds for both float and int|uint):
# 0: Equal
# 1: Less
# 2: LessEqual
# 4: NotEqual
# 5: GreaterEqual
# 6: Greater
- go: Equal
constImm: 0
commutative: true
documentation: !string |-
// NAME compares for equality.
- go: Less
constImm: 1
commutative: false
documentation: !string |-
// NAME compares for less than.
- go: LessEqual
constImm: 2
commutative: false
documentation: !string |-
// NAME compares for less than or equal.
- go: IsNan # For float only.
constImm: 3
commutative: true
documentation: !string |-
// NAME checks if elements are NaN. Use as x.IsNan(x).
- go: NotEqual
constImm: 4
commutative: true
documentation: !string |-
// NAME compares for inequality.
- go: GreaterEqual
constImm: 13
commutative: false
documentation: !string |-
// NAME compares for greater than or equal.
- go: Greater
constImm: 14
commutative: false
documentation: !string |-
// NAME compares for greater than.

View file

@ -0,0 +1,141 @@
!sum
# Ints
- go: Equal
asm: "V?PCMPEQ[BWDQ]"
in:
- &any
go: $t
- *any
out:
- &anyvregToMask
go: $t
overwriteBase: int
overwriteClass: mask
- go: Greater
asm: "V?PCMPGT[BWDQ]"
in:
- &int
go: $t
base: int
- *int
out:
- *anyvregToMask
# 256-bit VCMPGTQ's output elemBits is marked 32-bit in the XED data, we
# believe this is an error, so add this definition to overwrite.
- go: Greater
asm: "VPCMPGTQ"
in:
- &int64
go: $t
base: int
elemBits: 64
- *int64
out:
- base: int
elemBits: 32
overwriteElementBits: 64
overwriteClass: mask
overwriteBase: int
# TODO these are redundant with VPCMP operations.
# AVX-512 compares produce masks.
- go: Equal
asm: "V?PCMPEQ[BWDQ]"
in:
- *any
- *any
out:
- class: mask
- go: Greater
asm: "V?PCMPGT[BWDQ]"
in:
- *int
- *int
out:
- class: mask
# MASKED signed comparisons for X/Y registers
# unmasked would clash with emulations on AVX2
- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual)
asm: "VPCMP[BWDQ]"
in:
- &int
bits: (128|256)
go: $t
base: int
- *int
- class: immediate
const: 0 # Just a placeholder, will be overwritten by const imm porting.
inVariant:
- class: mask
out:
- class: mask
# MASKED unsigned comparisons for X/Y registers
# unmasked would clash with emulations on AVX2
- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual)
asm: "VPCMPU[BWDQ]"
in:
- &uint
bits: (128|256)
go: $t
base: uint
- *uint
- class: immediate
const: 0
inVariant:
- class: mask
out:
- class: mask
# masked/unmasked signed comparisons for Z registers
- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual)
asm: "VPCMP[BWDQ]"
in:
- &int
bits: 512
go: $t
base: int
- *int
- class: immediate
const: 0 # Just a placeholder, will be overwritten by const imm porting.
out:
- class: mask
# masked/unmasked unsigned comparisons for Z registers
- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual)
asm: "VPCMPU[BWDQ]"
in:
- &uint
bits: 512
go: $t
base: uint
- *uint
- class: immediate
const: 0
out:
- class: mask
# Floats
- go: Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual|IsNan
asm: "VCMPP[SD]"
in:
- &float
go: $t
base: float
- *float
- class: immediate
const: 0
out:
- go: $t
overwriteBase: int
overwriteClass: mask
- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual|IsNan)
asm: "VCMPP[SD]"
in:
- *float
- *float
- class: immediate
const: 0
out:
- class: mask

View file

@ -0,0 +1,10 @@
!sum
- go: ConvertToInt32
commutative: false
documentation: !string |-
// ConvertToInt32 converts element values to int32.
- go: ConvertToUint32
commutative: false
documentation: !string |-
// ConvertToUint32Masked converts element values to uint32.

View file

@ -0,0 +1,21 @@
!sum
- go: ConvertToInt32
asm: "VCVTTPS2DQ"
in:
- &fp
go: $t
base: float
out:
- &i32
go: $u
base: int
elemBits: 32
- go: ConvertToUint32
asm: "VCVTPS2UDQ"
in:
- *fp
out:
- &u32
go: $u
base: uint
elemBits: 32

View file

@ -0,0 +1,85 @@
!sum
- go: Div
commutative: false
documentation: !string |-
// NAME divides elements of two vectors.
- go: Sqrt
commutative: false
documentation: !string |-
// NAME computes the square root of each element.
- go: Reciprocal
commutative: false
documentation: !string |-
// NAME computes an approximate reciprocal of each element.
- go: ReciprocalSqrt
commutative: false
documentation: !string |-
// NAME computes an approximate reciprocal of the square root of each element.
- go: Scale
commutative: false
documentation: !string |-
// NAME multiplies elements by a power of 2.
- go: RoundToEven
commutative: false
constImm: 0
documentation: !string |-
// NAME rounds elements to the nearest integer.
- go: RoundToEvenScaled
commutative: false
constImm: 0
documentation: !string |-
// NAME rounds elements with specified precision.
- go: RoundToEvenScaledResidue
commutative: false
constImm: 0
documentation: !string |-
// NAME computes the difference after rounding with specified precision.
- go: Floor
commutative: false
constImm: 1
documentation: !string |-
// NAME rounds elements down to the nearest integer.
- go: FloorScaled
commutative: false
constImm: 1
documentation: !string |-
// NAME rounds elements down with specified precision.
- go: FloorScaledResidue
commutative: false
constImm: 1
documentation: !string |-
// NAME computes the difference after flooring with specified precision.
- go: Ceil
commutative: false
constImm: 2
documentation: !string |-
// NAME rounds elements up to the nearest integer.
- go: CeilScaled
commutative: false
constImm: 2
documentation: !string |-
// NAME rounds elements up with specified precision.
- go: CeilScaledResidue
commutative: false
constImm: 2
documentation: !string |-
// NAME computes the difference after ceiling with specified precision.
- go: Trunc
commutative: false
constImm: 3
documentation: !string |-
// NAME truncates elements towards zero.
- go: TruncScaled
commutative: false
constImm: 3
documentation: !string |-
// NAME truncates elements with specified precision.
- go: TruncScaledResidue
commutative: false
constImm: 3
documentation: !string |-
// NAME computes the difference after truncating with specified precision.
- go: AddSub
commutative: false
documentation: !string |-
// NAME subtracts even elements and adds odd elements of two vectors.

View file

@ -0,0 +1,62 @@
!sum
- go: Div
asm: "V?DIVP[SD]"
in: &2fp
- &fp
go: $t
base: float
- *fp
out: &1fp
- *fp
- go: Sqrt
asm: "V?SQRTP[SD]"
in: *1fp
out: *1fp
# TODO: Provide separate methods for 12-bit precision and 14-bit precision?
- go: Reciprocal
asm: "VRCP(14)?P[SD]"
in: *1fp
out: *1fp
- go: ReciprocalSqrt
asm: "V?RSQRT(14)?P[SD]"
in: *1fp
out: *1fp
- go: Scale
asm: "VSCALEFP[SD]"
in: *2fp
out: *1fp
- go: "RoundToEven|Ceil|Floor|Trunc"
asm: "VROUNDP[SD]"
in:
- *fp
- class: immediate
const: 0 # place holder
out: *1fp
- go: "(RoundToEven|Ceil|Floor|Trunc)Scaled"
asm: "VRNDSCALEP[SD]"
in:
- *fp
- class: immediate
const: 0 # place holder
immOffset: 4 # "M", round to numbers with M digits after dot(by means of binary number).
name: prec
out: *1fp
- go: "(RoundToEven|Ceil|Floor|Trunc)ScaledResidue"
asm: "VREDUCEP[SD]"
in:
- *fp
- class: immediate
const: 0 # place holder
immOffset: 4 # "M", round to numbers with M digits after dot(by means of binary number).
name: prec
out: *1fp
- go: "AddSub"
asm: "VADDSUBP[SD]"
in:
- *fp
- *fp
out:
- *fp

View file

@ -0,0 +1,21 @@
!sum
- go: GaloisFieldAffineTransform
commutative: false
documentation: !string |-
// NAME computes an affine transformation in GF(2^8):
// x is a vector of 8-bit vectors, with each adjacent 8 as a group; y is a vector of 8x8 1-bit matrixes;
// b is an 8-bit vector. The affine transformation is y * x + b, with each element of y
// corresponding to a group of 8 elements in x.
- go: GaloisFieldAffineTransformInverse
commutative: false
documentation: !string |-
// NAME computes an affine transformation in GF(2^8),
// with x inverted with respect to reduction polynomial x^8 + x^4 + x^3 + x + 1:
// x is a vector of 8-bit vectors, with each adjacent 8 as a group; y is a vector of 8x8 1-bit matrixes;
// b is an 8-bit vector. The affine transformation is y * x + b, with each element of y
// corresponding to a group of 8 elements in x.
- go: GaloisFieldMul
commutative: false
documentation: !string |-
// NAME computes element-wise GF(2^8) multiplication with
// reduction polynomial x^8 + x^4 + x^3 + x + 1.

View file

@ -0,0 +1,32 @@
!sum
- go: GaloisFieldAffineTransform
asm: VGF2P8AFFINEQB
operandOrder: 2I # 2nd operand, then immediate
in: &AffineArgs
- &uint8
go: $t
base: uint
- &uint8x8
go: $t2
base: uint
- &pureImmVar
class: immediate
immOffset: 0
name: b
out:
- *uint8
- go: GaloisFieldAffineTransformInverse
asm: VGF2P8AFFINEINVQB
operandOrder: 2I # 2nd operand, then immediate
in: *AffineArgs
out:
- *uint8
- go: GaloisFieldMul
asm: VGF2P8MULB
in:
- *uint8
- *uint8
out:
- *uint8

View file

@ -0,0 +1,21 @@
!sum
- go: Average
commutative: true
documentation: !string |-
// NAME computes the rounded average of corresponding elements.
- go: Abs
commutative: false
# Unary operation, not commutative
documentation: !string |-
// NAME computes the absolute value of each element.
- go: CopySign
# Applies sign of second operand to first: sign(val, sign_src)
commutative: false
documentation: !string |-
// NAME returns the product of the first operand with -1, 0, or 1,
// whichever constant is nearest to the value of the second operand.
# Sign does not have masked version
- go: OnesCount
commutative: false
documentation: !string |-
// NAME counts the number of set bits in each element.

View file

@ -0,0 +1,45 @@
!sum
# Average (unsigned byte, unsigned word)
# Instructions: VPAVGB, VPAVGW
- go: Average
asm: "VPAVG[BW]" # Matches VPAVGB (byte) and VPAVGW (word)
in:
- &uint_t # $t will be Uint8xN for VPAVGB, Uint16xN for VPAVGW
go: $t
base: uint
- *uint_t
out:
- *uint_t
# Absolute Value (signed byte, word, dword, qword)
# Instructions: VPABSB, VPABSW, VPABSD, VPABSQ
- go: Abs
asm: "VPABS[BWDQ]" # Matches VPABSB, VPABSW, VPABSD, VPABSQ
in:
- &int_t # $t will be Int8xN, Int16xN, Int32xN, Int64xN
go: $t
base: int
out:
- *int_t # Output is magnitude, fits in the same signed type
# Sign Operation (signed byte, word, dword)
# Applies sign of second operand to the first.
# Instructions: VPSIGNB, VPSIGNW, VPSIGND
- go: CopySign
asm: "VPSIGN[BWD]" # Matches VPSIGNB, VPSIGNW, VPSIGND
in:
- *int_t # value to apply sign to
- *int_t # value from which to take the sign
out:
- *int_t
# Population Count (count set bits in each element)
# Instructions: VPOPCNTB, VPOPCNTW (AVX512_BITALG)
# VPOPCNTD, VPOPCNTQ (AVX512_VPOPCNTDQ)
- go: OnesCount
asm: "VPOPCNT[BWDQ]"
in:
- &any
go: $t
out:
- *any

View file

@ -0,0 +1,47 @@
!sum
- go: DotProdPairs
commutative: false
documentation: !string |-
// NAME multiplies the elements and add the pairs together,
// yielding a vector of half as many elements with twice the input element size.
# TODO: maybe simplify this name within the receiver-type + method-naming scheme we use.
- go: DotProdPairsSaturated
commutative: false
documentation: !string |-
// NAME multiplies the elements and add the pairs together with saturation,
// yielding a vector of half as many elements with twice the input element size.
# QuadDotProd, i.e. VPDPBUSD(S) are operations with src/dst on the same register, we are not supporting this as of now.
# - go: DotProdBroadcast
# commutative: true
# # documentation: !string |-
# // NAME multiplies all elements and broadcasts the sum.
- go: AddDotProdQuadruple
commutative: false
documentation: !string |-
// NAME performs dot products on groups of 4 elements of x and y and then adds z.
- go: AddDotProdQuadrupleSaturated
commutative: false
documentation: !string |-
// NAME multiplies performs dot products on groups of 4 elements of x and y and then adds z.
- go: AddDotProdPairs
commutative: false
noTypes: "true"
noGenericOps: "true"
documentation: !string |-
// NAME performs dot products on pairs of elements of y and z and then adds x.
- go: AddDotProdPairsSaturated
commutative: false
documentation: !string |-
// NAME performs dot products on pairs of elements of y and z and then adds x.
- go: MulAdd
commutative: false
documentation: !string |-
// NAME performs a fused (x * y) + z.
- go: MulAddSub
commutative: false
documentation: !string |-
// NAME performs a fused (x * y) - z for odd-indexed elements, and (x * y) + z for even-indexed elements.
- go: MulSubAdd
commutative: false
documentation: !string |-
// NAME performs a fused (x * y) + z for odd-indexed elements, and (x * y) - z for even-indexed elements.

View file

@ -0,0 +1,113 @@
!sum
- go: DotProdPairs
asm: VPMADDWD
in:
- &int
go: $t
base: int
- *int
out:
- &int2 # The elemBits are different
go: $t2
base: int
- go: DotProdPairsSaturated
asm: VPMADDUBSW
in:
- &uint
go: $t
base: uint
overwriteElementBits: 8
- &int3
go: $t3
base: int
overwriteElementBits: 8
out:
- *int2
# - go: DotProdBroadcast
# asm: VDPP[SD]
# in:
# - &dpb_src
# go: $t
# - *dpb_src
# - class: immediate
# const: 127
# out:
# - *dpb_src
- go: AddDotProdQuadruple
asm: "VPDPBUSD"
operandOrder: "31" # switch operand 3 and 1
in:
- &qdpa_acc
go: $t_acc
base: int
elemBits: 32
- &qdpa_src1
go: $t_src1
base: uint
overwriteElementBits: 8
- &qdpa_src2
go: $t_src2
base: int
overwriteElementBits: 8
out:
- *qdpa_acc
- go: AddDotProdQuadrupleSaturated
asm: "VPDPBUSDS"
operandOrder: "31" # switch operand 3 and 1
in:
- *qdpa_acc
- *qdpa_src1
- *qdpa_src2
out:
- *qdpa_acc
- go: AddDotProdPairs
asm: "VPDPWSSD"
in:
- &pdpa_acc
go: $t_acc
base: int
elemBits: 32
- &pdpa_src1
go: $t_src1
base: int
overwriteElementBits: 16
- &pdpa_src2
go: $t_src2
base: int
overwriteElementBits: 16
out:
- *pdpa_acc
- go: AddDotProdPairsSaturated
asm: "VPDPWSSDS"
in:
- *pdpa_acc
- *pdpa_src1
- *pdpa_src2
out:
- *pdpa_acc
- go: MulAdd
asm: "VFMADD213PS|VFMADD213PD"
in:
- &fma_op
go: $t
base: float
- *fma_op
- *fma_op
out:
- *fma_op
- go: MulAddSub
asm: "VFMADDSUB213PS|VFMADDSUB213PD"
in:
- *fma_op
- *fma_op
- *fma_op
out:
- *fma_op
- go: MulSubAdd
asm: "VFMSUBADD213PS|VFMSUBADD213PD"
in:
- *fma_op
- *fma_op
- *fma_op
out:
- *fma_op

View file

@ -0,0 +1,9 @@
!sum
- go: Max
commutative: true
documentation: !string |-
// NAME computes the maximum of corresponding elements.
- go: Min
commutative: true
documentation: !string |-
// NAME computes the minimum of corresponding elements.

View file

@ -0,0 +1,42 @@
!sum
- go: Max
asm: "V?PMAXS[BWDQ]"
in: &2int
- &int
go: $t
base: int
- *int
out: &1int
- *int
- go: Max
asm: "V?PMAXU[BWDQ]"
in: &2uint
- &uint
go: $t
base: uint
- *uint
out: &1uint
- *uint
- go: Min
asm: "V?PMINS[BWDQ]"
in: *2int
out: *1int
- go: Min
asm: "V?PMINU[BWDQ]"
in: *2uint
out: *1uint
- go: Max
asm: "V?MAXP[SD]"
in: &2float
- &float
go: $t
base: float
- *float
out: &1float
- *float
- go: Min
asm: "V?MINP[SD]"
in: *2float
out: *1float

View file

@ -0,0 +1,72 @@
!sum
- go: SetElem
commutative: false
documentation: !string |-
// NAME sets a single constant-indexed element's value.
- go: GetElem
commutative: false
documentation: !string |-
// NAME retrieves a single constant-indexed element's value.
- go: SetLo
commutative: false
constImm: 0
documentation: !string |-
// NAME returns x with its lower half set to y.
- go: GetLo
commutative: false
constImm: 0
documentation: !string |-
// NAME returns the lower half of x.
- go: SetHi
commutative: false
constImm: 1
documentation: !string |-
// NAME returns x with its upper half set to y.
- go: GetHi
commutative: false
constImm: 1
documentation: !string |-
// NAME returns the upper half of x.
- go: Permute
commutative: false
documentation: !string |-
// NAME performs a full permutation of vector x using indices:
// result := {x[indices[0]], x[indices[1]], ..., x[indices[n]]}
// Only the needed bits to represent x's index are used in indices' elements.
- go: Permute2 # Permute2 is only available on or after AVX512
commutative: false
documentation: !string |-
// NAME performs a full permutation of vector x, y using indices:
// result := {xy[indices[0]], xy[indices[1]], ..., xy[indices[n]]}
// where xy is x appending y.
// Only the needed bits to represent xy's index are used in indices' elements.
- go: Compress
commutative: false
documentation: !string |-
// NAME performs a compression on vector x using mask by
// selecting elements as indicated by mask, and pack them to lower indexed elements.
- go: blend
commutative: false
documentation: !string |-
// NAME blends two vectors based on mask values, choosing either
// the first or the second based on whether the third is false or true
- go: Expand
commutative: false
documentation: !string |-
// NAME performs an expansion on a vector x whose elements are packed to lower parts.
// The expansion is to distribute elements as indexed by mask, from lower mask elements to upper in order.
- go: Broadcast128
commutative: false
documentation: !string |-
// NAME copies element zero of its (128-bit) input to all elements of
// the 128-bit output vector.
- go: Broadcast256
commutative: false
documentation: !string |-
// NAME copies element zero of its (128-bit) input to all elements of
// the 256-bit output vector.
- go: Broadcast512
commutative: false
documentation: !string |-
// NAME copies element zero of its (128-bit) input to all elements of
// the 512-bit output vector.

View file

@ -0,0 +1,372 @@
!sum
- go: SetElem
asm: "VPINSR[BWDQ]"
in:
- &t
class: vreg
base: $b
- class: greg
base: $b
lanes: 1 # Scalar, darn it!
- &imm
class: immediate
immOffset: 0
name: index
out:
- *t
- go: SetElem
asm: "VPINSR[DQ]"
in:
- &t
class: vreg
base: int
OverwriteBase: float
- class: greg
base: int
OverwriteBase: float
lanes: 1 # Scalar, darn it!
- &imm
class: immediate
immOffset: 0
name: index
out:
- *t
- go: GetElem
asm: "VPEXTR[BWDQ]"
in:
- class: vreg
base: $b
elemBits: $e
- *imm
out:
- class: greg
base: $b
bits: $e
- go: "SetHi|SetLo"
asm: "VINSERTI128|VINSERTI64X4"
inVariant: []
in:
- &i8x2N
class: vreg
base: $t
OverwriteElementBits: 8
- &i8xN
class: vreg
base: $t
OverwriteElementBits: 8
- &imm01 # This immediate should be only 0 or 1
class: immediate
const: 0 # place holder
name: index
out:
- *i8x2N
- go: "GetHi|GetLo"
asm: "VEXTRACTI128|VEXTRACTI64X4"
inVariant: []
in:
- *i8x2N
- *imm01
out:
- *i8xN
- go: "SetHi|SetLo"
asm: "VINSERTI128|VINSERTI64X4"
inVariant: []
in:
- &i16x2N
class: vreg
base: $t
OverwriteElementBits: 16
- &i16xN
class: vreg
base: $t
OverwriteElementBits: 16
- *imm01
out:
- *i16x2N
- go: "GetHi|GetLo"
asm: "VEXTRACTI128|VEXTRACTI64X4"
inVariant: []
in:
- *i16x2N
- *imm01
out:
- *i16xN
- go: "SetHi|SetLo"
asm: "VINSERTI128|VINSERTI64X4"
inVariant: []
in:
- &i32x2N
class: vreg
base: $t
OverwriteElementBits: 32
- &i32xN
class: vreg
base: $t
OverwriteElementBits: 32
- *imm01
out:
- *i32x2N
- go: "GetHi|GetLo"
asm: "VEXTRACTI128|VEXTRACTI64X4"
inVariant: []
in:
- *i32x2N
- *imm01
out:
- *i32xN
- go: "SetHi|SetLo"
asm: "VINSERTI128|VINSERTI64X4"
inVariant: []
in:
- &i64x2N
class: vreg
base: $t
OverwriteElementBits: 64
- &i64xN
class: vreg
base: $t
OverwriteElementBits: 64
- *imm01
out:
- *i64x2N
- go: "GetHi|GetLo"
asm: "VEXTRACTI128|VEXTRACTI64X4"
inVariant: []
in:
- *i64x2N
- *imm01
out:
- *i64xN
- go: "SetHi|SetLo"
asm: "VINSERTF128|VINSERTF64X4"
inVariant: []
in:
- &f32x2N
class: vreg
base: $t
OverwriteElementBits: 32
- &f32xN
class: vreg
base: $t
OverwriteElementBits: 32
- *imm01
out:
- *f32x2N
- go: "GetHi|GetLo"
asm: "VEXTRACTF128|VEXTRACTF64X4"
inVariant: []
in:
- *f32x2N
- *imm01
out:
- *f32xN
- go: "SetHi|SetLo"
asm: "VINSERTF128|VINSERTF64X4"
inVariant: []
in:
- &f64x2N
class: vreg
base: $t
OverwriteElementBits: 64
- &f64xN
class: vreg
base: $t
OverwriteElementBits: 64
- *imm01
out:
- *f64x2N
- go: "GetHi|GetLo"
asm: "VEXTRACTF128|VEXTRACTF64X4"
inVariant: []
in:
- *f64x2N
- *imm01
out:
- *f64xN
- go: Permute
asm: "VPERM[BWDQ]|VPERMP[SD]"
operandOrder: "21Type1"
in:
- &anyindices
go: $t
name: indices
overwriteBase: uint
- &any
go: $t
out:
- *any
- go: Permute2
asm: "VPERMI2[BWDQ]|VPERMI2P[SD]"
# Because we are overwriting the receiver's type, we
# have to move the receiver to be a parameter so that
# we can have no duplication.
operandOrder: "231Type1"
in:
- *anyindices # result in arg 0
- *any
- *any
out:
- *any
- go: Compress
asm: "VPCOMPRESS[BWDQ]|VCOMPRESSP[SD]"
in:
# The mask in Compress is a control mask rather than a write mask, so it's not optional.
- class: mask
- *any
out:
- *any
# For now a non-public method because
# (1) [OverwriteClass] must be set together with [OverwriteBase]
# (2) "simdgen does not support [OverwriteClass] in inputs".
# That means the signature is wrong.
- go: blend
asm: VPBLENDVB
in:
- &v
go: $t
class: vreg
base: int
- *v
-
class: vreg
base: int
name: mask
out:
- *v
# For AVX512
- go: blend
asm: VPBLENDM[BWDQ]
in:
- &v
go: $t
bits: 512
class: vreg
base: int
- *v
inVariant:
-
class: mask
out:
- *v
- go: Expand
asm: "VPEXPAND[BWDQ]|VEXPANDP[SD]"
in:
# The mask in Expand is a control mask rather than a write mask, so it's not optional.
- class: mask
- *any
out:
- *any
- go: Broadcast128
asm: VPBROADCAST[BWDQ]
in:
- class: vreg
bits: 128
elemBits: $e
base: $b
out:
- class: vreg
bits: 128
elemBits: $e
base: $b
# weirdly, this one case on AVX2 is memory-operand-only
- go: Broadcast128
asm: VPBROADCASTQ
in:
- class: vreg
bits: 128
elemBits: 64
base: int
OverwriteBase: float
out:
- class: vreg
bits: 128
elemBits: 64
base: int
OverwriteBase: float
- go: Broadcast256
asm: VPBROADCAST[BWDQ]
in:
- class: vreg
bits: 128
elemBits: $e
base: $b
out:
- class: vreg
bits: 256
elemBits: $e
base: $b
- go: Broadcast512
asm: VPBROADCAST[BWDQ]
in:
- class: vreg
bits: 128
elemBits: $e
base: $b
out:
- class: vreg
bits: 512
elemBits: $e
base: $b
- go: Broadcast128
asm: VBROADCASTS[SD]
in:
- class: vreg
bits: 128
elemBits: $e
base: $b
out:
- class: vreg
bits: 128
elemBits: $e
base: $b
- go: Broadcast256
asm: VBROADCASTS[SD]
in:
- class: vreg
bits: 128
elemBits: $e
base: $b
out:
- class: vreg
bits: 256
elemBits: $e
base: $b
- go: Broadcast512
asm: VBROADCASTS[SD]
in:
- class: vreg
bits: 128
elemBits: $e
base: $b
out:
- class: vreg
bits: 512
elemBits: $e
base: $b

View file

@ -0,0 +1,14 @@
!sum
- go: Mul
commutative: true
documentation: !string |-
// NAME multiplies corresponding elements of two vectors.
- go: MulEvenWiden
commutative: true
documentation: !string |-
// NAME multiplies even-indexed elements, widening the result.
// Result[i] = v1.Even[i] * v2.Even[i].
- go: MulHigh
commutative: true
documentation: !string |-
// NAME multiplies elements and stores the high part of the result.

View file

@ -0,0 +1,73 @@
!sum
# "Normal" multiplication is only available for floats.
# This only covers the single and double precision.
- go: Mul
asm: "VMULP[SD]"
in:
- &fp
go: $t
base: float
- *fp
out:
- *fp
# Integer multiplications.
# MulEvenWiden
# Dword only.
- go: MulEvenWiden
asm: "VPMULDQ"
in:
- &intNot64
go: $t
elemBits: 8|16|32
base: int
- *intNot64
out:
- &int2
go: $t2
base: int
- go: MulEvenWiden
asm: "VPMULUDQ"
in:
- &uintNot64
go: $t
elemBits: 8|16|32
base: uint
- *uintNot64
out:
- &uint2
go: $t2
base: uint
# MulHigh
# Word only.
- go: MulHigh
asm: "VPMULHW"
in:
- &int
go: $t
base: int
- *int
out:
- *int
- go: MulHigh
asm: "VPMULHUW"
in:
- &uint
go: $t
base: uint
- *uint
out:
- *uint
# MulLow
# signed and unsigned are the same for lower bits.
- go: Mul
asm: "VPMULL[WDQ]"
in:
- &any
go: $t
- *any
out:
- *any

View file

@ -0,0 +1,103 @@
!sum
- go: ShiftAllLeft
nameAndSizeCheck: true
specialLower: sftimm
commutative: false
documentation: !string |-
// NAME shifts each element to the left by the specified number of bits. Emptied lower bits are zeroed.
- go: ShiftAllRight
signed: false
nameAndSizeCheck: true
specialLower: sftimm
commutative: false
documentation: !string |-
// NAME shifts each element to the right by the specified number of bits. Emptied upper bits are zeroed.
- go: ShiftAllRight
signed: true
specialLower: sftimm
nameAndSizeCheck: true
commutative: false
documentation: !string |-
// NAME shifts each element to the right by the specified number of bits. Emptied upper bits are filled with the sign bit.
- go: shiftAllLeftConst # no APIs, only ssa ops.
noTypes: "true"
noGenericOps: "true"
SSAVariant: "const" # to avoid its name colliding with reg version of this instruction, amend this to its ssa op name.
nameAndSizeCheck: true
commutative: false
- go: shiftAllRightConst # no APIs, only ssa ops.
noTypes: "true"
noGenericOps: "true"
SSAVariant: "const"
signed: false
nameAndSizeCheck: true
commutative: false
- go: shiftAllRightConst # no APIs, only ssa ops.
noTypes: "true"
noGenericOps: "true"
SSAVariant: "const"
signed: true
nameAndSizeCheck: true
commutative: false
- go: ShiftLeft
nameAndSizeCheck: true
commutative: false
documentation: !string |-
// NAME shifts each element in x to the left by the number of bits specified in y's corresponding elements. Emptied lower bits are zeroed.
- go: ShiftRight
signed: false
nameAndSizeCheck: true
commutative: false
documentation: !string |-
// NAME shifts each element in x to the right by the number of bits specified in y's corresponding elements. Emptied upper bits are zeroed.
- go: ShiftRight
signed: true
nameAndSizeCheck: true
commutative: false
documentation: !string |-
// NAME shifts each element in x to the right by the number of bits specified in y's corresponding elements. Emptied upper bits are filled with the sign bit.
- go: RotateAllLeft
nameAndSizeCheck: true
commutative: false
documentation: !string |-
// NAME rotates each element to the left by the number of bits specified by the immediate.
- go: RotateLeft
nameAndSizeCheck: true
commutative: false
documentation: !string |-
// NAME rotates each element in x to the left by the number of bits specified by y's corresponding elements.
- go: RotateAllRight
nameAndSizeCheck: true
commutative: false
documentation: !string |-
// NAME rotates each element to the right by the number of bits specified by the immediate.
- go: RotateRight
nameAndSizeCheck: true
commutative: false
documentation: !string |-
// NAME rotates each element in x to the right by the number of bits specified by y's corresponding elements.
- go: ShiftAllLeftConcat
nameAndSizeCheck: true
commutative: false
documentation: !string |-
// NAME shifts each element of x to the left by the number of bits specified by the
// immediate(only the lower 5 bits are used), and then copies the upper bits of y to the emptied lower bits of the shifted x.
- go: ShiftAllRightConcat
nameAndSizeCheck: true
commutative: false
documentation: !string |-
// NAME shifts each element of x to the right by the number of bits specified by the
// immediate(only the lower 5 bits are used), and then copies the lower bits of y to the emptied upper bits of the shifted x.
- go: ShiftLeftConcat
nameAndSizeCheck: true
commutative: false
documentation: !string |-
// NAME shifts each element of x to the left by the number of bits specified by the
// corresponding elements in y(only the lower 5 bits are used), and then copies the upper bits of z to the emptied lower bits of the shifted x.
- go: ShiftRightConcat
nameAndSizeCheck: true
commutative: false
documentation: !string |-
// NAME shifts each element of x to the right by the number of bits specified by the
// corresponding elements in y(only the lower 5 bits are used), and then copies the lower bits of z to the emptied upper bits of the shifted x.

View file

@ -0,0 +1,172 @@
!sum
# Integers
# ShiftAll*
- go: ShiftAllLeft
asm: "VPSLL[WDQ]"
in:
- &any
go: $t
- &vecAsScalar64
go: "Uint.*"
treatLikeAScalarOfSize: 64
out:
- *any
- go: ShiftAllRight
signed: false
asm: "VPSRL[WDQ]"
in:
- &uint
go: $t
base: uint
- *vecAsScalar64
out:
- *uint
- go: ShiftAllRight
signed: true
asm: "VPSRA[WDQ]"
in:
- &int
go: $t
base: int
- *vecAsScalar64
out:
- *int
- go: shiftAllLeftConst
asm: "VPSLL[WDQ]"
in:
- *any
- &imm
class: immediate
immOffset: 0
out:
- *any
- go: shiftAllRightConst
asm: "VPSRL[WDQ]"
in:
- *int
- *imm
out:
- *int
- go: shiftAllRightConst
asm: "VPSRA[WDQ]"
in:
- *uint
- *imm
out:
- *uint
# Shift* (variable)
- go: ShiftLeft
asm: "VPSLLV[WD]"
in:
- *any
- *any
out:
- *any
# XED data of VPSLLVQ marks the element bits 32 which is off to the actual semantic, we need to overwrite
# it to 64.
- go: ShiftLeft
asm: "VPSLLVQ"
in:
- &anyOverwriteElemBits
go: $t
overwriteElementBits: 64
- *anyOverwriteElemBits
out:
- *anyOverwriteElemBits
- go: ShiftRight
signed: false
asm: "VPSRLV[WD]"
in:
- *uint
- *uint
out:
- *uint
# XED data of VPSRLVQ needs the same overwrite as VPSLLVQ.
- go: ShiftRight
signed: false
asm: "VPSRLVQ"
in:
- &uintOverwriteElemBits
go: $t
base: uint
overwriteElementBits: 64
- *uintOverwriteElemBits
out:
- *uintOverwriteElemBits
- go: ShiftRight
signed: true
asm: "VPSRAV[WDQ]"
in:
- *int
- *int
out:
- *int
# Rotate
- go: RotateAllLeft
asm: "VPROL[DQ]"
in:
- *any
- &pureImm
class: immediate
immOffset: 0
name: shift
out:
- *any
- go: RotateAllRight
asm: "VPROR[DQ]"
in:
- *any
- *pureImm
out:
- *any
- go: RotateLeft
asm: "VPROLV[DQ]"
in:
- *any
- *any
out:
- *any
- go: RotateRight
asm: "VPRORV[DQ]"
in:
- *any
- *any
out:
- *any
# Bizzare shifts.
- go: ShiftAllLeftConcat
asm: "VPSHLD[WDQ]"
in:
- *any
- *any
- *pureImm
out:
- *any
- go: ShiftAllRightConcat
asm: "VPSHRD[WDQ]"
in:
- *any
- *any
- *pureImm
out:
- *any
- go: ShiftLeftConcat
asm: "VPSHLDV[WDQ]"
in:
- *any
- *any
- *any
out:
- *any
- go: ShiftRightConcat
asm: "VPSHRDV[WDQ]"
in:
- *any
- *any
- *any
out:
- *any

View file

@ -0,0 +1,73 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"reflect"
"strconv"
)
func pprints(v any) string {
var pp pprinter
pp.val(reflect.ValueOf(v), 0)
return string(pp.buf)
}
type pprinter struct {
buf []byte
}
func (p *pprinter) indent(by int) {
for range by {
p.buf = append(p.buf, '\t')
}
}
func (p *pprinter) val(v reflect.Value, indent int) {
switch v.Kind() {
default:
p.buf = fmt.Appendf(p.buf, "unsupported kind %v", v.Kind())
case reflect.Bool:
p.buf = strconv.AppendBool(p.buf, v.Bool())
case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
p.buf = strconv.AppendInt(p.buf, v.Int(), 10)
case reflect.String:
p.buf = strconv.AppendQuote(p.buf, v.String())
case reflect.Pointer:
if v.IsNil() {
p.buf = append(p.buf, "nil"...)
} else {
p.buf = append(p.buf, "&"...)
p.val(v.Elem(), indent)
}
case reflect.Slice, reflect.Array:
p.buf = append(p.buf, "[\n"...)
for i := range v.Len() {
p.indent(indent + 1)
p.val(v.Index(i), indent+1)
p.buf = append(p.buf, ",\n"...)
}
p.indent(indent)
p.buf = append(p.buf, ']')
case reflect.Struct:
vt := v.Type()
p.buf = append(append(p.buf, vt.String()...), "{\n"...)
for f := range v.NumField() {
p.indent(indent + 1)
p.buf = append(append(p.buf, vt.Field(f).Name...), ": "...)
p.val(v.Field(f), indent+1)
p.buf = append(p.buf, ",\n"...)
}
p.indent(indent)
p.buf = append(p.buf, '}')
}
}

View file

@ -0,0 +1,41 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import "testing"
func TestSort(t *testing.T) {
testCases := []struct {
s1, s2 string
want int
}{
{"a1", "a2", -1},
{"a11a", "a11b", -1},
{"a01a1", "a1a01", -1},
{"a2", "a1", 1},
{"a10", "a2", 1},
{"a1", "a10", -1},
{"z11", "z2", 1},
{"z2", "z11", -1},
{"abc", "abd", -1},
{"123", "45", 1},
{"file1", "file1", 0},
{"file", "file1", -1},
{"file1", "file", 1},
{"a01", "a1", -1},
{"a1a", "a1b", -1},
}
for _, tc := range testCases {
got := compareNatural(tc.s1, tc.s2)
result := "✅"
if got != tc.want {
result = "❌"
t.Errorf("%s CompareNatural(\"%s\", \"%s\") -> got %2d, want %2d\n", result, tc.s1, tc.s2, got, tc.want)
} else {
t.Logf("%s CompareNatural(\"%s\", \"%s\") -> got %2d, want %2d\n", result, tc.s1, tc.s2, got, tc.want)
}
}
}

View file

@ -0,0 +1,90 @@
# This file defines the possible types of each operand and result.
#
# In general, we're able to narrow this down on some attributes directly from
# the machine instruction descriptions, but the Go mappings need to further
# constrain them and how they relate. For example, on x86 we can't distinguish
# int and uint, though we can distinguish these from float.
in: !repeat
- !sum &types
- {class: vreg, go: Int8x16, base: "int", elemBits: 8, bits: 128, lanes: 16}
- {class: vreg, go: Uint8x16, base: "uint", elemBits: 8, bits: 128, lanes: 16}
- {class: vreg, go: Int16x8, base: "int", elemBits: 16, bits: 128, lanes: 8}
- {class: vreg, go: Uint16x8, base: "uint", elemBits: 16, bits: 128, lanes: 8}
- {class: vreg, go: Int32x4, base: "int", elemBits: 32, bits: 128, lanes: 4}
- {class: vreg, go: Uint32x4, base: "uint", elemBits: 32, bits: 128, lanes: 4}
- {class: vreg, go: Int64x2, base: "int", elemBits: 64, bits: 128, lanes: 2}
- {class: vreg, go: Uint64x2, base: "uint", elemBits: 64, bits: 128, lanes: 2}
- {class: vreg, go: Float32x4, base: "float", elemBits: 32, bits: 128, lanes: 4}
- {class: vreg, go: Float64x2, base: "float", elemBits: 64, bits: 128, lanes: 2}
- {class: vreg, go: Int8x32, base: "int", elemBits: 8, bits: 256, lanes: 32}
- {class: vreg, go: Uint8x32, base: "uint", elemBits: 8, bits: 256, lanes: 32}
- {class: vreg, go: Int16x16, base: "int", elemBits: 16, bits: 256, lanes: 16}
- {class: vreg, go: Uint16x16, base: "uint", elemBits: 16, bits: 256, lanes: 16}
- {class: vreg, go: Int32x8, base: "int", elemBits: 32, bits: 256, lanes: 8}
- {class: vreg, go: Uint32x8, base: "uint", elemBits: 32, bits: 256, lanes: 8}
- {class: vreg, go: Int64x4, base: "int", elemBits: 64, bits: 256, lanes: 4}
- {class: vreg, go: Uint64x4, base: "uint", elemBits: 64, bits: 256, lanes: 4}
- {class: vreg, go: Float32x8, base: "float", elemBits: 32, bits: 256, lanes: 8}
- {class: vreg, go: Float64x4, base: "float", elemBits: 64, bits: 256, lanes: 4}
- {class: vreg, go: Int8x64, base: "int", elemBits: 8, bits: 512, lanes: 64}
- {class: vreg, go: Uint8x64, base: "uint", elemBits: 8, bits: 512, lanes: 64}
- {class: vreg, go: Int16x32, base: "int", elemBits: 16, bits: 512, lanes: 32}
- {class: vreg, go: Uint16x32, base: "uint", elemBits: 16, bits: 512, lanes: 32}
- {class: vreg, go: Int32x16, base: "int", elemBits: 32, bits: 512, lanes: 16}
- {class: vreg, go: Uint32x16, base: "uint", elemBits: 32, bits: 512, lanes: 16}
- {class: vreg, go: Int64x8, base: "int", elemBits: 64, bits: 512, lanes: 8}
- {class: vreg, go: Uint64x8, base: "uint", elemBits: 64, bits: 512, lanes: 8}
- {class: vreg, go: Float32x16, base: "float", elemBits: 32, bits: 512, lanes: 16}
- {class: vreg, go: Float64x8, base: "float", elemBits: 64, bits: 512, lanes: 8}
- {class: mask, go: Mask8x16, base: "int", elemBits: 8, bits: 128, lanes: 16}
- {class: mask, go: Mask16x8, base: "int", elemBits: 16, bits: 128, lanes: 8}
- {class: mask, go: Mask32x4, base: "int", elemBits: 32, bits: 128, lanes: 4}
- {class: mask, go: Mask64x2, base: "int", elemBits: 64, bits: 128, lanes: 2}
- {class: mask, go: Mask8x32, base: "int", elemBits: 8, bits: 256, lanes: 32}
- {class: mask, go: Mask16x16, base: "int", elemBits: 16, bits: 256, lanes: 16}
- {class: mask, go: Mask32x8, base: "int", elemBits: 32, bits: 256, lanes: 8}
- {class: mask, go: Mask64x4, base: "int", elemBits: 64, bits: 256, lanes: 4}
- {class: mask, go: Mask8x64, base: "int", elemBits: 8, bits: 512, lanes: 64}
- {class: mask, go: Mask16x32, base: "int", elemBits: 16, bits: 512, lanes: 32}
- {class: mask, go: Mask32x16, base: "int", elemBits: 32, bits: 512, lanes: 16}
- {class: mask, go: Mask64x8, base: "int", elemBits: 64, bits: 512, lanes: 8}
- {class: greg, go: float64, base: "float", bits: 64, lanes: 1}
- {class: greg, go: float32, base: "float", bits: 32, lanes: 1}
- {class: greg, go: int64, base: "int", bits: 64, lanes: 1}
- {class: greg, go: int32, base: "int", bits: 32, lanes: 1}
- {class: greg, go: int16, base: "int", bits: 16, lanes: 1}
- {class: greg, go: int8, base: "int", bits: 8, lanes: 1}
- {class: greg, go: uint64, base: "uint", bits: 64, lanes: 1}
- {class: greg, go: uint32, base: "uint", bits: 32, lanes: 1}
- {class: greg, go: uint16, base: "uint", bits: 16, lanes: 1}
- {class: greg, go: uint8, base: "uint", bits: 8, lanes: 1}
# Special shapes just to make INSERT[IF]128 work.
# The elemBits field of these shapes are wrong, it would be overwritten by overwriteElemBits.
- {class: vreg, go: Int8x16, base: "int", elemBits: 128, bits: 128, lanes: 16}
- {class: vreg, go: Uint8x16, base: "uint", elemBits: 128, bits: 128, lanes: 16}
- {class: vreg, go: Int16x8, base: "int", elemBits: 128, bits: 128, lanes: 8}
- {class: vreg, go: Uint16x8, base: "uint", elemBits: 128, bits: 128, lanes: 8}
- {class: vreg, go: Int32x4, base: "int", elemBits: 128, bits: 128, lanes: 4}
- {class: vreg, go: Uint32x4, base: "uint", elemBits: 128, bits: 128, lanes: 4}
- {class: vreg, go: Int64x2, base: "int", elemBits: 128, bits: 128, lanes: 2}
- {class: vreg, go: Uint64x2, base: "uint", elemBits: 128, bits: 128, lanes: 2}
- {class: vreg, go: Int8x32, base: "int", elemBits: 128, bits: 256, lanes: 32}
- {class: vreg, go: Uint8x32, base: "uint", elemBits: 128, bits: 256, lanes: 32}
- {class: vreg, go: Int16x16, base: "int", elemBits: 128, bits: 256, lanes: 16}
- {class: vreg, go: Uint16x16, base: "uint", elemBits: 128, bits: 256, lanes: 16}
- {class: vreg, go: Int32x8, base: "int", elemBits: 128, bits: 256, lanes: 8}
- {class: vreg, go: Uint32x8, base: "uint", elemBits: 128, bits: 256, lanes: 8}
- {class: vreg, go: Int64x4, base: "int", elemBits: 128, bits: 256, lanes: 4}
- {class: vreg, go: Uint64x4, base: "uint", elemBits: 128, bits: 256, lanes: 4}
- {class: immediate, go: Immediate} # TODO: we only support imms that are not used as value -- usually as instruction semantic predicate like VPCMP as of now.
inVariant: !repeat
- *types
out: !repeat
- *types

View file

@ -0,0 +1,780 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"cmp"
"fmt"
"log"
"maps"
"regexp"
"slices"
"strconv"
"strings"
"golang.org/x/arch/x86/xeddata"
"gopkg.in/yaml.v3"
"simd/_gen/unify"
)
const (
NOT_REG_CLASS = 0 // not a register
VREG_CLASS = 1 // classify as a vector register; see
GREG_CLASS = 2 // classify as a general register
)
// instVariant is a bitmap indicating a variant of an instruction that has
// optional parameters.
type instVariant uint8
const (
instVariantNone instVariant = 0
// instVariantMasked indicates that this is the masked variant of an
// optionally-masked instruction.
instVariantMasked instVariant = 1 << iota
)
var operandRemarks int
// TODO: Doc. Returns Values with Def domains.
func loadXED(xedPath string) []*unify.Value {
// TODO: Obviously a bunch more to do here.
db, err := xeddata.NewDatabase(xedPath)
if err != nil {
log.Fatalf("open database: %v", err)
}
var defs []*unify.Value
err = xeddata.WalkInsts(xedPath, func(inst *xeddata.Inst) {
inst.Pattern = xeddata.ExpandStates(db, inst.Pattern)
switch {
case inst.RealOpcode == "N":
return // Skip unstable instructions
case !strings.HasPrefix(inst.Extension, "AVX"):
// We're only interested in AVX instructions.
return
}
if *flagDebugXED {
fmt.Printf("%s:\n%+v\n", inst.Pos, inst)
}
ops, err := decodeOperands(db, strings.Fields(inst.Operands))
if err != nil {
operandRemarks++
if *Verbose {
log.Printf("%s: [%s] %s", inst.Pos, inst.Opcode(), err)
}
return
}
applyQuirks(inst, ops)
defsPos := len(defs)
defs = append(defs, instToUVal(inst, ops)...)
if *flagDebugXED {
for i := defsPos; i < len(defs); i++ {
y, _ := yaml.Marshal(defs[i])
fmt.Printf("==>\n%s\n", y)
}
}
})
if err != nil {
log.Fatalf("walk insts: %v", err)
}
if len(unknownFeatures) > 0 {
if !*Verbose {
nInst := 0
for _, insts := range unknownFeatures {
nInst += len(insts)
}
log.Printf("%d unhandled CPU features for %d instructions (use -v for details)", len(unknownFeatures), nInst)
} else {
keys := slices.SortedFunc(maps.Keys(unknownFeatures), func(a, b cpuFeatureKey) int {
return cmp.Or(cmp.Compare(a.Extension, b.Extension),
cmp.Compare(a.ISASet, b.ISASet))
})
for _, key := range keys {
if key.ISASet == "" || key.ISASet == key.Extension {
log.Printf("unhandled Extension %s", key.Extension)
} else {
log.Printf("unhandled Extension %s and ISASet %s", key.Extension, key.ISASet)
}
log.Printf(" opcodes: %s", slices.Sorted(maps.Keys(unknownFeatures[key])))
}
}
}
return defs
}
var (
maskRequiredRe = regexp.MustCompile(`VPCOMPRESS[BWDQ]|VCOMPRESSP[SD]|VPEXPAND[BWDQ]|VEXPANDP[SD]`)
maskOptionalRe = regexp.MustCompile(`VPCMP(EQ|GT|U)?[BWDQ]|VCMPP[SD]`)
)
func applyQuirks(inst *xeddata.Inst, ops []operand) {
opc := inst.Opcode()
switch {
case maskRequiredRe.MatchString(opc):
// The mask on these instructions is marked optional, but the
// instruction is pointless without the mask.
for i, op := range ops {
if op, ok := op.(operandMask); ok {
op.optional = false
ops[i] = op
}
}
case maskOptionalRe.MatchString(opc):
// Conversely, these masks should be marked optional and aren't.
for i, op := range ops {
if op, ok := op.(operandMask); ok && op.action.r {
op.optional = true
ops[i] = op
}
}
}
}
type operandCommon struct {
action operandAction
}
// operandAction defines whether this operand is read and/or written.
//
// TODO: Should this live in [xeddata.Operand]?
type operandAction struct {
r bool // Read
w bool // Written
cr bool // Read is conditional (implies r==true)
cw bool // Write is conditional (implies w==true)
}
type operandMem struct {
operandCommon
// TODO
}
type vecShape struct {
elemBits int // Element size in bits
bits int // Register width in bits (total vector bits)
}
type operandVReg struct { // Vector register
operandCommon
vecShape
elemBaseType scalarBaseType
}
type operandGReg struct { // Vector register
operandCommon
vecShape
elemBaseType scalarBaseType
}
// operandMask is a vector mask.
//
// Regardless of the actual mask representation, the [vecShape] of this operand
// corresponds to the "bit for bit" type of mask. That is, elemBits gives the
// element width covered by each mask element, and bits/elemBits gives the total
// number of mask elements. (bits gives the total number of bits as if this were
// a bit-for-bit mask, which may be meaningless on its own.)
type operandMask struct {
operandCommon
vecShape
// Bits in the mask is w/bits.
allMasks bool // If set, size cannot be inferred because all operands are masks.
// Mask can be omitted, in which case it defaults to K0/"no mask"
optional bool
}
type operandImm struct {
operandCommon
bits int // Immediate size in bits
}
type operand interface {
common() operandCommon
addToDef(b *unify.DefBuilder)
}
func strVal(s any) *unify.Value {
return unify.NewValue(unify.NewStringExact(fmt.Sprint(s)))
}
func (o operandCommon) common() operandCommon {
return o
}
func (o operandMem) addToDef(b *unify.DefBuilder) {
// TODO: w, base
b.Add("class", strVal("memory"))
}
func (o operandVReg) addToDef(b *unify.DefBuilder) {
baseDomain, err := unify.NewStringRegex(o.elemBaseType.regex())
if err != nil {
panic("parsing baseRe: " + err.Error())
}
b.Add("class", strVal("vreg"))
b.Add("bits", strVal(o.bits))
b.Add("base", unify.NewValue(baseDomain))
// If elemBits == bits, then the vector can be ANY shape. This happens with,
// for example, logical ops.
if o.elemBits != o.bits {
b.Add("elemBits", strVal(o.elemBits))
}
}
func (o operandGReg) addToDef(b *unify.DefBuilder) {
baseDomain, err := unify.NewStringRegex(o.elemBaseType.regex())
if err != nil {
panic("parsing baseRe: " + err.Error())
}
b.Add("class", strVal("greg"))
b.Add("bits", strVal(o.bits))
b.Add("base", unify.NewValue(baseDomain))
if o.elemBits != o.bits {
b.Add("elemBits", strVal(o.elemBits))
}
}
func (o operandMask) addToDef(b *unify.DefBuilder) {
b.Add("class", strVal("mask"))
if o.allMasks {
// If all operands are masks, omit sizes and let unification determine mask sizes.
return
}
b.Add("elemBits", strVal(o.elemBits))
b.Add("bits", strVal(o.bits))
}
func (o operandImm) addToDef(b *unify.DefBuilder) {
b.Add("class", strVal("immediate"))
b.Add("bits", strVal(o.bits))
}
var actionEncoding = map[string]operandAction{
"r": {r: true},
"cr": {r: true, cr: true},
"w": {w: true},
"cw": {w: true, cw: true},
"rw": {r: true, w: true},
"crw": {r: true, w: true, cr: true},
"rcw": {r: true, w: true, cw: true},
}
func decodeOperand(db *xeddata.Database, operand string) (operand, error) {
op, err := xeddata.NewOperand(db, operand)
if err != nil {
log.Fatalf("parsing operand %q: %v", operand, err)
}
if *flagDebugXED {
fmt.Printf(" %+v\n", op)
}
if strings.HasPrefix(op.Name, "EMX_BROADCAST") {
// This refers to a set of macros defined in all-state.txt that set a
// BCAST operand to various fixed values. But the BCAST operand is
// itself suppressed and "internal", so I think we can just ignore this
// operand.
return nil, nil
}
// TODO: See xed_decoded_inst_operand_action. This might need to be more
// complicated.
action, ok := actionEncoding[op.Action]
if !ok {
return nil, fmt.Errorf("unknown action %q", op.Action)
}
common := operandCommon{action: action}
lhs := op.NameLHS()
if strings.HasPrefix(lhs, "MEM") {
// TODO: Width, base type
return operandMem{
operandCommon: common,
}, nil
} else if strings.HasPrefix(lhs, "REG") {
if op.Width == "mskw" {
// The mask operand doesn't specify a width. We have to infer it.
//
// XED uses the marker ZEROSTR to indicate that a mask operand is
// optional and, if omitted, implies K0, aka "no mask".
return operandMask{
operandCommon: common,
optional: op.Attributes["TXT=ZEROSTR"],
}, nil
} else {
class, regBits := decodeReg(op)
if class == NOT_REG_CLASS {
return nil, fmt.Errorf("failed to decode register %q", operand)
}
baseType, elemBits, ok := decodeType(op)
if !ok {
return nil, fmt.Errorf("failed to decode register width %q", operand)
}
shape := vecShape{elemBits: elemBits, bits: regBits}
if class == VREG_CLASS {
return operandVReg{
operandCommon: common,
vecShape: shape,
elemBaseType: baseType,
}, nil
}
// general register
m := min(shape.bits, shape.elemBits)
shape.bits, shape.elemBits = m, m
return operandGReg{
operandCommon: common,
vecShape: shape,
elemBaseType: baseType,
}, nil
}
} else if strings.HasPrefix(lhs, "IMM") {
_, bits, ok := decodeType(op)
if !ok {
return nil, fmt.Errorf("failed to decode register width %q", operand)
}
return operandImm{
operandCommon: common,
bits: bits,
}, nil
}
// TODO: BASE and SEG
return nil, fmt.Errorf("unknown operand LHS %q in %q", lhs, operand)
}
func decodeOperands(db *xeddata.Database, operands []string) (ops []operand, err error) {
// Decode the XED operand descriptions.
for _, o := range operands {
op, err := decodeOperand(db, o)
if err != nil {
return nil, err
}
if op != nil {
ops = append(ops, op)
}
}
// XED doesn't encode the size of mask operands. If there are mask operands,
// try to infer their sizes from other operands.
if err := inferMaskSizes(ops); err != nil {
return nil, fmt.Errorf("%w in operands %+v", err, operands)
}
return ops, nil
}
func inferMaskSizes(ops []operand) error {
// This is a heuristic and it falls apart in some cases:
//
// - Mask operations like KAND[BWDQ] have *nothing* in the XED to indicate
// mask size.
//
// - VINSERT*, VPSLL*, VPSRA*, and VPSRL* and some others naturally have
// mixed input sizes and the XED doesn't indicate which operands the mask
// applies to.
//
// - VPDP* and VP4DP* have really complex mixed operand patterns.
//
// I think for these we may just have to hand-write a table of which
// operands each mask applies to.
inferMask := func(r, w bool) error {
var masks []int
var rSizes, wSizes, sizes []vecShape
allMasks := true
hasWMask := false
for i, op := range ops {
action := op.common().action
if _, ok := op.(operandMask); ok {
if action.r && action.w {
return fmt.Errorf("unexpected rw mask")
}
if action.r == r || action.w == w {
masks = append(masks, i)
}
if action.w {
hasWMask = true
}
} else {
allMasks = false
if reg, ok := op.(operandVReg); ok {
if action.r {
rSizes = append(rSizes, reg.vecShape)
}
if action.w {
wSizes = append(wSizes, reg.vecShape)
}
}
}
}
if len(masks) == 0 {
return nil
}
if r {
sizes = rSizes
if len(sizes) == 0 {
sizes = wSizes
}
}
if w {
sizes = wSizes
if len(sizes) == 0 {
sizes = rSizes
}
}
if len(sizes) == 0 {
// If all operands are masks, leave the mask inferrence to the users.
if allMasks {
for _, i := range masks {
m := ops[i].(operandMask)
m.allMasks = true
ops[i] = m
}
return nil
}
return fmt.Errorf("cannot infer mask size: no register operands")
}
shape, ok := singular(sizes)
if !ok {
if !hasWMask && len(wSizes) == 1 && len(masks) == 1 {
// This pattern looks like predicate mask, so its shape should align with the
// output. TODO: verify this is a safe assumption.
shape = wSizes[0]
} else {
return fmt.Errorf("cannot infer mask size: multiple register sizes %v", sizes)
}
}
for _, i := range masks {
m := ops[i].(operandMask)
m.vecShape = shape
ops[i] = m
}
return nil
}
if err := inferMask(true, false); err != nil {
return err
}
if err := inferMask(false, true); err != nil {
return err
}
return nil
}
// addOperandstoDef adds "in", "inVariant", and "out" to an instruction Def.
//
// Optional mask input operands are added to the inVariant field if
// variant&instVariantMasked, and omitted otherwise.
func addOperandsToDef(ops []operand, instDB *unify.DefBuilder, variant instVariant) {
var inVals, inVar, outVals []*unify.Value
asmPos := 0
for _, op := range ops {
var db unify.DefBuilder
op.addToDef(&db)
db.Add("asmPos", unify.NewValue(unify.NewStringExact(fmt.Sprint(asmPos))))
action := op.common().action
asmCount := 1 // # of assembly operands; 0 or 1
if action.r {
inVal := unify.NewValue(db.Build())
// If this is an optional mask, put it in the input variant tuple.
if mask, ok := op.(operandMask); ok && mask.optional {
if variant&instVariantMasked != 0 {
inVar = append(inVar, inVal)
} else {
// This operand doesn't appear in the assembly at all.
asmCount = 0
}
} else {
// Just a regular input operand.
inVals = append(inVals, inVal)
}
}
if action.w {
outVal := unify.NewValue(db.Build())
outVals = append(outVals, outVal)
}
asmPos += asmCount
}
instDB.Add("in", unify.NewValue(unify.NewTuple(inVals...)))
instDB.Add("inVariant", unify.NewValue(unify.NewTuple(inVar...)))
instDB.Add("out", unify.NewValue(unify.NewTuple(outVals...)))
}
func instToUVal(inst *xeddata.Inst, ops []operand) []*unify.Value {
feature, ok := decodeCPUFeature(inst)
if !ok {
return nil
}
var vals []*unify.Value
vals = append(vals, instToUVal1(inst, ops, feature, instVariantNone))
if hasOptionalMask(ops) {
vals = append(vals, instToUVal1(inst, ops, feature, instVariantMasked))
}
return vals
}
func instToUVal1(inst *xeddata.Inst, ops []operand, feature string, variant instVariant) *unify.Value {
var db unify.DefBuilder
db.Add("goarch", unify.NewValue(unify.NewStringExact("amd64")))
db.Add("asm", unify.NewValue(unify.NewStringExact(inst.Opcode())))
addOperandsToDef(ops, &db, variant)
db.Add("cpuFeature", unify.NewValue(unify.NewStringExact(feature)))
if strings.Contains(inst.Pattern, "ZEROING=0") {
// This is an EVEX instruction, but the ".Z" (zero-merging)
// instruction flag is NOT valid. EVEX.z must be zero.
//
// This can mean a few things:
//
// - The output of an instruction is a mask, so merging modes don't
// make any sense. E.g., VCMPPS.
//
// - There are no masks involved anywhere. (Maybe MASK=0 is also set
// in this case?) E.g., VINSERTPS.
//
// - The operation inherently performs merging. E.g., VCOMPRESSPS
// with a mem operand.
//
// There may be other reasons.
db.Add("zeroing", unify.NewValue(unify.NewStringExact("false")))
}
pos := unify.Pos{Path: inst.Pos.Path, Line: inst.Pos.Line}
return unify.NewValuePos(db.Build(), pos)
}
// decodeCPUFeature returns the CPU feature name required by inst. These match
// the names of the "Has*" feature checks in the simd package.
func decodeCPUFeature(inst *xeddata.Inst) (string, bool) {
key := cpuFeatureKey{
Extension: inst.Extension,
ISASet: isaSetStrip.ReplaceAllLiteralString(inst.ISASet, ""),
}
feat, ok := cpuFeatureMap[key]
if !ok {
imap := unknownFeatures[key]
if imap == nil {
imap = make(map[string]struct{})
unknownFeatures[key] = imap
}
imap[inst.Opcode()] = struct{}{}
return "", false
}
if feat == "ignore" {
return "", false
}
return feat, true
}
var isaSetStrip = regexp.MustCompile("_(128N?|256N?|512)$")
type cpuFeatureKey struct {
Extension, ISASet string
}
// cpuFeatureMap maps from XED's "EXTENSION" and "ISA_SET" to a CPU feature name
// that can be used in the SIMD API.
var cpuFeatureMap = map[cpuFeatureKey]string{
{"AVX", ""}: "AVX",
{"AVX_VNNI", "AVX_VNNI"}: "AVXVNNI",
{"AVX2", ""}: "AVX2",
// AVX-512 foundational features. We combine all of these into one "AVX512" feature.
{"AVX512EVEX", "AVX512F"}: "AVX512",
{"AVX512EVEX", "AVX512CD"}: "AVX512",
{"AVX512EVEX", "AVX512BW"}: "AVX512",
{"AVX512EVEX", "AVX512DQ"}: "AVX512",
// AVX512VL doesn't appear explicitly in the ISASet. I guess it's implied by
// the vector length suffix.
// AVX-512 extension features
{"AVX512EVEX", "AVX512_BITALG"}: "AVX512BITALG",
{"AVX512EVEX", "AVX512_GFNI"}: "AVX512GFNI",
{"AVX512EVEX", "AVX512_VBMI2"}: "AVX512VBMI2",
{"AVX512EVEX", "AVX512_VBMI"}: "AVX512VBMI",
{"AVX512EVEX", "AVX512_VNNI"}: "AVX512VNNI",
{"AVX512EVEX", "AVX512_VPOPCNTDQ"}: "AVX512VPOPCNTDQ",
// AVX 10.2 (not yet supported)
{"AVX512EVEX", "AVX10_2_RC"}: "ignore",
}
var unknownFeatures = map[cpuFeatureKey]map[string]struct{}{}
// hasOptionalMask returns whether there is an optional mask operand in ops.
func hasOptionalMask(ops []operand) bool {
for _, op := range ops {
if op, ok := op.(operandMask); ok && op.optional {
return true
}
}
return false
}
func singular[T comparable](xs []T) (T, bool) {
if len(xs) == 0 {
return *new(T), false
}
for _, x := range xs[1:] {
if x != xs[0] {
return *new(T), false
}
}
return xs[0], true
}
// decodeReg returns class (NOT_REG_CLASS, VREG_CLASS, GREG_CLASS),
// and width in bits. If the operand cannot be decided as a register,
// then the clas is NOT_REG_CLASS.
func decodeReg(op *xeddata.Operand) (class, width int) {
// op.Width tells us the total width, e.g.,:
//
// dq => 128 bits (XMM)
// qq => 256 bits (YMM)
// mskw => K
// z[iuf?](8|16|32|...) => 512 bits (ZMM)
//
// But the encoding is really weird and it's not clear if these *always*
// mean XMM/YMM/ZMM or if other irregular things can use these large widths.
// Hence, we dig into the register sets themselves.
if !strings.HasPrefix(op.NameLHS(), "REG") {
return NOT_REG_CLASS, 0
}
// TODO: We shouldn't be relying on the macro naming conventions. We should
// use all-dec-patterns.txt, but xeddata doesn't support that table right now.
rhs := op.NameRHS()
if !strings.HasSuffix(rhs, "()") {
return NOT_REG_CLASS, 0
}
switch {
case strings.HasPrefix(rhs, "XMM_"):
return VREG_CLASS, 128
case strings.HasPrefix(rhs, "YMM_"):
return VREG_CLASS, 256
case strings.HasPrefix(rhs, "ZMM_"):
return VREG_CLASS, 512
case strings.HasPrefix(rhs, "GPR64_"), strings.HasPrefix(rhs, "VGPR64_"):
return GREG_CLASS, 64
case strings.HasPrefix(rhs, "GPR32_"), strings.HasPrefix(rhs, "VGPR32_"):
return GREG_CLASS, 32
}
return NOT_REG_CLASS, 0
}
var xtypeRe = regexp.MustCompile(`^([iuf])([0-9]+)$`)
// scalarBaseType describes the base type of a scalar element. This is a Go
// type, but without the bit width suffix (with the exception of
// scalarBaseIntOrUint).
type scalarBaseType int
const (
scalarBaseInt scalarBaseType = iota
scalarBaseUint
scalarBaseIntOrUint // Signed or unsigned is unspecified
scalarBaseFloat
scalarBaseComplex
scalarBaseBFloat
scalarBaseHFloat
)
func (s scalarBaseType) regex() string {
switch s {
case scalarBaseInt:
return "int"
case scalarBaseUint:
return "uint"
case scalarBaseIntOrUint:
return "int|uint"
case scalarBaseFloat:
return "float"
case scalarBaseComplex:
return "complex"
case scalarBaseBFloat:
return "BFloat"
case scalarBaseHFloat:
return "HFloat"
}
panic(fmt.Sprintf("unknown scalar base type %d", s))
}
func decodeType(op *xeddata.Operand) (base scalarBaseType, bits int, ok bool) {
// The xtype tells you the element type. i8, i16, i32, i64, f32, etc.
//
// TODO: Things like AVX2 VPAND have an xtype of u256 because they're
// element-width agnostic. Do I map that to all widths, or just omit the
// element width and let unification flesh it out? There's no u512
// (presumably those are all masked, so elem width matters). These are all
// Category: LOGICAL, so maybe we could use that info?
// Handle some weird ones.
switch op.Xtype {
// 8-bit float formats as defined by Open Compute Project "OCP 8-bit
// Floating Point Specification (OFP8)".
case "bf8": // E5M2 float
return scalarBaseBFloat, 8, true
case "hf8": // E4M3 float
return scalarBaseHFloat, 8, true
case "bf16": // bfloat16 float
return scalarBaseBFloat, 16, true
case "2f16":
// Complex consisting of 2 float16s. Doesn't exist in Go, but we can say
// what it would be.
return scalarBaseComplex, 32, true
case "2i8", "2I8":
// These just use the lower INT8 in each 16 bit field.
// As far as I can tell, "2I8" is a typo.
return scalarBaseInt, 8, true
case "2u16", "2U16":
// some VPDP* has it
// TODO: does "z" means it has zeroing?
return scalarBaseUint, 16, true
case "2i16", "2I16":
// some VPDP* has it
return scalarBaseInt, 16, true
case "4u8", "4U8":
// some VPDP* has it
return scalarBaseUint, 8, true
case "4i8", "4I8":
// some VPDP* has it
return scalarBaseInt, 8, true
}
// The rest follow a simple pattern.
m := xtypeRe.FindStringSubmatch(op.Xtype)
if m == nil {
// TODO: Report unrecognized xtype
return 0, 0, false
}
bits, _ = strconv.Atoi(m[2])
switch m[1] {
case "i", "u":
// XED is rather inconsistent about what's signed, unsigned, or doesn't
// matter, so merge them together and let the Go definitions narrow as
// appropriate. Maybe there's a better way to do this.
return scalarBaseIntOrUint, bits, true
case "f":
return scalarBaseFloat, bits, true
default:
panic("unreachable")
}
}

View file

@ -0,0 +1,154 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"fmt"
"iter"
"maps"
"slices"
)
type Closure struct {
val *Value
env envSet
}
func NewSum(vs ...*Value) Closure {
id := &ident{name: "sum"}
return Closure{NewValue(Var{id}), topEnv.bind(id, vs...)}
}
// IsBottom returns whether c consists of no values.
func (c Closure) IsBottom() bool {
return c.val.Domain == nil
}
// Summands returns the top-level Values of c. This assumes the top-level of c
// was constructed as a sum, and is mostly useful for debugging.
func (c Closure) Summands() iter.Seq[*Value] {
return func(yield func(*Value) bool) {
var rec func(v *Value, env envSet) bool
rec = func(v *Value, env envSet) bool {
switch d := v.Domain.(type) {
case Var:
parts := env.partitionBy(d.id)
for _, part := range parts {
// It may be a sum of sums. Walk into this value.
if !rec(part.value, part.env) {
return false
}
}
return true
default:
return yield(v)
}
}
rec(c.val, c.env)
}
}
// All enumerates all possible concrete values of c by substituting variables
// from the environment.
//
// E.g., enumerating this Value
//
// a: !sum [1, 2]
// b: !sum [3, 4]
//
// results in
//
// - {a: 1, b: 3}
// - {a: 1, b: 4}
// - {a: 2, b: 3}
// - {a: 2, b: 4}
func (c Closure) All() iter.Seq[*Value] {
// In order to enumerate all concrete values under all possible variable
// bindings, we use a "non-deterministic continuation passing style" to
// implement this. We use CPS to traverse the Value tree, threading the
// (possibly narrowing) environment through that CPS following an Euler
// tour. Where the environment permits multiple choices, we invoke the same
// continuation for each choice. Similar to a yield function, the
// continuation can return false to stop the non-deterministic walk.
return func(yield func(*Value) bool) {
c.val.all1(c.env, func(v *Value, e envSet) bool {
return yield(v)
})
}
}
func (v *Value) all1(e envSet, cont func(*Value, envSet) bool) bool {
switch d := v.Domain.(type) {
default:
panic(fmt.Sprintf("unknown domain type %T", d))
case nil:
return true
case Top, String:
return cont(v, e)
case Def:
fields := d.keys()
// We can reuse this parts slice because we're doing a DFS through the
// state space. (Otherwise, we'd have to do some messy threading of an
// immutable slice-like value through allElt.)
parts := make(map[string]*Value, len(fields))
// TODO: If there are no Vars or Sums under this Def, then nothing can
// change the Value or env, so we could just cont(v, e).
var allElt func(elt int, e envSet) bool
allElt = func(elt int, e envSet) bool {
if elt == len(fields) {
// Build a new Def from the concrete parts. Clone parts because
// we may reuse it on other non-deterministic branches.
nVal := newValueFrom(Def{maps.Clone(parts)}, v)
return cont(nVal, e)
}
return d.fields[fields[elt]].all1(e, func(v *Value, e envSet) bool {
parts[fields[elt]] = v
return allElt(elt+1, e)
})
}
return allElt(0, e)
case Tuple:
// Essentially the same as Def.
if d.repeat != nil {
// There's nothing we can do with this.
return cont(v, e)
}
parts := make([]*Value, len(d.vs))
var allElt func(elt int, e envSet) bool
allElt = func(elt int, e envSet) bool {
if elt == len(d.vs) {
// Build a new tuple from the concrete parts. Clone parts because
// we may reuse it on other non-deterministic branches.
nVal := newValueFrom(Tuple{vs: slices.Clone(parts)}, v)
return cont(nVal, e)
}
return d.vs[elt].all1(e, func(v *Value, e envSet) bool {
parts[elt] = v
return allElt(elt+1, e)
})
}
return allElt(0, e)
case Var:
// Go each way this variable can be bound.
for _, ePart := range e.partitionBy(d.id) {
// d.id is no longer bound in this environment partition. We'll may
// need it later in the Euler tour, so bind it back to this single
// value.
env := ePart.env.bind(d.id, ePart.value)
if !ePart.value.all1(env, cont) {
return false
}
}
return true
}
}

View file

@ -0,0 +1,359 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"fmt"
"iter"
"maps"
"reflect"
"regexp"
"slices"
"strconv"
"strings"
)
// A Domain is a non-empty set of values, all of the same kind.
//
// Domain may be a scalar:
//
// - [String] - Represents string-typed values.
//
// Or a composite:
//
// - [Def] - A mapping from fixed keys to [Domain]s.
//
// - [Tuple] - A fixed-length sequence of [Domain]s or
// all possible lengths repeating a [Domain].
//
// Or top or bottom:
//
// - [Top] - Represents all possible values of all kinds.
//
// - nil - Represents no values.
//
// Or a variable:
//
// - [Var] - A value captured in the environment.
type Domain interface {
Exact() bool
WhyNotExact() string
// decode stores this value in a Go value. If this value is not exact, this
// returns a potentially wrapped *inexactError.
decode(reflect.Value) error
}
type inexactError struct {
valueType string
goType string
}
func (e *inexactError) Error() string {
return fmt.Sprintf("cannot store inexact %s value in %s", e.valueType, e.goType)
}
type decodeError struct {
path string
err error
}
func newDecodeError(path string, err error) *decodeError {
if err, ok := err.(*decodeError); ok {
return &decodeError{path: path + "." + err.path, err: err.err}
}
return &decodeError{path: path, err: err}
}
func (e *decodeError) Unwrap() error {
return e.err
}
func (e *decodeError) Error() string {
return fmt.Sprintf("%s: %s", e.path, e.err)
}
// Top represents all possible values of all possible types.
type Top struct{}
func (t Top) Exact() bool { return false }
func (t Top) WhyNotExact() string { return "is top" }
func (t Top) decode(rv reflect.Value) error {
// We can decode Top into a pointer-typed value as nil.
if rv.Kind() != reflect.Pointer {
return &inexactError{"top", rv.Type().String()}
}
rv.SetZero()
return nil
}
// A Def is a mapping from field names to [Value]s. Any fields not explicitly
// listed have [Value] [Top].
type Def struct {
fields map[string]*Value
}
// A DefBuilder builds a [Def] one field at a time. The zero value is an empty
// [Def].
type DefBuilder struct {
fields map[string]*Value
}
func (b *DefBuilder) Add(name string, v *Value) {
if b.fields == nil {
b.fields = make(map[string]*Value)
}
if _, ok := b.fields[name]; ok {
panic(fmt.Sprintf("duplicate field %q", name))
}
b.fields[name] = v
}
// Build constructs a [Def] from the fields added to this builder.
func (b *DefBuilder) Build() Def {
return Def{maps.Clone(b.fields)}
}
// Exact returns true if all field Values are exact.
func (d Def) Exact() bool {
for _, v := range d.fields {
if !v.Exact() {
return false
}
}
return true
}
// WhyNotExact returns why the value is not exact
func (d Def) WhyNotExact() string {
for s, v := range d.fields {
if !v.Exact() {
w := v.WhyNotExact()
return "field " + s + ": " + w
}
}
return ""
}
func (d Def) decode(rv reflect.Value) error {
if rv.Kind() != reflect.Struct {
return fmt.Errorf("cannot decode Def into %s", rv.Type())
}
var lowered map[string]string // Lower case -> canonical for d.fields.
rt := rv.Type()
for fi := range rv.NumField() {
fType := rt.Field(fi)
if fType.PkgPath != "" {
continue
}
v := d.fields[fType.Name]
if v == nil {
v = topValue
// Try a case-insensitive match
canon, ok := d.fields[strings.ToLower(fType.Name)]
if ok {
v = canon
} else {
if lowered == nil {
lowered = make(map[string]string, len(d.fields))
for k := range d.fields {
l := strings.ToLower(k)
if k != l {
lowered[l] = k
}
}
}
canon, ok := lowered[strings.ToLower(fType.Name)]
if ok {
v = d.fields[canon]
}
}
}
if err := decodeReflect(v, rv.Field(fi)); err != nil {
return newDecodeError(fType.Name, err)
}
}
return nil
}
func (d Def) keys() []string {
return slices.Sorted(maps.Keys(d.fields))
}
func (d Def) All() iter.Seq2[string, *Value] {
// TODO: We call All fairly often. It's probably bad to sort this every
// time.
keys := slices.Sorted(maps.Keys(d.fields))
return func(yield func(string, *Value) bool) {
for _, k := range keys {
if !yield(k, d.fields[k]) {
return
}
}
}
}
// A Tuple is a sequence of Values in one of two forms: 1. a fixed-length tuple,
// where each Value can be different or 2. a "repeated tuple", which is a Value
// repeated 0 or more times.
type Tuple struct {
vs []*Value
// repeat, if non-nil, means this Tuple consists of an element repeated 0 or
// more times. If repeat is non-nil, vs must be nil. This is a generator
// function because we don't necessarily want *exactly* the same Value
// repeated. For example, in YAML encoding, a !sum in a repeated tuple needs
// a fresh variable in each instance.
repeat []func(envSet) (*Value, envSet)
}
func NewTuple(vs ...*Value) Tuple {
return Tuple{vs: vs}
}
func NewRepeat(gens ...func(envSet) (*Value, envSet)) Tuple {
return Tuple{repeat: gens}
}
func (d Tuple) Exact() bool {
if d.repeat != nil {
return false
}
for _, v := range d.vs {
if !v.Exact() {
return false
}
}
return true
}
func (d Tuple) WhyNotExact() string {
if d.repeat != nil {
return "d.repeat is not nil"
}
for i, v := range d.vs {
if !v.Exact() {
w := v.WhyNotExact()
return "index " + strconv.FormatInt(int64(i), 10) + ": " + w
}
}
return ""
}
func (d Tuple) decode(rv reflect.Value) error {
if d.repeat != nil {
return &inexactError{"repeated tuple", rv.Type().String()}
}
// TODO: We could also do arrays.
if rv.Kind() != reflect.Slice {
return fmt.Errorf("cannot decode Tuple into %s", rv.Type())
}
if rv.IsNil() || rv.Cap() < len(d.vs) {
rv.Set(reflect.MakeSlice(rv.Type(), len(d.vs), len(d.vs)))
} else {
rv.SetLen(len(d.vs))
}
for i, v := range d.vs {
if err := decodeReflect(v, rv.Index(i)); err != nil {
return newDecodeError(fmt.Sprintf("%d", i), err)
}
}
return nil
}
// A String represents a set of strings. It can represent the intersection of a
// set of regexps, or a single exact string. In general, the domain of a String
// is non-empty, but we do not attempt to prove emptiness of a regexp value.
type String struct {
kind stringKind
re []*regexp.Regexp // Intersection of regexps
exact string
}
type stringKind int
const (
stringRegex stringKind = iota
stringExact
)
func NewStringRegex(exprs ...string) (String, error) {
if len(exprs) == 0 {
exprs = []string{""}
}
v := String{kind: -1}
for _, expr := range exprs {
if expr == "" {
// Skip constructing the regexp. It won't have a "literal prefix"
// and so we wind up thinking this is a regexp instead of an exact
// (empty) string.
v = String{kind: stringExact, exact: ""}
continue
}
re, err := regexp.Compile(`\A(?:` + expr + `)\z`)
if err != nil {
return String{}, fmt.Errorf("parsing value: %s", err)
}
// An exact value narrows the whole domain to exact, so we're done, but
// should keep parsing.
if v.kind == stringExact {
continue
}
if exact, complete := re.LiteralPrefix(); complete {
v = String{kind: stringExact, exact: exact}
} else {
v.kind = stringRegex
v.re = append(v.re, re)
}
}
return v, nil
}
func NewStringExact(s string) String {
return String{kind: stringExact, exact: s}
}
// Exact returns whether this Value is known to consist of a single string.
func (d String) Exact() bool {
return d.kind == stringExact
}
func (d String) WhyNotExact() string {
if d.kind == stringExact {
return ""
}
return "string is not exact"
}
func (d String) decode(rv reflect.Value) error {
if d.kind != stringExact {
return &inexactError{"regex", rv.Type().String()}
}
switch rv.Kind() {
default:
return fmt.Errorf("cannot decode String into %s", rv.Type())
case reflect.String:
rv.SetString(d.exact)
case reflect.Int:
i, err := strconv.Atoi(d.exact)
if err != nil {
return fmt.Errorf("cannot decode String into %s: %s", rv.Type(), err)
}
rv.SetInt(int64(i))
case reflect.Bool:
b, err := strconv.ParseBool(d.exact)
if err != nil {
return fmt.Errorf("cannot decode String into %s: %s", rv.Type(), err)
}
rv.SetBool(b)
}
return nil
}

221
src/simd/_gen/unify/dot.go Normal file
View file

@ -0,0 +1,221 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"bytes"
"fmt"
"html"
"io"
"os"
"os/exec"
"strings"
)
const maxNodes = 30
type dotEncoder struct {
w *bytes.Buffer
idGen int // Node name generation
valLimit int // Limit the number of Values in a subgraph
idp identPrinter
}
func newDotEncoder() *dotEncoder {
return &dotEncoder{
w: new(bytes.Buffer),
}
}
func (enc *dotEncoder) clear() {
enc.w.Reset()
enc.idGen = 0
}
func (enc *dotEncoder) writeTo(w io.Writer) {
fmt.Fprintln(w, "digraph {")
// Use the "new" ranking algorithm, which lets us put nodes from different
// clusters in the same rank.
fmt.Fprintln(w, "newrank=true;")
fmt.Fprintln(w, "node [shape=box, ordering=out];")
w.Write(enc.w.Bytes())
fmt.Fprintln(w, "}")
}
func (enc *dotEncoder) writeSvg(w io.Writer) error {
cmd := exec.Command("dot", "-Tsvg")
in, err := cmd.StdinPipe()
if err != nil {
return err
}
var out bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return err
}
enc.writeTo(in)
in.Close()
if err := cmd.Wait(); err != nil {
return err
}
// Trim SVG header so the result can be embedded
//
// TODO: In Graphviz 10.0.1, we could use -Tsvg_inline.
svg := out.Bytes()
if i := bytes.Index(svg, []byte("<svg ")); i >= 0 {
svg = svg[i:]
}
_, err = w.Write(svg)
return err
}
func (enc *dotEncoder) newID(f string) string {
id := fmt.Sprintf(f, enc.idGen)
enc.idGen++
return id
}
func (enc *dotEncoder) node(label, sublabel string) string {
id := enc.newID("n%d")
l := html.EscapeString(label)
if sublabel != "" {
l += fmt.Sprintf("<BR ALIGN=\"CENTER\"/><FONT POINT-SIZE=\"10\">%s</FONT>", html.EscapeString(sublabel))
}
fmt.Fprintf(enc.w, "%s [label=<%s>];\n", id, l)
return id
}
func (enc *dotEncoder) edge(from, to string, label string, args ...any) {
l := fmt.Sprintf(label, args...)
fmt.Fprintf(enc.w, "%s -> %s [label=%q];\n", from, to, l)
}
func (enc *dotEncoder) valueSubgraph(v *Value) {
enc.valLimit = maxNodes
cID := enc.newID("cluster_%d")
fmt.Fprintf(enc.w, "subgraph %s {\n", cID)
fmt.Fprintf(enc.w, "style=invis;")
vID := enc.value(v)
fmt.Fprintf(enc.w, "}\n")
// We don't need the IDs right now.
_, _ = cID, vID
}
func (enc *dotEncoder) value(v *Value) string {
if enc.valLimit <= 0 {
id := enc.newID("n%d")
fmt.Fprintf(enc.w, "%s [label=\"...\", shape=triangle];\n", id)
return id
}
enc.valLimit--
switch vd := v.Domain.(type) {
default:
panic(fmt.Sprintf("unknown domain type %T", vd))
case nil:
return enc.node("_|_", "")
case Top:
return enc.node("_", "")
// TODO: Like in YAML, figure out if this is just a sum. In dot, we
// could say any unentangled variable is a sum, and if it has more than
// one reference just share the node.
// case Sum:
// node := enc.node("Sum", "")
// for i, elt := range vd.vs {
// enc.edge(node, enc.value(elt), "%d", i)
// if enc.valLimit <= 0 {
// break
// }
// }
// return node
case Def:
node := enc.node("Def", "")
for k, v := range vd.All() {
enc.edge(node, enc.value(v), "%s", k)
if enc.valLimit <= 0 {
break
}
}
return node
case Tuple:
if vd.repeat == nil {
label := "Tuple"
node := enc.node(label, "")
for i, elt := range vd.vs {
enc.edge(node, enc.value(elt), "%d", i)
if enc.valLimit <= 0 {
break
}
}
return node
} else {
// TODO
return enc.node("TODO: Repeat", "")
}
case String:
switch vd.kind {
case stringExact:
return enc.node(fmt.Sprintf("%q", vd.exact), "")
case stringRegex:
var parts []string
for _, re := range vd.re {
parts = append(parts, fmt.Sprintf("%q", re))
}
return enc.node(strings.Join(parts, "&"), "")
}
panic("bad String kind")
case Var:
return enc.node(fmt.Sprintf("Var %s", enc.idp.unique(vd.id)), "")
}
}
func (enc *dotEncoder) envSubgraph(e envSet) {
enc.valLimit = maxNodes
cID := enc.newID("cluster_%d")
fmt.Fprintf(enc.w, "subgraph %s {\n", cID)
fmt.Fprintf(enc.w, "style=invis;")
vID := enc.env(e.root)
fmt.Fprintf(enc.w, "}\n")
_, _ = cID, vID
}
func (enc *dotEncoder) env(e *envExpr) string {
switch e.kind {
default:
panic("bad kind")
case envZero:
return enc.node("0", "")
case envUnit:
return enc.node("1", "")
case envBinding:
node := enc.node(fmt.Sprintf("%q :", enc.idp.unique(e.id)), "")
enc.edge(node, enc.value(e.val), "")
return node
case envProduct:
node := enc.node("", "")
for _, op := range e.operands {
enc.edge(node, enc.env(op), "")
}
return node
case envSum:
node := enc.node("+", "")
for _, op := range e.operands {
enc.edge(node, enc.env(op), "")
}
return node
}
}

480
src/simd/_gen/unify/env.go Normal file
View file

@ -0,0 +1,480 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"fmt"
"iter"
"reflect"
"strings"
)
// An envSet is an immutable set of environments, where each environment is a
// mapping from [ident]s to [Value]s.
//
// To keep this compact, we use an algebraic representation similar to
// relational algebra. The atoms are zero, unit, or a singular binding:
//
// - A singular binding is an environment set consisting of a single environment
// that binds a single ident to a single value.
//
// - Zero is the empty set.
//
// - Unit is an environment set consisting of a single, empty environment (no
// bindings).
//
// From these, we build up more complex sets of environments using sums and
// cross products:
//
// - A sum is simply the union of the two environment sets.
//
// - A cross product is the Cartesian product of the two environment sets,
// followed by combining each pair of environments. Combining simply merges the
// two mappings, but fails if the mappings overlap.
//
// For example, to represent {{x: 1, y: 1}, {x: 2, y: 2}}, we build the two
// environments and sum them:
//
// ({x: 1} {y: 1}) + ({x: 2} {y: 2})
//
// If we add a third variable z that can be 1 or 2, independent of x and y, we
// get four logical environments:
//
// {x: 1, y: 1, z: 1}
// {x: 2, y: 2, z: 1}
// {x: 1, y: 1, z: 2}
// {x: 2, y: 2, z: 2}
//
// This could be represented as a sum of all four environments, but because z is
// independent, we can use a more compact representation:
//
// (({x: 1} {y: 1}) + ({x: 2} {y: 2})) ({z: 1} + {z: 2})
//
// Environment sets obey commutative algebra rules:
//
// e + 0 = e
// e 0 = 0
// e 1 = e
// e + f = f + e
// e f = f e
type envSet struct {
root *envExpr
}
type envExpr struct {
// TODO: A tree-based data structure for this may not be ideal, since it
// involves a lot of walking to find things and we often have to do deep
// rewrites anyway for partitioning. Would some flattened array-style
// representation be better, possibly combined with an index of ident uses?
// We could even combine that with an immutable array abstraction (ala
// Clojure) that could enable more efficient construction operations.
kind envExprKind
// For envBinding
id *ident
val *Value
// For sum or product. Len must be >= 2 and none of the elements can have
// the same kind as this node.
operands []*envExpr
}
type envExprKind byte
const (
envZero envExprKind = iota
envUnit
envProduct
envSum
envBinding
)
var (
// topEnv is the unit value (multiplicative identity) of a [envSet].
topEnv = envSet{envExprUnit}
// bottomEnv is the zero value (additive identity) of a [envSet].
bottomEnv = envSet{envExprZero}
envExprZero = &envExpr{kind: envZero}
envExprUnit = &envExpr{kind: envUnit}
)
// bind binds id to each of vals in e.
//
// Its panics if id is already bound in e.
//
// Environments are typically initially constructed by starting with [topEnv]
// and calling bind one or more times.
func (e envSet) bind(id *ident, vals ...*Value) envSet {
if e.isEmpty() {
return bottomEnv
}
// TODO: If any of vals are _, should we just drop that val? We're kind of
// inconsistent about whether an id missing from e means id is invalid or
// means id is _.
// Check that id isn't present in e.
for range e.root.bindings(id) {
panic("id " + id.name + " already present in environment")
}
// Create a sum of all the values.
bindings := make([]*envExpr, 0, 1)
for _, val := range vals {
bindings = append(bindings, &envExpr{kind: envBinding, id: id, val: val})
}
// Multiply it in.
return envSet{newEnvExprProduct(e.root, newEnvExprSum(bindings...))}
}
func (e envSet) isEmpty() bool {
return e.root.kind == envZero
}
// bindings yields all [envBinding] nodes in e with the given id. If id is nil,
// it yields all binding nodes.
func (e *envExpr) bindings(id *ident) iter.Seq[*envExpr] {
// This is just a pre-order walk and it happens this is the only thing we
// need a pre-order walk for.
return func(yield func(*envExpr) bool) {
var rec func(e *envExpr) bool
rec = func(e *envExpr) bool {
if e.kind == envBinding && (id == nil || e.id == id) {
if !yield(e) {
return false
}
}
for _, o := range e.operands {
if !rec(o) {
return false
}
}
return true
}
rec(e)
}
}
// newEnvExprProduct constructs a product node from exprs, performing
// simplifications. It does NOT check that bindings are disjoint.
func newEnvExprProduct(exprs ...*envExpr) *envExpr {
factors := make([]*envExpr, 0, 2)
for _, expr := range exprs {
switch expr.kind {
case envZero:
return envExprZero
case envUnit:
// No effect on product
case envProduct:
factors = append(factors, expr.operands...)
default:
factors = append(factors, expr)
}
}
if len(factors) == 0 {
return envExprUnit
} else if len(factors) == 1 {
return factors[0]
}
return &envExpr{kind: envProduct, operands: factors}
}
// newEnvExprSum constructs a sum node from exprs, performing simplifications.
func newEnvExprSum(exprs ...*envExpr) *envExpr {
// TODO: If all of envs are products (or bindings), factor any common terms.
// E.g., x * y + x * z ==> x * (y + z). This is easy to do for binding
// terms, but harder to do for more general terms.
var have smallSet[*envExpr]
terms := make([]*envExpr, 0, 2)
for _, expr := range exprs {
switch expr.kind {
case envZero:
// No effect on sum
case envSum:
for _, expr1 := range expr.operands {
if have.Add(expr1) {
terms = append(terms, expr1)
}
}
default:
if have.Add(expr) {
terms = append(terms, expr)
}
}
}
if len(terms) == 0 {
return envExprZero
} else if len(terms) == 1 {
return terms[0]
}
return &envExpr{kind: envSum, operands: terms}
}
func crossEnvs(env1, env2 envSet) envSet {
// Confirm that envs have disjoint idents.
var ids1 smallSet[*ident]
for e := range env1.root.bindings(nil) {
ids1.Add(e.id)
}
for e := range env2.root.bindings(nil) {
if ids1.Has(e.id) {
panic(fmt.Sprintf("%s bound on both sides of cross-product", e.id.name))
}
}
return envSet{newEnvExprProduct(env1.root, env2.root)}
}
func unionEnvs(envs ...envSet) envSet {
exprs := make([]*envExpr, len(envs))
for i := range envs {
exprs[i] = envs[i].root
}
return envSet{newEnvExprSum(exprs...)}
}
// envPartition is a subset of an env where id is bound to value in all
// deterministic environments.
type envPartition struct {
id *ident
value *Value
env envSet
}
// partitionBy splits e by distinct bindings of id and removes id from each
// partition.
//
// If there are environments in e where id is not bound, they will not be
// reflected in any partition.
//
// It panics if e is bottom, since attempting to partition an empty environment
// set almost certainly indicates a bug.
func (e envSet) partitionBy(id *ident) []envPartition {
if e.isEmpty() {
// We could return zero partitions, but getting here at all almost
// certainly indicates a bug.
panic("cannot partition empty environment set")
}
// Emit a partition for each value of id.
var seen smallSet[*Value]
var parts []envPartition
for n := range e.root.bindings(id) {
if !seen.Add(n.val) {
// Already emitted a partition for this value.
continue
}
parts = append(parts, envPartition{
id: id,
value: n.val,
env: envSet{e.root.substitute(id, n.val)},
})
}
return parts
}
// substitute replaces bindings of id to val with 1 and bindings of id to any
// other value with 0 and simplifies the result.
func (e *envExpr) substitute(id *ident, val *Value) *envExpr {
switch e.kind {
default:
panic("bad kind")
case envZero, envUnit:
return e
case envBinding:
if e.id != id {
return e
} else if e.val != val {
return envExprZero
} else {
return envExprUnit
}
case envProduct, envSum:
// Substitute each operand. Sometimes, this won't change anything, so we
// build the new operands list lazily.
var nOperands []*envExpr
for i, op := range e.operands {
nOp := op.substitute(id, val)
if nOperands == nil && op != nOp {
// Operand diverged; initialize nOperands.
nOperands = make([]*envExpr, 0, len(e.operands))
nOperands = append(nOperands, e.operands[:i]...)
}
if nOperands != nil {
nOperands = append(nOperands, nOp)
}
}
if nOperands == nil {
// Nothing changed.
return e
}
if e.kind == envProduct {
return newEnvExprProduct(nOperands...)
} else {
return newEnvExprSum(nOperands...)
}
}
}
// A smallSet is a set optimized for stack allocation when small.
type smallSet[T comparable] struct {
array [32]T
n int
m map[T]struct{}
}
// Has returns whether val is in set.
func (s *smallSet[T]) Has(val T) bool {
arr := s.array[:s.n]
for i := range arr {
if arr[i] == val {
return true
}
}
_, ok := s.m[val]
return ok
}
// Add adds val to the set and returns true if it was added (not already
// present).
func (s *smallSet[T]) Add(val T) bool {
// Test for presence.
if s.Has(val) {
return false
}
// Add it
if s.n < len(s.array) {
s.array[s.n] = val
s.n++
} else {
if s.m == nil {
s.m = make(map[T]struct{})
}
s.m[val] = struct{}{}
}
return true
}
type ident struct {
_ [0]func() // Not comparable (only compare *ident)
name string
}
type Var struct {
id *ident
}
func (d Var) Exact() bool {
// These can't appear in concrete Values.
panic("Exact called on non-concrete Value")
}
func (d Var) WhyNotExact() string {
// These can't appear in concrete Values.
return "WhyNotExact called on non-concrete Value"
}
func (d Var) decode(rv reflect.Value) error {
return &inexactError{"var", rv.Type().String()}
}
func (d Var) unify(w *Value, e envSet, swap bool, uf *unifier) (Domain, envSet, error) {
// TODO: Vars from !sums in the input can have a huge number of values.
// Unifying these could be way more efficient with some indexes over any
// exact values we can pull out, like Def fields that are exact Strings.
// Maybe we try to produce an array of yes/no/maybe matches and then we only
// have to do deeper evaluation of the maybes. We could probably cache this
// on an envTerm. It may also help to special-case Var/Var unification to
// pick which one to index versus enumerate.
if vd, ok := w.Domain.(Var); ok && d.id == vd.id {
// Unifying $x with $x results in $x. If we descend into this we'll have
// problems because we strip $x out of the environment to keep ourselves
// honest and then can't find it on the other side.
//
// TODO: I'm not positive this is the right fix.
return vd, e, nil
}
// We need to unify w with the value of d in each possible environment. We
// can save some work by grouping environments by the value of d, since
// there will be a lot of redundancy here.
var nEnvs []envSet
envParts := e.partitionBy(d.id)
for i, envPart := range envParts {
exit := uf.enterVar(d.id, i)
// Each branch logically gets its own copy of the initial environment
// (narrowed down to just this binding of the variable), and each branch
// may result in different changes to that starting environment.
res, e2, err := w.unify(envPart.value, envPart.env, swap, uf)
exit.exit()
if err != nil {
return nil, envSet{}, err
}
if res.Domain == nil {
// This branch entirely failed to unify, so it's gone.
continue
}
nEnv := e2.bind(d.id, res)
nEnvs = append(nEnvs, nEnv)
}
if len(nEnvs) == 0 {
// All branches failed
return nil, bottomEnv, nil
}
// The effect of this is entirely captured in the environment. We can return
// back the same Bind node.
return d, unionEnvs(nEnvs...), nil
}
// An identPrinter maps [ident]s to unique string names.
type identPrinter struct {
ids map[*ident]string
idGen map[string]int
}
func (p *identPrinter) unique(id *ident) string {
if p.ids == nil {
p.ids = make(map[*ident]string)
p.idGen = make(map[string]int)
}
name, ok := p.ids[id]
if !ok {
gen := p.idGen[id.name]
p.idGen[id.name]++
if gen == 0 {
name = id.name
} else {
name = fmt.Sprintf("%s#%d", id.name, gen)
}
p.ids[id] = name
}
return name
}
func (p *identPrinter) slice(ids []*ident) string {
var strs []string
for _, id := range ids {
strs = append(strs, p.unique(id))
}
return fmt.Sprintf("[%s]", strings.Join(strs, ", "))
}

123
src/simd/_gen/unify/html.go Normal file
View file

@ -0,0 +1,123 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"fmt"
"html"
"io"
"strings"
)
func (t *tracer) writeHTML(w io.Writer) {
if !t.saveTree {
panic("writeHTML called without tracer.saveTree")
}
fmt.Fprintf(w, "<html><head><style>%s</style></head>", htmlCSS)
for _, root := range t.trees {
dot := newDotEncoder()
html := htmlTracer{w: w, dot: dot}
html.writeTree(root)
}
fmt.Fprintf(w, "</html>\n")
}
const htmlCSS = `
.unify {
display: grid;
grid-auto-columns: min-content;
text-align: center;
}
.header {
grid-row: 1;
font-weight: bold;
padding: 0.25em;
position: sticky;
top: 0;
background: white;
}
.envFactor {
display: grid;
grid-auto-rows: min-content;
grid-template-columns: subgrid;
text-align: center;
}
`
type htmlTracer struct {
w io.Writer
dot *dotEncoder
svgs map[any]string
}
func (t *htmlTracer) writeTree(node *traceTree) {
// TODO: This could be really nice.
//
// - Put nodes that were unified on the same rank with {rank=same; a; b}
//
// - On hover, highlight nodes that node was unified with and the result. If
// it's a variable, highlight it in the environment, too.
//
// - On click, show the details of unifying that node.
//
// This could be the only way to navigate, without necessarily needing the
// whole nest of <detail> nodes.
// TODO: It might be possible to write this out on the fly.
t.emit([]*Value{node.v, node.w}, []string{"v", "w"}, node.envIn)
// Render children.
for i, child := range node.children {
if i >= 10 {
fmt.Fprintf(t.w, `<div style="margin-left: 4em">...</div>`)
break
}
fmt.Fprintf(t.w, `<details style="margin-left: 4em"><summary>%s</summary>`, html.EscapeString(child.label))
t.writeTree(child)
fmt.Fprintf(t.w, "</details>\n")
}
// Render result.
if node.err != nil {
fmt.Fprintf(t.w, "Error: %s\n", html.EscapeString(node.err.Error()))
} else {
t.emit([]*Value{node.res}, []string{"res"}, node.env)
}
}
func htmlSVG[Key comparable](t *htmlTracer, f func(Key), arg Key) string {
if s, ok := t.svgs[arg]; ok {
return s
}
var buf strings.Builder
f(arg)
t.dot.writeSvg(&buf)
t.dot.clear()
svg := buf.String()
if t.svgs == nil {
t.svgs = make(map[any]string)
}
t.svgs[arg] = svg
buf.Reset()
return svg
}
func (t *htmlTracer) emit(vs []*Value, labels []string, env envSet) {
fmt.Fprintf(t.w, `<div class="unify">`)
for i, v := range vs {
fmt.Fprintf(t.w, `<div class="header" style="grid-column: %d">%s</div>`, i+1, html.EscapeString(labels[i]))
fmt.Fprintf(t.w, `<div style="grid-area: 2 / %d">%s</div>`, i+1, htmlSVG(t, t.dot.valueSubgraph, v))
}
col := len(vs)
fmt.Fprintf(t.w, `<div class="header" style="grid-column: %d">in</div>`, col+1)
fmt.Fprintf(t.w, `<div style="grid-area: 2 / %d">%s</div>`, col+1, htmlSVG(t, t.dot.envSubgraph, env))
fmt.Fprintf(t.w, `</div>`)
}

View file

@ -0,0 +1,33 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"fmt"
)
type Pos struct {
Path string
Line int
}
func (p Pos) String() string {
var b []byte
b, _ = p.AppendText(b)
return string(b)
}
func (p Pos) AppendText(b []byte) ([]byte, error) {
if p.Line == 0 {
if p.Path == "" {
return append(b, "?:?"...), nil
} else {
return append(b, p.Path...), nil
}
} else if p.Path == "" {
return fmt.Appendf(b, "?:%d", p.Line), nil
}
return fmt.Appendf(b, "%s:%d", p.Path, p.Line), nil
}

View file

@ -0,0 +1,33 @@
# In the original representation of environments, this caused an exponential
# blowup in time and allocation. With that representation, this took about 20
# seconds on my laptop and had a max RSS of ~12 GB. Big enough to be really
# noticeable, but not so big it's likely to crash a developer machine. With the
# better environment representation, it runs almost instantly and has an RSS of
# ~90 MB.
unify:
- !sum
- !sum [1, 2]
- !sum [3, 4]
- !sum [5, 6]
- !sum [7, 8]
- !sum [9, 10]
- !sum [11, 12]
- !sum [13, 14]
- !sum [15, 16]
- !sum [17, 18]
- !sum [19, 20]
- !sum [21, 22]
- !sum
- !sum [1, 2]
- !sum [3, 4]
- !sum [5, 6]
- !sum [7, 8]
- !sum [9, 10]
- !sum [11, 12]
- !sum [13, 14]
- !sum [15, 16]
- !sum [17, 18]
- !sum [19, 20]
- !sum [21, 22]
all:
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]

174
src/simd/_gen/unify/testdata/unify.yaml vendored Normal file
View file

@ -0,0 +1,174 @@
# Basic tests of unification
#
# Terminals
#
unify:
- _
- _
want:
_
---
unify:
- _
- test
want:
test
---
unify:
- test
- t?est
want:
test
---
unify:
- 1
- 1
want:
1
---
unify:
- test
- foo
want:
_|_
#
# Tuple
#
---
unify:
- [a, b]
- [a, b]
want:
[a, b]
---
unify:
- [a, _]
- [_, b]
want:
[a, b]
---
unify:
- ["ab?c", "de?f"]
- [ac, def]
want:
[ac, def]
#
# Repeats
#
---
unify:
- !repeat [a]
- [_]
want:
[a]
---
unify:
- !repeat [a]
- [_, _]
want:
[a, a]
---
unify:
- !repeat [a]
- [b]
want:
_|_
---
unify:
- !repeat [xy*]
- [x, xy, xyy]
want:
[x, xy, xyy]
---
unify:
- !repeat [xy*]
- !repeat ["xz?y*"]
- [x, xy, xyy]
want:
[x, xy, xyy]
---
unify:
- !repeat [!sum [a, b]]
- [a, b, a]
all:
- [a, b, a]
---
unify:
- !repeat [!sum [a, b]]
- !repeat [!sum [b, c]]
- [b, b, b]
all:
- [b, b, b]
---
unify:
- !repeat [!sum [a, b]]
- !repeat [!sum [b, c]]
- [a]
all: []
#
# Def
#
---
unify:
- {a: a, b: b}
- {a: a, b: b}
want:
{a: a, b: b}
---
unify:
- {a: a}
- {b: b}
want:
{a: a, b: b}
#
# Sum
#
---
unify:
- !sum [1, 2]
- !sum [2, 3]
all:
- 2
---
unify:
- !sum [{label: a, value: abc}, {label: b, value: def}]
- !sum [{value: "ab?c", extra: d}, {value: "def?", extra: g}]
all:
- {extra: d, label: a, value: abc}
- {extra: g, label: b, value: def}
---
# A sum of repeats must deal with different dynamically-created variables in
# each branch.
unify:
- !sum [!repeat [a], !repeat [b]]
- [a, a, a]
all:
- [a, a, a]
---
unify:
- !sum [!repeat [a], !repeat [b]]
- [a, a, b]
all: []
---
# Exercise sumEnvs with more than one result
unify:
- !sum
- [a|b, c|d]
- [e, g]
- [!sum [a, b, e, f], !sum [c, d, g, h]]
all:
- [a, c]
- [a, d]
- [b, c]
- [b, d]
- [e, g]

175
src/simd/_gen/unify/testdata/vars.yaml vendored Normal file
View file

@ -0,0 +1,175 @@
#
# Basic tests
#
name: "basic string"
unify:
- $x
- test
all:
- test
---
name: "basic tuple"
unify:
- [$x, $x]
- [test, test]
all:
- [test, test]
---
name: "three tuples"
unify:
- [$x, $x]
- [test, _]
- [_, test]
all:
- [test, test]
---
name: "basic def"
unify:
- {a: $x, b: $x}
- {a: test, b: test}
all:
- {a: test, b: test}
---
name: "three defs"
unify:
- {a: $x, b: $x}
- {a: test}
- {b: test}
all:
- {a: test, b: test}
#
# Bottom tests
#
---
name: "basic bottom"
unify:
- [$x, $x]
- [test, foo]
all: []
---
name: "three-way bottom"
unify:
- [$x, $x]
- [test, _]
- [_, foo]
all: []
#
# Basic sum tests
#
---
name: "basic sum"
unify:
- $x
- !sum [a, b]
all:
- a
- b
---
name: "sum of tuples"
unify:
- [$x]
- !sum [[a], [b]]
all:
- [a]
- [b]
---
name: "acausal sum"
unify:
- [_, !sum [a, b]]
- [$x, $x]
all:
- [a, a]
- [b, b]
#
# Transitivity tests
#
---
name: "transitivity"
unify:
- [_, _, _, test]
- [$x, $x, _, _]
- [ _, $x, $x, _]
- [ _, _, $x, $x]
all:
- [test, test, test, test]
#
# Multiple vars
#
---
name: "basic uncorrelated vars"
unify:
- - !sum [1, 2]
- !sum [3, 4]
- - $a
- $b
all:
- [1, 3]
- [1, 4]
- [2, 3]
- [2, 4]
---
name: "uncorrelated vars"
unify:
- - !sum [1, 2]
- !sum [3, 4]
- !sum [1, 2]
- - $a
- $b
- $a
all:
- [1, 3, 1]
- [1, 4, 1]
- [2, 3, 2]
- [2, 4, 2]
---
name: "entangled vars"
unify:
- - !sum [[1,2],[3,4]]
- !sum [[2,1],[3,4],[4,3]]
- - [$a, $b]
- [$b, $a]
all:
- - [1, 2]
- [2, 1]
- - [3, 4]
- [4, 3]
#
# End-to-end examples
#
---
name: "end-to-end"
unify:
- go: Add
in:
- go: $t
- go: $t
- in: !repeat
- !sum
- go: Int32x4
base: int
- go: Uint32x4
base: uint
all:
- go: Add
in:
- base: int
go: Int32x4
- base: int
go: Int32x4
- go: Add
in:
- base: uint
go: Uint32x4
- base: uint
go: Uint32x4

View file

@ -0,0 +1,168 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"fmt"
"io"
"strings"
"gopkg.in/yaml.v3"
)
// debugDotInHTML, if true, includes dot code for all graphs in the HTML. Useful
// for debugging the dot output itself.
const debugDotInHTML = false
var Debug struct {
// UnifyLog, if non-nil, receives a streaming text trace of unification.
UnifyLog io.Writer
// HTML, if non-nil, writes an HTML trace of unification to HTML.
HTML io.Writer
}
type tracer struct {
logw io.Writer
enc yamlEncoder // Print consistent idents throughout
saveTree bool // if set, record tree; required for HTML output
path []string
node *traceTree
trees []*traceTree
}
type traceTree struct {
label string // Identifies this node as a child of parent
v, w *Value // Unification inputs
envIn envSet
res *Value // Unification result
env envSet
err error // or error
parent *traceTree
children []*traceTree
}
type tracerExit struct {
t *tracer
len int
node *traceTree
}
func (t *tracer) enter(pat string, vals ...any) tracerExit {
if t == nil {
return tracerExit{}
}
label := fmt.Sprintf(pat, vals...)
var p *traceTree
if t.saveTree {
p = t.node
if p != nil {
t.node = &traceTree{label: label, parent: p}
p.children = append(p.children, t.node)
}
}
t.path = append(t.path, label)
return tracerExit{t, len(t.path) - 1, p}
}
func (t *tracer) enterVar(id *ident, branch int) tracerExit {
if t == nil {
return tracerExit{}
}
// Use the tracer's ident printer
return t.enter("Var %s br %d", t.enc.idp.unique(id), branch)
}
func (te tracerExit) exit() {
if te.t == nil {
return
}
te.t.path = te.t.path[:te.len]
te.t.node = te.node
}
func indentf(prefix string, pat string, vals ...any) string {
s := fmt.Sprintf(pat, vals...)
if len(prefix) == 0 {
return s
}
if !strings.Contains(s, "\n") {
return prefix + s
}
indent := prefix
if strings.TrimLeft(prefix, " ") != "" {
// Prefix has non-space characters in it. Construct an all space-indent.
indent = strings.Repeat(" ", len(prefix))
}
return prefix + strings.ReplaceAll(s, "\n", "\n"+indent)
}
func yamlf(prefix string, node *yaml.Node) string {
b, err := yaml.Marshal(node)
if err != nil {
return fmt.Sprintf("<marshal failed: %s>", err)
}
return strings.TrimRight(indentf(prefix, "%s", b), " \n")
}
func (t *tracer) logf(pat string, vals ...any) {
if t == nil || t.logw == nil {
return
}
prefix := fmt.Sprintf("[%s] ", strings.Join(t.path, "/"))
s := indentf(prefix, pat, vals...)
s = strings.TrimRight(s, " \n")
fmt.Fprintf(t.logw, "%s\n", s)
}
func (t *tracer) traceUnify(v, w *Value, e envSet) {
if t == nil {
return
}
t.logf("Unify\n%s\nwith\n%s\nin\n%s",
yamlf(" ", t.enc.value(v)),
yamlf(" ", t.enc.value(w)),
yamlf(" ", t.enc.env(e)))
if t.saveTree {
if t.node == nil {
t.node = &traceTree{}
t.trees = append(t.trees, t.node)
}
t.node.v, t.node.w, t.node.envIn = v, w, e
}
}
func (t *tracer) traceDone(res *Value, e envSet, err error) {
if t == nil {
return
}
if err != nil {
t.logf("==> %s", err)
} else {
t.logf("==>\n%s", yamlf(" ", t.enc.closure(Closure{res, e})))
}
if t.saveTree {
node := t.node
if node == nil {
panic("popped top of trace stack")
}
node.res, node.err = res, err
node.env = e
}
}

View file

@ -0,0 +1,322 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package unify implements unification of structured values.
//
// A [Value] represents a possibly infinite set of concrete values, where a
// value is either a string ([String]), a tuple of values ([Tuple]), or a
// string-keyed map of values called a "def" ([Def]). These sets can be further
// constrained by variables ([Var]). A [Value] combined with bindings of
// variables is a [Closure].
//
// [Unify] finds a [Closure] that satisfies two or more other [Closure]s. This
// can be thought of as intersecting the sets represented by these Closures'
// values, or as the greatest lower bound/infimum of these Closures. If no such
// Closure exists, the result of unification is "bottom", or the empty set.
//
// # Examples
//
// The regular expression "a*" is the infinite set of strings of zero or more
// "a"s. "a*" can be unified with "a" or "aa" or "aaa", and the result is just
// "a", "aa", or "aaa", respectively. However, unifying "a*" with "b" fails
// because there are no values that satisfy both.
//
// Sums express sets directly. For example, !sum [a, b] is the set consisting of
// "a" and "b". Unifying this with !sum [b, c] results in just "b". This also
// makes it easy to demonstrate that unification isn't necessarily a single
// concrete value. For example, unifying !sum [a, b, c] with !sum [b, c, d]
// results in two concrete values: "b" and "c".
//
// The special value _ or "top" represents all possible values. Unifying _ with
// any value x results in x.
//
// Unifying composite values—tuples and defs—unifies their elements.
//
// The value [a*, aa] is an infinite set of tuples. If we unify that with the
// value [aaa, a*], the only possible value that satisfies both is [aaa, aa].
// Likewise, this is the intersection of the sets described by these two values.
//
// Defs are similar to tuples, but they are indexed by strings and don't have a
// fixed length. For example, {x: a, y: b} is a def with two fields. Any field
// not mentioned in a def is implicitly top. Thus, unifying this with {y: b, z:
// c} results in {x: a, y: b, z: c}.
//
// Variables constrain values. For example, the value [$x, $x] represents all
// tuples whose first and second values are the same, but doesn't otherwise
// constrain that value. Thus, this set includes [a, a] as well as [[b, c, d],
// [b, c, d]], but it doesn't include [a, b].
//
// Sums are internally implemented as fresh variables that are simultaneously
// bound to all values of the sum. That is !sum [a, b] is actually $var (where
// var is some fresh name), closed under the environment $var=a | $var=b.
package unify
import (
"errors"
"fmt"
"slices"
)
// Unify computes a Closure that satisfies each input Closure. If no such
// Closure exists, it returns bottom.
func Unify(closures ...Closure) (Closure, error) {
if len(closures) == 0 {
return Closure{topValue, topEnv}, nil
}
var trace *tracer
if Debug.UnifyLog != nil || Debug.HTML != nil {
trace = &tracer{
logw: Debug.UnifyLog,
saveTree: Debug.HTML != nil,
}
}
unified := closures[0]
for _, c := range closures[1:] {
var err error
uf := newUnifier()
uf.tracer = trace
e := crossEnvs(unified.env, c.env)
unified.val, unified.env, err = unified.val.unify(c.val, e, false, uf)
if Debug.HTML != nil {
uf.writeHTML(Debug.HTML)
}
if err != nil {
return Closure{}, err
}
}
return unified, nil
}
type unifier struct {
*tracer
}
func newUnifier() *unifier {
return &unifier{}
}
// errDomains is a sentinel error used between unify and unify1 to indicate that
// unify1 could not unify the domains of the two values.
var errDomains = errors.New("cannot unify domains")
func (v *Value) unify(w *Value, e envSet, swap bool, uf *unifier) (*Value, envSet, error) {
if swap {
// Put the values in order. This just happens to be a handy choke-point
// to do this at.
v, w = w, v
}
uf.traceUnify(v, w, e)
d, e2, err := v.unify1(w, e, false, uf)
if err == errDomains {
// Try the other order.
d, e2, err = w.unify1(v, e, true, uf)
if err == errDomains {
// Okay, we really can't unify these.
err = fmt.Errorf("cannot unify %T (%s) and %T (%s): kind mismatch", v.Domain, v.PosString(), w.Domain, w.PosString())
}
}
if err != nil {
uf.traceDone(nil, envSet{}, err)
return nil, envSet{}, err
}
res := unified(d, v, w)
uf.traceDone(res, e2, nil)
if d == nil {
// Double check that a bottom Value also has a bottom env.
if !e2.isEmpty() {
panic("bottom Value has non-bottom environment")
}
}
return res, e2, nil
}
func (v *Value) unify1(w *Value, e envSet, swap bool, uf *unifier) (Domain, envSet, error) {
// TODO: If there's an error, attach position information to it.
vd, wd := v.Domain, w.Domain
// Bottom returns bottom, and eliminates all possible environments.
if vd == nil || wd == nil {
return nil, bottomEnv, nil
}
// Top always returns the other.
if _, ok := vd.(Top); ok {
return wd, e, nil
}
// Variables
if vd, ok := vd.(Var); ok {
return vd.unify(w, e, swap, uf)
}
// Composite values
if vd, ok := vd.(Def); ok {
if wd, ok := wd.(Def); ok {
return vd.unify(wd, e, swap, uf)
}
}
if vd, ok := vd.(Tuple); ok {
if wd, ok := wd.(Tuple); ok {
return vd.unify(wd, e, swap, uf)
}
}
// Scalar values
if vd, ok := vd.(String); ok {
if wd, ok := wd.(String); ok {
res := vd.unify(wd)
if res == nil {
e = bottomEnv
}
return res, e, nil
}
}
return nil, envSet{}, errDomains
}
func (d Def) unify(o Def, e envSet, swap bool, uf *unifier) (Domain, envSet, error) {
out := Def{fields: make(map[string]*Value)}
// Check keys of d against o.
for key, dv := range d.All() {
ov, ok := o.fields[key]
if !ok {
// ov is implicitly Top. Bypass unification.
out.fields[key] = dv
continue
}
exit := uf.enter("%s", key)
res, e2, err := dv.unify(ov, e, swap, uf)
exit.exit()
if err != nil {
return nil, envSet{}, err
} else if res.Domain == nil {
// No match.
return nil, bottomEnv, nil
}
out.fields[key] = res
e = e2
}
// Check keys of o that we didn't already check. These all implicitly match
// because we know the corresponding fields in d are all Top.
for key, dv := range o.All() {
if _, ok := d.fields[key]; !ok {
out.fields[key] = dv
}
}
return out, e, nil
}
func (v Tuple) unify(w Tuple, e envSet, swap bool, uf *unifier) (Domain, envSet, error) {
if v.repeat != nil && w.repeat != nil {
// Since we generate the content of these lazily, there's not much we
// can do but just stick them on a list to unify later.
return Tuple{repeat: concat(v.repeat, w.repeat)}, e, nil
}
// Expand any repeated tuples.
tuples := make([]Tuple, 0, 2)
if v.repeat == nil {
tuples = append(tuples, v)
} else {
v2, e2 := v.doRepeat(e, len(w.vs))
tuples = append(tuples, v2...)
e = e2
}
if w.repeat == nil {
tuples = append(tuples, w)
} else {
w2, e2 := w.doRepeat(e, len(v.vs))
tuples = append(tuples, w2...)
e = e2
}
// Now unify all of the tuples (usually this will be just 2 tuples)
out := tuples[0]
for _, t := range tuples[1:] {
if len(out.vs) != len(t.vs) {
uf.logf("tuple length mismatch")
return nil, bottomEnv, nil
}
zs := make([]*Value, len(out.vs))
for i, v1 := range out.vs {
exit := uf.enter("%d", i)
z, e2, err := v1.unify(t.vs[i], e, swap, uf)
exit.exit()
if err != nil {
return nil, envSet{}, err
} else if z.Domain == nil {
return nil, bottomEnv, nil
}
zs[i] = z
e = e2
}
out = Tuple{vs: zs}
}
return out, e, nil
}
// doRepeat creates a fixed-length tuple from a repeated tuple. The caller is
// expected to unify the returned tuples.
func (v Tuple) doRepeat(e envSet, n int) ([]Tuple, envSet) {
res := make([]Tuple, len(v.repeat))
for i, gen := range v.repeat {
res[i].vs = make([]*Value, n)
for j := range n {
res[i].vs[j], e = gen(e)
}
}
return res, e
}
// unify intersects the domains of two [String]s. If it can prove that this
// domain is empty, it returns nil (bottom).
//
// TODO: Consider splitting literals and regexps into two domains.
func (v String) unify(w String) Domain {
// Unification is symmetric, so put them in order of string kind so we only
// have to deal with half the cases.
if v.kind > w.kind {
v, w = w, v
}
switch v.kind {
case stringRegex:
switch w.kind {
case stringRegex:
// Construct a match against all of the regexps
return String{kind: stringRegex, re: slices.Concat(v.re, w.re)}
case stringExact:
for _, re := range v.re {
if !re.MatchString(w.exact) {
return nil
}
}
return w
}
case stringExact:
if v.exact != w.exact {
return nil
}
return v
}
panic("bad string kind")
}
func concat[T any](s1, s2 []T) []T {
// Reuse s1 or s2 if possible.
if len(s1) == 0 {
return s2
}
return append(s1[:len(s1):len(s1)], s2...)
}

View file

@ -0,0 +1,154 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"gopkg.in/yaml.v3"
)
func TestUnify(t *testing.T) {
paths, err := filepath.Glob("testdata/*")
if err != nil {
t.Fatal(err)
}
if len(paths) == 0 {
t.Fatal("no testdata found")
}
for _, path := range paths {
// Skip paths starting with _ so experimental files can be added.
base := filepath.Base(path)
if base[0] == '_' {
continue
}
if !strings.HasSuffix(base, ".yaml") {
t.Errorf("non-.yaml file in testdata: %s", base)
continue
}
base = strings.TrimSuffix(base, ".yaml")
t.Run(base, func(t *testing.T) {
testUnify(t, path)
})
}
}
func testUnify(t *testing.T, path string) {
f, err := os.Open(path)
if err != nil {
t.Fatal(err)
}
defer f.Close()
type testCase struct {
Skip bool
Name string
Unify []Closure
Want yaml.Node
All yaml.Node
}
dec := yaml.NewDecoder(f)
for i := 0; ; i++ {
var tc testCase
err := dec.Decode(&tc)
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
name := tc.Name
if name == "" {
name = fmt.Sprint(i)
}
t.Run(name, func(t *testing.T) {
if tc.Skip {
t.Skip("skip: true set in test case")
}
defer func() {
p := recover()
if p != nil || t.Failed() {
// Redo with a trace
//
// TODO: Use t.Output() in Go 1.25.
var buf bytes.Buffer
Debug.UnifyLog = &buf
func() {
defer func() {
// If the original unify panicked, the second one
// probably will, too. Ignore it and let the first panic
// bubble.
recover()
}()
Unify(tc.Unify...)
}()
Debug.UnifyLog = nil
t.Logf("Trace:\n%s", buf.String())
}
if p != nil {
panic(p)
}
}()
// Unify the test cases
//
// TODO: Try reordering the inputs also
c, err := Unify(tc.Unify...)
if err != nil {
// TODO: Tests of errors
t.Fatal(err)
}
// Encode the result back to YAML so we can check if it's structurally
// equal.
clean := func(val any) *yaml.Node {
var node yaml.Node
node.Encode(val)
for n := range allYamlNodes(&node) {
// Canonicalize the style. There may be other style flags we need to
// muck with.
n.Style &^= yaml.FlowStyle
n.HeadComment = ""
n.LineComment = ""
n.FootComment = ""
}
return &node
}
check := func(gotVal any, wantNode *yaml.Node) {
got, err := yaml.Marshal(clean(gotVal))
if err != nil {
t.Fatalf("Encoding Value back to yaml failed: %s", err)
}
want, err := yaml.Marshal(clean(wantNode))
if err != nil {
t.Fatalf("Encoding Want back to yaml failed: %s", err)
}
if !bytes.Equal(got, want) {
t.Errorf("%s:%d:\nwant:\n%sgot\n%s", f.Name(), wantNode.Line, want, got)
}
}
if tc.Want.Kind != 0 {
check(c.val, &tc.Want)
}
if tc.All.Kind != 0 {
fVal := slices.Collect(c.All())
check(fVal, &tc.All)
}
})
}
}

View file

@ -0,0 +1,167 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"fmt"
"iter"
"reflect"
)
// A Value represents a structured, non-deterministic value consisting of
// strings, tuples of Values, and string-keyed maps of Values. A
// non-deterministic Value will also contain variables, which are resolved via
// an environment as part of a [Closure].
//
// For debugging, a Value can also track the source position it was read from in
// an input file, and its provenance from other Values.
type Value struct {
Domain Domain
// A Value has either a pos or parents (or neither).
pos *Pos
parents *[2]*Value
}
var (
topValue = &Value{Domain: Top{}}
bottomValue = &Value{Domain: nil}
)
// NewValue returns a new [Value] with the given domain and no position
// information.
func NewValue(d Domain) *Value {
return &Value{Domain: d}
}
// NewValuePos returns a new [Value] with the given domain at position p.
func NewValuePos(d Domain, p Pos) *Value {
return &Value{Domain: d, pos: &p}
}
// newValueFrom returns a new [Value] with the given domain that copies the
// position information of p.
func newValueFrom(d Domain, p *Value) *Value {
return &Value{Domain: d, pos: p.pos, parents: p.parents}
}
func unified(d Domain, p1, p2 *Value) *Value {
return &Value{Domain: d, parents: &[2]*Value{p1, p2}}
}
func (v *Value) Pos() Pos {
if v.pos == nil {
return Pos{}
}
return *v.pos
}
func (v *Value) PosString() string {
var b []byte
for root := range v.Provenance() {
if len(b) > 0 {
b = append(b, ' ')
}
b, _ = root.pos.AppendText(b)
}
return string(b)
}
func (v *Value) WhyNotExact() string {
if v.Domain == nil {
return "v.Domain is nil"
}
return v.Domain.WhyNotExact()
}
func (v *Value) Exact() bool {
if v.Domain == nil {
return false
}
return v.Domain.Exact()
}
// Decode decodes v into a Go value.
//
// v must be exact, except that it can include Top. into must be a pointer.
// [Def]s are decoded into structs. [Tuple]s are decoded into slices. [String]s
// are decoded into strings or ints. Any field can itself be a pointer to one of
// these types. Top can be decoded into a pointer-typed field and will set the
// field to nil. Anything else will allocate a value if necessary.
//
// Any type may implement [Decoder], in which case its DecodeUnified method will
// be called instead of using the default decoding scheme.
func (v *Value) Decode(into any) error {
rv := reflect.ValueOf(into)
if rv.Kind() != reflect.Pointer {
return fmt.Errorf("cannot decode into non-pointer %T", into)
}
return decodeReflect(v, rv.Elem())
}
func decodeReflect(v *Value, rv reflect.Value) error {
var ptr reflect.Value
if rv.Kind() == reflect.Pointer {
if rv.IsNil() {
// Transparently allocate through pointers, *except* for Top, which
// wants to set the pointer to nil.
//
// TODO: Drop this condition if I switch to an explicit Optional[T]
// or move the Top logic into Def.
if _, ok := v.Domain.(Top); !ok {
// Allocate the value to fill in, but don't actually store it in
// the pointer until we successfully decode.
ptr = rv
rv = reflect.New(rv.Type().Elem()).Elem()
}
} else {
rv = rv.Elem()
}
}
var err error
if reflect.PointerTo(rv.Type()).Implements(decoderType) {
// Use the custom decoder.
err = rv.Addr().Interface().(Decoder).DecodeUnified(v)
} else {
err = v.Domain.decode(rv)
}
if err == nil && ptr.IsValid() {
ptr.Set(rv.Addr())
}
return err
}
// Decoder can be implemented by types as a custom implementation of [Decode]
// for that type.
type Decoder interface {
DecodeUnified(v *Value) error
}
var decoderType = reflect.TypeOf((*Decoder)(nil)).Elem()
// Provenance iterates over all of the source Values that have contributed to
// this Value.
func (v *Value) Provenance() iter.Seq[*Value] {
return func(yield func(*Value) bool) {
var rec func(d *Value) bool
rec = func(d *Value) bool {
if d.pos != nil {
if !yield(d) {
return false
}
}
if d.parents != nil {
for _, p := range d.parents {
if !rec(p) {
return false
}
}
}
return true
}
rec(v)
}
}

View file

@ -0,0 +1,50 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"reflect"
"slices"
"testing"
)
func ExampleClosure_All_tuple() {
v := mustParse(`
- !sum [1, 2]
- !sum [3, 4]
`)
printYaml(slices.Collect(v.All()))
// Output:
// - [1, 3]
// - [1, 4]
// - [2, 3]
// - [2, 4]
}
func ExampleClosure_All_def() {
v := mustParse(`
a: !sum [1, 2]
b: !sum [3, 4]
c: 5
`)
printYaml(slices.Collect(v.All()))
// Output:
// - {a: 1, b: 3, c: 5}
// - {a: 1, b: 4, c: 5}
// - {a: 2, b: 3, c: 5}
// - {a: 2, b: 4, c: 5}
}
func checkDecode[T any](t *testing.T, got *Value, want T) {
var gotT T
if err := got.Decode(&gotT); err != nil {
t.Fatalf("Decode failed: %v", err)
}
if !reflect.DeepEqual(&gotT, &want) {
t.Fatalf("got:\n%s\nwant:\n%s", prettyYaml(gotT), prettyYaml(want))
}
}

619
src/simd/_gen/unify/yaml.go Normal file
View file

@ -0,0 +1,619 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"regexp"
"strings"
"gopkg.in/yaml.v3"
)
// ReadOpts provides options to [Read] and related functions. The zero value is
// the default options.
type ReadOpts struct {
// FS, if non-nil, is the file system from which to resolve !import file
// names.
FS fs.FS
}
// Read reads a [Closure] in YAML format from r, using path for error messages.
//
// It maps YAML nodes into terminal Values as follows:
//
// - "_" or !top _ is the top value ([Top]).
//
// - "_|_" or !bottom _ is the bottom value. This is an error during
// unmarshaling, but can appear in marshaled values.
//
// - "$<name>" or !var <name> is a variable ([Var]). Everywhere the same name
// appears within a single unmarshal operation, it is mapped to the same
// variable. Different unmarshal operations get different variables, even if
// they have the same string name.
//
// - !regex "x" is a regular expression ([String]), as is any string that
// doesn't match "_", "_|_", or "$...". Regular expressions are implicitly
// anchored at the beginning and end. If the string doesn't contain any
// meta-characters (that is, it's a "literal" regular expression), then it's
// treated as an exact string.
//
// - !string "x", or any int, float, bool, or binary value is an exact string
// ([String]).
//
// - !regex [x, y, ...] is an intersection of regular expressions ([String]).
//
// It maps YAML nodes into non-terminal Values as follows:
//
// - Sequence nodes like [x, y, z] are tuples ([Tuple]).
//
// - !repeat [x] is a repeated tuple ([Tuple]), which is 0 or more instances of
// x. There must be exactly one element in the list.
//
// - Mapping nodes like {a: x, b: y} are defs ([Def]). Any fields not listed are
// implicitly top.
//
// - !sum [x, y, z] is a sum of its children. This can be thought of as a union
// of the values x, y, and z, or as a non-deterministic choice between x, y, and
// z. If a variable appears both inside the sum and outside of it, only the
// non-deterministic choice view really works. The unifier does not directly
// implement sums; instead, this is decoded as a fresh variable that's
// simultaneously bound to x, y, and z.
//
// - !import glob is like a !sum, but its children are read from all files
// matching the given glob pattern, which is interpreted relative to the current
// file path. Each file gets its own variable scope.
func Read(r io.Reader, path string, opts ReadOpts) (Closure, error) {
dec := yamlDecoder{opts: opts, path: path, env: topEnv}
v, err := dec.read(r)
if err != nil {
return Closure{}, err
}
return dec.close(v), nil
}
// ReadFile reads a [Closure] in YAML format from a file.
//
// The file must consist of a single YAML document.
//
// If opts.FS is not set, this sets it to a FS rooted at path's directory.
//
// See [Read] for details.
func ReadFile(path string, opts ReadOpts) (Closure, error) {
f, err := os.Open(path)
if err != nil {
return Closure{}, err
}
defer f.Close()
if opts.FS == nil {
opts.FS = os.DirFS(filepath.Dir(path))
}
return Read(f, path, opts)
}
// UnmarshalYAML implements [yaml.Unmarshaler].
//
// Since there is no way to pass [ReadOpts] to this function, it assumes default
// options.
func (c *Closure) UnmarshalYAML(node *yaml.Node) error {
dec := yamlDecoder{path: "<yaml.Node>", env: topEnv}
v, err := dec.root(node)
if err != nil {
return err
}
*c = dec.close(v)
return nil
}
type yamlDecoder struct {
opts ReadOpts
path string
vars map[string]*ident
nSums int
env envSet
}
func (dec *yamlDecoder) read(r io.Reader) (*Value, error) {
n, err := readOneNode(r)
if err != nil {
return nil, fmt.Errorf("%s: %w", dec.path, err)
}
// Decode YAML node to a Value
v, err := dec.root(n)
if err != nil {
return nil, fmt.Errorf("%s: %w", dec.path, err)
}
return v, nil
}
// readOneNode reads a single YAML document from r and returns an error if there
// are more documents in r.
func readOneNode(r io.Reader) (*yaml.Node, error) {
yd := yaml.NewDecoder(r)
// Decode as a YAML node
var node yaml.Node
if err := yd.Decode(&node); err != nil {
return nil, err
}
np := &node
if np.Kind == yaml.DocumentNode {
np = node.Content[0]
}
// Ensure there are no more YAML docs in this file
if err := yd.Decode(nil); err == nil {
return nil, fmt.Errorf("must not contain multiple documents")
} else if err != io.EOF {
return nil, err
}
return np, nil
}
// root parses the root of a file.
func (dec *yamlDecoder) root(node *yaml.Node) (*Value, error) {
// Prepare for variable name resolution in this file. This may be a nested
// root, so restore the current values when we're done.
oldVars, oldNSums := dec.vars, dec.nSums
defer func() {
dec.vars, dec.nSums = oldVars, oldNSums
}()
dec.vars = make(map[string]*ident, 0)
dec.nSums = 0
return dec.value(node)
}
// close wraps a decoded [Value] into a [Closure].
func (dec *yamlDecoder) close(v *Value) Closure {
return Closure{v, dec.env}
}
func (dec *yamlDecoder) value(node *yaml.Node) (vOut *Value, errOut error) {
pos := &Pos{Path: dec.path, Line: node.Line}
// Resolve alias nodes.
if node.Kind == yaml.AliasNode {
node = node.Alias
}
mk := func(d Domain) (*Value, error) {
v := &Value{Domain: d, pos: pos}
return v, nil
}
mk2 := func(d Domain, err error) (*Value, error) {
if err != nil {
return nil, err
}
return mk(d)
}
// is tests the kind and long tag of node.
is := func(kind yaml.Kind, tag string) bool {
return node.Kind == kind && node.LongTag() == tag
}
isExact := func() bool {
if node.Kind != yaml.ScalarNode {
return false
}
// We treat any string-ish YAML node as a string.
switch node.LongTag() {
case "!string", "tag:yaml.org,2002:int", "tag:yaml.org,2002:float", "tag:yaml.org,2002:bool", "tag:yaml.org,2002:binary":
return true
}
return false
}
// !!str nodes provide a short-hand syntax for several leaf domains that are
// also available under explicit tags. To simplify checking below, we set
// strVal to non-"" only for !!str nodes.
strVal := ""
isStr := is(yaml.ScalarNode, "tag:yaml.org,2002:str")
if isStr {
strVal = node.Value
}
switch {
case is(yaml.ScalarNode, "!var"):
strVal = "$" + node.Value
fallthrough
case strings.HasPrefix(strVal, "$"):
id, ok := dec.vars[strVal]
if !ok {
// We encode different idents with the same string name by adding a
// #N suffix. Strip that off so it doesn't accumulate. This isn't
// meant to be used in user-written input, though nothing stops that.
name, _, _ := strings.Cut(strVal, "#")
id = &ident{name: name}
dec.vars[strVal] = id
dec.env = dec.env.bind(id, topValue)
}
return mk(Var{id: id})
case strVal == "_" || is(yaml.ScalarNode, "!top"):
return mk(Top{})
case strVal == "_|_" || is(yaml.ScalarNode, "!bottom"):
return nil, errors.New("found bottom")
case isExact():
val := node.Value
return mk(NewStringExact(val))
case isStr || is(yaml.ScalarNode, "!regex"):
// Any other string we treat as a regex. This will produce an exact
// string anyway if the regex is literal.
val := node.Value
return mk2(NewStringRegex(val))
case is(yaml.SequenceNode, "!regex"):
var vals []string
if err := node.Decode(&vals); err != nil {
return nil, err
}
return mk2(NewStringRegex(vals...))
case is(yaml.MappingNode, "tag:yaml.org,2002:map"):
var db DefBuilder
for i := 0; i < len(node.Content); i += 2 {
key := node.Content[i]
if key.Kind != yaml.ScalarNode {
return nil, fmt.Errorf("non-scalar key %q", key.Value)
}
val, err := dec.value(node.Content[i+1])
if err != nil {
return nil, err
}
db.Add(key.Value, val)
}
return mk(db.Build())
case is(yaml.SequenceNode, "tag:yaml.org,2002:seq"):
elts := node.Content
vs := make([]*Value, 0, len(elts))
for _, elt := range elts {
v, err := dec.value(elt)
if err != nil {
return nil, err
}
vs = append(vs, v)
}
return mk(NewTuple(vs...))
case is(yaml.SequenceNode, "!repeat") || is(yaml.SequenceNode, "!repeat-unify"):
// !repeat must have one child. !repeat-unify is used internally for
// delayed unification, and is the same, it's just allowed to have more
// than one child.
if node.LongTag() == "!repeat" && len(node.Content) != 1 {
return nil, fmt.Errorf("!repeat must have exactly one child")
}
// Decode the children to make sure they're well-formed, but otherwise
// discard that decoding and do it again every time we need a new
// element.
var gen []func(e envSet) (*Value, envSet)
origEnv := dec.env
elts := node.Content
for i, elt := range elts {
_, err := dec.value(elt)
if err != nil {
return nil, err
}
// Undo any effects on the environment. We *do* keep any named
// variables that were added to the vars map in case they were
// introduced within the element.
dec.env = origEnv
// Add a generator function
gen = append(gen, func(e envSet) (*Value, envSet) {
dec.env = e
// TODO: If this is in a sum, this tends to generate a ton of
// fresh variables that are different on each branch of the
// parent sum. Does it make sense to hold on to the i'th value
// of the tuple after we've generated it?
v, err := dec.value(elts[i])
if err != nil {
// It worked the first time, so this really shouldn't hapen.
panic("decoding repeat element failed")
}
return v, dec.env
})
}
return mk(NewRepeat(gen...))
case is(yaml.SequenceNode, "!sum"):
vs := make([]*Value, 0, len(node.Content))
for _, elt := range node.Content {
v, err := dec.value(elt)
if err != nil {
return nil, err
}
vs = append(vs, v)
}
if len(vs) == 1 {
return vs[0], nil
}
// A sum is implemented as a fresh variable that's simultaneously bound
// to each of the descendants.
id := &ident{name: fmt.Sprintf("sum%d", dec.nSums)}
dec.nSums++
dec.env = dec.env.bind(id, vs...)
return mk(Var{id: id})
case is(yaml.ScalarNode, "!import"):
if dec.opts.FS == nil {
return nil, fmt.Errorf("!import not allowed (ReadOpts.FS not set)")
}
pat := node.Value
if !fs.ValidPath(pat) {
// This will result in Glob returning no results. Give a more useful
// error message for this case.
return nil, fmt.Errorf("!import path must not contain '.' or '..'")
}
ms, err := fs.Glob(dec.opts.FS, pat)
if err != nil {
return nil, fmt.Errorf("resolving !import: %w", err)
}
if len(ms) == 0 {
return nil, fmt.Errorf("!import did not match any files")
}
// Parse each file
vs := make([]*Value, 0, len(ms))
for _, m := range ms {
v, err := dec.import1(m)
if err != nil {
return nil, err
}
vs = append(vs, v)
}
// Create a sum.
if len(vs) == 1 {
return vs[0], nil
}
id := &ident{name: "import"}
dec.env = dec.env.bind(id, vs...)
return mk(Var{id: id})
}
return nil, fmt.Errorf("unknown node kind %d %v", node.Kind, node.Tag)
}
func (dec *yamlDecoder) import1(path string) (*Value, error) {
// Make sure we can open the path first.
f, err := dec.opts.FS.Open(path)
if err != nil {
return nil, fmt.Errorf("!import failed: %w", err)
}
defer f.Close()
// Prepare the enter path.
oldFS, oldPath := dec.opts.FS, dec.path
defer func() {
dec.opts.FS, dec.path = oldFS, oldPath
}()
// Enter path, which is relative to the current path's directory.
newPath := filepath.Join(filepath.Dir(dec.path), path)
subFS, err := fs.Sub(dec.opts.FS, filepath.Dir(path))
if err != nil {
return nil, err
}
dec.opts.FS, dec.path = subFS, newPath
// Parse the file.
return dec.read(f)
}
type yamlEncoder struct {
idp identPrinter
e envSet // We track the environment for !repeat nodes.
}
// TODO: Switch some Value marshaling to Closure?
func (c Closure) MarshalYAML() (any, error) {
// TODO: If the environment is trivial, just marshal the value.
enc := &yamlEncoder{}
return enc.closure(c), nil
}
func (c Closure) String() string {
b, err := yaml.Marshal(c)
if err != nil {
return fmt.Sprintf("marshal failed: %s", err)
}
return string(b)
}
func (v *Value) MarshalYAML() (any, error) {
enc := &yamlEncoder{}
return enc.value(v), nil
}
func (v *Value) String() string {
b, err := yaml.Marshal(v)
if err != nil {
return fmt.Sprintf("marshal failed: %s", err)
}
return string(b)
}
func (enc *yamlEncoder) closure(c Closure) *yaml.Node {
enc.e = c.env
var n yaml.Node
n.Kind = yaml.MappingNode
n.Tag = "!closure"
n.Content = make([]*yaml.Node, 4)
n.Content[0] = new(yaml.Node)
n.Content[0].SetString("env")
n.Content[2] = new(yaml.Node)
n.Content[2].SetString("in")
n.Content[3] = enc.value(c.val)
// Fill in the env after we've written the value in case value encoding
// affects the env.
n.Content[1] = enc.env(enc.e)
enc.e = envSet{} // Allow GC'ing the env
return &n
}
func (enc *yamlEncoder) env(e envSet) *yaml.Node {
var encode func(e *envExpr) *yaml.Node
encode = func(e *envExpr) *yaml.Node {
var n yaml.Node
switch e.kind {
default:
panic("bad kind")
case envZero:
n.SetString("0")
case envUnit:
n.SetString("1")
case envBinding:
var id yaml.Node
id.SetString(enc.idp.unique(e.id))
n.Kind = yaml.MappingNode
n.Content = []*yaml.Node{&id, enc.value(e.val)}
case envProduct, envSum:
n.Kind = yaml.SequenceNode
if e.kind == envProduct {
n.Tag = "!product"
} else {
n.Tag = "!sum"
}
for _, e2 := range e.operands {
n.Content = append(n.Content, encode(e2))
}
}
return &n
}
return encode(e.root)
}
var yamlIntRe = regexp.MustCompile(`^-?[0-9]+$`)
func (enc *yamlEncoder) value(v *Value) *yaml.Node {
var n yaml.Node
switch d := v.Domain.(type) {
case nil:
// Not allowed by unmarshaler, but useful for understanding when
// something goes horribly wrong.
//
// TODO: We might be able to track useful provenance for this, which
// would really help with debugging unexpected bottoms.
n.SetString("_|_")
return &n
case Top:
n.SetString("_")
return &n
case Def:
n.Kind = yaml.MappingNode
for k, elt := range d.All() {
var kn yaml.Node
kn.SetString(k)
n.Content = append(n.Content, &kn, enc.value(elt))
}
n.HeadComment = v.PosString()
return &n
case Tuple:
n.Kind = yaml.SequenceNode
if d.repeat == nil {
for _, elt := range d.vs {
n.Content = append(n.Content, enc.value(elt))
}
} else {
if len(d.repeat) == 1 {
n.Tag = "!repeat"
} else {
n.Tag = "!repeat-unify"
}
// TODO: I'm not positive this will round-trip everything correctly.
for _, gen := range d.repeat {
v, e := gen(enc.e)
enc.e = e
n.Content = append(n.Content, enc.value(v))
}
}
return &n
case String:
switch d.kind {
case stringExact:
n.SetString(d.exact)
switch {
// Make this into a "nice" !!int node if I can.
case yamlIntRe.MatchString(d.exact):
n.Tag = "tag:yaml.org,2002:int"
// Or a "nice" !!bool node.
case d.exact == "false" || d.exact == "true":
n.Tag = "tag:yaml.org,2002:bool"
// If this doesn't require escaping, leave it as a str node to avoid
// the annoying YAML tags. Otherwise, mark it as an exact string.
// Alternatively, we could always emit a str node with regexp
// quoting.
case d.exact != regexp.QuoteMeta(d.exact):
n.Tag = "!string"
}
return &n
case stringRegex:
o := make([]string, 0, 1)
for _, re := range d.re {
s := re.String()
s = strings.TrimSuffix(strings.TrimPrefix(s, `\A(?:`), `)\z`)
o = append(o, s)
}
if len(o) == 1 {
n.SetString(o[0])
return &n
}
n.Encode(o)
n.Tag = "!regex"
return &n
}
panic("bad String kind")
case Var:
// TODO: If Var only appears once in the whole Value and is independent
// in the environment (part of a term that is only over Var), then emit
// this as a !sum instead.
if false {
var vs []*Value // TODO: Get values of this var.
if len(vs) == 1 {
return enc.value(vs[0])
}
n.Kind = yaml.SequenceNode
n.Tag = "!sum"
for _, elt := range vs {
n.Content = append(n.Content, enc.value(elt))
}
return &n
}
n.SetString(enc.idp.unique(d.id))
if !strings.HasPrefix(d.id.name, "$") {
n.Tag = "!var"
}
return &n
}
panic(fmt.Sprintf("unknown domain type %T", v.Domain))
}

View file

@ -0,0 +1,202 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"bytes"
"fmt"
"iter"
"log"
"strings"
"testing"
"testing/fstest"
"gopkg.in/yaml.v3"
)
func mustParse(expr string) Closure {
var c Closure
if err := yaml.Unmarshal([]byte(expr), &c); err != nil {
panic(err)
}
return c
}
func oneValue(t *testing.T, c Closure) *Value {
t.Helper()
var v *Value
var i int
for v = range c.All() {
i++
}
if i != 1 {
t.Fatalf("expected 1 value, got %d", i)
}
return v
}
func printYaml(val any) {
fmt.Println(prettyYaml(val))
}
func prettyYaml(val any) string {
b, err := yaml.Marshal(val)
if err != nil {
panic(err)
}
var node yaml.Node
if err := yaml.Unmarshal(b, &node); err != nil {
panic(err)
}
// Map lines to start offsets. We'll use this to figure out when nodes are
// "small" and should use inline style.
lines := []int{-1, 0}
for pos := 0; pos < len(b); {
next := bytes.IndexByte(b[pos:], '\n')
if next == -1 {
break
}
pos += next + 1
lines = append(lines, pos)
}
lines = append(lines, len(b))
// Strip comments and switch small nodes to inline style
cleanYaml(&node, lines, len(b))
b, err = yaml.Marshal(&node)
if err != nil {
panic(err)
}
return string(b)
}
func cleanYaml(node *yaml.Node, lines []int, endPos int) {
node.HeadComment = ""
node.FootComment = ""
node.LineComment = ""
for i, n2 := range node.Content {
end2 := endPos
if i < len(node.Content)-1 {
end2 = lines[node.Content[i+1].Line]
}
cleanYaml(n2, lines, end2)
}
// Use inline style?
switch node.Kind {
case yaml.MappingNode, yaml.SequenceNode:
if endPos-lines[node.Line] < 40 {
node.Style = yaml.FlowStyle
}
}
}
func allYamlNodes(n *yaml.Node) iter.Seq[*yaml.Node] {
return func(yield func(*yaml.Node) bool) {
if !yield(n) {
return
}
for _, n2 := range n.Content {
for n3 := range allYamlNodes(n2) {
if !yield(n3) {
return
}
}
}
}
}
func TestRoundTripString(t *testing.T) {
// Check that we can round-trip a string with regexp meta-characters in it.
const y = `!string test*`
t.Logf("input:\n%s", y)
v1 := oneValue(t, mustParse(y))
var buf1 strings.Builder
enc := yaml.NewEncoder(&buf1)
if err := enc.Encode(v1); err != nil {
log.Fatal(err)
}
enc.Close()
t.Logf("after parse 1:\n%s", buf1.String())
v2 := oneValue(t, mustParse(buf1.String()))
var buf2 strings.Builder
enc = yaml.NewEncoder(&buf2)
if err := enc.Encode(v2); err != nil {
log.Fatal(err)
}
enc.Close()
t.Logf("after parse 2:\n%s", buf2.String())
if buf1.String() != buf2.String() {
t.Fatal("parse 1 and parse 2 differ")
}
}
func TestEmptyString(t *testing.T) {
// Regression test. Make sure an empty string is parsed as an exact string,
// not a regexp.
const y = `""`
t.Logf("input:\n%s", y)
v1 := oneValue(t, mustParse(y))
if !v1.Exact() {
t.Fatal("expected exact string")
}
}
func TestImport(t *testing.T) {
// Test a basic import
main := strings.NewReader("!import x/y.yaml")
fs := fstest.MapFS{
// Test a glob import with a relative path
"x/y.yaml": {Data: []byte("!import y/*.yaml")},
"x/y/z.yaml": {Data: []byte("42")},
}
cl, err := Read(main, "x.yaml", ReadOpts{FS: fs})
if err != nil {
t.Fatal(err)
}
x := 42
checkDecode(t, oneValue(t, cl), &x)
}
func TestImportEscape(t *testing.T) {
// Make sure an import can't escape its subdirectory.
main := strings.NewReader("!import x/y.yaml")
fs := fstest.MapFS{
"x/y.yaml": {Data: []byte("!import ../y/*.yaml")},
"y/z.yaml": {Data: []byte("42")},
}
_, err := Read(main, "x.yaml", ReadOpts{FS: fs})
if err == nil {
t.Fatal("relative !import should have failed")
}
if !strings.Contains(err.Error(), "must not contain") {
t.Fatalf("unexpected error %v", err)
}
}
func TestImportScope(t *testing.T) {
// Test that imports have different variable scopes.
main := strings.NewReader("[!import y.yaml, !import y.yaml]")
fs := fstest.MapFS{
"y.yaml": {Data: []byte("$v")},
}
cl1, err := Read(main, "x.yaml", ReadOpts{FS: fs})
if err != nil {
t.Fatal(err)
}
cl2 := mustParse("[1, 2]")
res, err := Unify(cl1, cl2)
if err != nil {
t.Fatal(err)
}
checkDecode(t, oneValue(t, res), []int{1, 2})
}