mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
[dev.simd] simd/_gen: migrate simdgen from x/arch
This moves the simdgen tool and its supporting unify package from golang.org/x/arch/internal as of CL 695619 to simd/_gen in the main repo. The simdgen tool was started in x/arch to live next to xeddata and a few other assembler generators that already lived there. However, as we've been developing simdgen, we've discovered that there's a tremendous amount of process friction coordinating commits to x/arch with the corresponding generated files in the main repo. Many of the existing generators in x/arch were started before modules existed. In GOPATH world, it was impractical for them to live in the main repo because they have dependencies that are not allowed in the main repo. However, now that we have modules and can use small submodules in the main repo, we can isolate these dependencies to just the generators, making it practical for them to live in the main repo. This commit was generated by the following script: # Checks set -e if [[ ! -d src/simd ]]; then echo >&2 "$PWD is not the root of the main repo on dev.simd" exit 1 fi if [[ -z "$XEDDATA" ]]; then echo >&2 "Must set \$XEDDATA" exit 1 fi which go >/dev/null # Move simdgen from x/arch xarch=$(mktemp -d) git clone https://go.googlesource.com/arch $xarch xarchCL=$(git -C $xarch log -1 --format=%b | awk -F/ '/^Reviewed-on:/ {print $NF}') echo >&2 "x/arch CL: $xarchCL" mv $xarch/internal src/simd/_gen sed --in-place s,golang.org/x/arch/internal/,simd/_gen/, src/simd/_gen/*/*.go # Create self-contained module cat > src/simd/_gen/go.mod <<EOF module simd/_gen go 1.24 EOF cd src/simd/_gen go mod tidy git add . git gofmt # Regenerate file go run -C simdgen . -xedPath $XEDDATA -o godefs -goroot $(go env GOROOT) go.yaml types.yaml categories.yaml go run -C ../../cmd/compile/internal/ssa/_gen . Change-Id: I56dd8473e913a9eb1978d9b3b3518ed632972f6f Reviewed-on: https://go-review.googlesource.com/c/go/+/695975 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: David Chase <drchase@google.com>
This commit is contained in:
parent
257c1356ec
commit
b7c8698549
60 changed files with 9083 additions and 0 deletions
8
src/simd/_gen/go.mod
Normal file
8
src/simd/_gen/go.mod
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
module simd/_gen
|
||||||
|
|
||||||
|
go 1.24
|
||||||
|
|
||||||
|
require (
|
||||||
|
golang.org/x/arch v0.20.0
|
||||||
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
|
)
|
||||||
6
src/simd/_gen/go.sum
Normal file
6
src/simd/_gen/go.sum
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c=
|
||||||
|
golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
3
src/simd/_gen/simdgen/.gitignore
vendored
Normal file
3
src/simd/_gen/simdgen/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
testdata/*
|
||||||
|
.gemini/*
|
||||||
|
.gemini*
|
||||||
107
src/simd/_gen/simdgen/asm.yaml.toy
Normal file
107
src/simd/_gen/simdgen/asm.yaml.toy
Normal file
|
|
@ -0,0 +1,107 @@
|
||||||
|
# Hand-written toy input like -xedPath would generate.
|
||||||
|
# This input can be substituted for -xedPath.
|
||||||
|
!sum
|
||||||
|
- asm: ADDPS
|
||||||
|
goarch: amd64
|
||||||
|
feature: "SSE2"
|
||||||
|
in:
|
||||||
|
- asmPos: 0
|
||||||
|
class: vreg
|
||||||
|
base: float
|
||||||
|
elemBits: 32
|
||||||
|
bits: 128
|
||||||
|
- asmPos: 1
|
||||||
|
class: vreg
|
||||||
|
base: float
|
||||||
|
elemBits: 32
|
||||||
|
bits: 128
|
||||||
|
out:
|
||||||
|
- asmPos: 0
|
||||||
|
class: vreg
|
||||||
|
base: float
|
||||||
|
elemBits: 32
|
||||||
|
bits: 128
|
||||||
|
|
||||||
|
- asm: ADDPD
|
||||||
|
goarch: amd64
|
||||||
|
feature: "SSE2"
|
||||||
|
in:
|
||||||
|
- asmPos: 0
|
||||||
|
class: vreg
|
||||||
|
base: float
|
||||||
|
elemBits: 64
|
||||||
|
bits: 128
|
||||||
|
- asmPos: 1
|
||||||
|
class: vreg
|
||||||
|
base: float
|
||||||
|
elemBits: 64
|
||||||
|
bits: 128
|
||||||
|
out:
|
||||||
|
- asmPos: 0
|
||||||
|
class: vreg
|
||||||
|
base: float
|
||||||
|
elemBits: 64
|
||||||
|
bits: 128
|
||||||
|
|
||||||
|
- asm: PADDB
|
||||||
|
goarch: amd64
|
||||||
|
feature: "SSE2"
|
||||||
|
in:
|
||||||
|
- asmPos: 0
|
||||||
|
class: vreg
|
||||||
|
base: int|uint
|
||||||
|
elemBits: 32
|
||||||
|
bits: 128
|
||||||
|
- asmPos: 1
|
||||||
|
class: vreg
|
||||||
|
base: int|uint
|
||||||
|
elemBits: 32
|
||||||
|
bits: 128
|
||||||
|
out:
|
||||||
|
- asmPos: 0
|
||||||
|
class: vreg
|
||||||
|
base: int|uint
|
||||||
|
elemBits: 32
|
||||||
|
bits: 128
|
||||||
|
|
||||||
|
- asm: VPADDB
|
||||||
|
goarch: amd64
|
||||||
|
feature: "AVX"
|
||||||
|
in:
|
||||||
|
- asmPos: 1
|
||||||
|
class: vreg
|
||||||
|
base: int|uint
|
||||||
|
elemBits: 8
|
||||||
|
bits: 128
|
||||||
|
- asmPos: 2
|
||||||
|
class: vreg
|
||||||
|
base: int|uint
|
||||||
|
elemBits: 8
|
||||||
|
bits: 128
|
||||||
|
out:
|
||||||
|
- asmPos: 0
|
||||||
|
class: vreg
|
||||||
|
base: int|uint
|
||||||
|
elemBits: 8
|
||||||
|
bits: 128
|
||||||
|
|
||||||
|
- asm: VPADDB
|
||||||
|
goarch: amd64
|
||||||
|
feature: "AVX2"
|
||||||
|
in:
|
||||||
|
- asmPos: 1
|
||||||
|
class: vreg
|
||||||
|
base: int|uint
|
||||||
|
elemBits: 8
|
||||||
|
bits: 256
|
||||||
|
- asmPos: 2
|
||||||
|
class: vreg
|
||||||
|
base: int|uint
|
||||||
|
elemBits: 8
|
||||||
|
bits: 256
|
||||||
|
out:
|
||||||
|
- asmPos: 0
|
||||||
|
class: vreg
|
||||||
|
base: int|uint
|
||||||
|
elemBits: 8
|
||||||
|
bits: 256
|
||||||
1
src/simd/_gen/simdgen/categories.yaml
Normal file
1
src/simd/_gen/simdgen/categories.yaml
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
!import ops/*/categories.yaml
|
||||||
33
src/simd/_gen/simdgen/etetest.sh
Executable file
33
src/simd/_gen/simdgen/etetest.sh
Executable file
|
|
@ -0,0 +1,33 @@
|
||||||
|
#!/bin/bash -x
|
||||||
|
|
||||||
|
cat <<\\EOF
|
||||||
|
|
||||||
|
This is an end-to-end test of Go SIMD. It checks out a fresh Go
|
||||||
|
repository from the go.simd branch, then generates the SIMD input
|
||||||
|
files and runs simdgen writing into the fresh repository.
|
||||||
|
|
||||||
|
After that it generates the modified ssa pattern matching files, then
|
||||||
|
builds the compiler.
|
||||||
|
|
||||||
|
\EOF
|
||||||
|
|
||||||
|
rm -rf go-test
|
||||||
|
git clone https://go.googlesource.com/go -b dev.simd go-test
|
||||||
|
go run . -xedPath xeddata -o godefs -goroot ./go-test go.yaml types.yaml categories.yaml
|
||||||
|
(cd go-test/src/cmd/compile/internal/ssa/_gen ; go run *.go )
|
||||||
|
(cd go-test/src ; GOEXPERIMENT=simd ./make.bash )
|
||||||
|
(cd go-test/bin; b=`pwd` ; cd ../src/simd/testdata; GOARCH=amd64 $b/go run .)
|
||||||
|
(cd go-test/bin; b=`pwd` ; cd ../src ;
|
||||||
|
GOEXPERIMENT=simd GOARCH=amd64 $b/go test -v simd
|
||||||
|
GOEXPERIMENT=simd $b/go test go/doc
|
||||||
|
GOEXPERIMENT=simd $b/go test go/build
|
||||||
|
GOEXPERIMENT=simd $b/go test cmd/api -v -check
|
||||||
|
$b/go test go/doc
|
||||||
|
$b/go test go/build
|
||||||
|
$b/go test cmd/api -v -check
|
||||||
|
|
||||||
|
$b/go test cmd/compile/internal/ssagen -simd=0
|
||||||
|
GOEXPERIMENT=simd $b/go test cmd/compile/internal/ssagen -simd=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# next, add some tests of SIMD itself
|
||||||
70
src/simd/_gen/simdgen/gen_simdGenericOps.go
Normal file
70
src/simd/_gen/simdgen/gen_simdGenericOps.go
Normal file
|
|
@ -0,0 +1,70 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
const simdGenericOpsTmpl = `
|
||||||
|
package main
|
||||||
|
|
||||||
|
func simdGenericOps() []opData {
|
||||||
|
return []opData{
|
||||||
|
{{- range .Ops }}
|
||||||
|
{name: "{{.OpName}}", argLength: {{.OpInLen}}, commutative: {{.Comm}}},
|
||||||
|
{{- end }}
|
||||||
|
{{- range .OpsImm }}
|
||||||
|
{name: "{{.OpName}}", argLength: {{.OpInLen}}, commutative: {{.Comm}}, aux: "UInt8"},
|
||||||
|
{{- end }}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
// writeSIMDGenericOps generates the generic ops and writes it to simdAMD64ops.go
|
||||||
|
// within the specified directory.
|
||||||
|
func writeSIMDGenericOps(ops []Operation) *bytes.Buffer {
|
||||||
|
t := templateOf(simdGenericOpsTmpl, "simdgenericOps")
|
||||||
|
buffer := new(bytes.Buffer)
|
||||||
|
buffer.WriteString(generatedHeader)
|
||||||
|
|
||||||
|
type genericOpsData struct {
|
||||||
|
OpName string
|
||||||
|
OpInLen int
|
||||||
|
Comm bool
|
||||||
|
}
|
||||||
|
type opData struct {
|
||||||
|
Ops []genericOpsData
|
||||||
|
OpsImm []genericOpsData
|
||||||
|
}
|
||||||
|
var opsData opData
|
||||||
|
for _, op := range ops {
|
||||||
|
if op.NoGenericOps != nil && *op.NoGenericOps == "true" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, _, _, immType, gOp := op.shape()
|
||||||
|
gOpData := genericOpsData{gOp.GenericName(), len(gOp.In), op.Commutative}
|
||||||
|
if immType == VarImm || immType == ConstVarImm {
|
||||||
|
opsData.OpsImm = append(opsData.OpsImm, gOpData)
|
||||||
|
} else {
|
||||||
|
opsData.Ops = append(opsData.Ops, gOpData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Slice(opsData.Ops, func(i, j int) bool {
|
||||||
|
return compareNatural(opsData.Ops[i].OpName, opsData.Ops[j].OpName) < 0
|
||||||
|
})
|
||||||
|
sort.Slice(opsData.OpsImm, func(i, j int) bool {
|
||||||
|
return compareNatural(opsData.OpsImm[i].OpName, opsData.OpsImm[j].OpName) < 0
|
||||||
|
})
|
||||||
|
|
||||||
|
err := t.Execute(buffer, opsData)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute template: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
151
src/simd/_gen/simdgen/gen_simdIntrinsics.go
Normal file
151
src/simd/_gen/simdgen/gen_simdIntrinsics.go
Normal file
|
|
@ -0,0 +1,151 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
const simdIntrinsicsTmpl = `
|
||||||
|
{{define "header"}}
|
||||||
|
package ssagen
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmd/compile/internal/ir"
|
||||||
|
"cmd/compile/internal/ssa"
|
||||||
|
"cmd/compile/internal/types"
|
||||||
|
"cmd/internal/sys"
|
||||||
|
)
|
||||||
|
|
||||||
|
const simdPackage = "` + simdPackage + `"
|
||||||
|
|
||||||
|
func simdIntrinsics(addF func(pkg, fn string, b intrinsicBuilder, archFamilies ...sys.ArchFamily)) {
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op1"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen1(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op2"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen2(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op2_21"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen2_21(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op2_21Type1"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen2_21(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op3"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen3(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op3_21"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen3_21(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op3_21Type1"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen3_21(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op3_231Type1"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen3_231(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op3_31"}} addF(simdPackage, "{{(index .In 2).Go}}.{{.Go}}", opLen3_31(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op4"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen4(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op4_231Type1"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen4_231(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op4_31"}} addF(simdPackage, "{{(index .In 2).Go}}.{{.Go}}", opLen4_31(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op1Imm8"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen1Imm8(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op2Imm8"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen2Imm8(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op2Imm8_2I"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen2Imm8_2I(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op3Imm8"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen3Imm8(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op3Imm8_2I"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen3Imm8_2I(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
{{define "op4Imm8"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen4Imm8(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "vectorConversion"}} addF(simdPackage, "{{.Tsrc.Name}}.As{{.Tdst.Name}}", func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return args[0] }, sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "loadStore"}} addF(simdPackage, "Load{{.Name}}", simdLoad(), sys.AMD64)
|
||||||
|
addF(simdPackage, "{{.Name}}.Store", simdStore(), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "maskedLoadStore"}} addF(simdPackage, "LoadMasked{{.Name}}", simdMaskedLoad(ssa.OpLoadMasked{{.ElemBits}}), sys.AMD64)
|
||||||
|
addF(simdPackage, "{{.Name}}.StoreMasked", simdMaskedStore(ssa.OpStoreMasked{{.ElemBits}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "mask"}} addF(simdPackage, "{{.Name}}.As{{.VectorCounterpart}}", func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return args[0] }, sys.AMD64)
|
||||||
|
addF(simdPackage, "{{.VectorCounterpart}}.As{{.Name}}", func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return args[0] }, sys.AMD64)
|
||||||
|
addF(simdPackage, "{{.Name}}.And", opLen2(ssa.OpAnd{{.ReshapedVectorWithAndOr}}, types.TypeVec{{.Size}}), sys.AMD64)
|
||||||
|
addF(simdPackage, "{{.Name}}.Or", opLen2(ssa.OpOr{{.ReshapedVectorWithAndOr}}, types.TypeVec{{.Size}}), sys.AMD64)
|
||||||
|
addF(simdPackage, "Load{{.Name}}FromBits", simdLoadMask({{.ElemBits}}, {{.Lanes}}), sys.AMD64)
|
||||||
|
addF(simdPackage, "{{.Name}}.StoreToBits", simdStoreMask({{.ElemBits}}, {{.Lanes}}), sys.AMD64)
|
||||||
|
addF(simdPackage, "{{.Name}}FromBits", simdCvtVToMask({{.ElemBits}}, {{.Lanes}}), sys.AMD64)
|
||||||
|
addF(simdPackage, "{{.Name}}.ToBits", simdCvtMaskToV({{.ElemBits}}, {{.Lanes}}), sys.AMD64)
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "footer"}}}
|
||||||
|
{{end}}
|
||||||
|
`
|
||||||
|
|
||||||
|
// writeSIMDIntrinsics generates the intrinsic mappings and writes it to simdintrinsics.go
|
||||||
|
// within the specified directory.
|
||||||
|
func writeSIMDIntrinsics(ops []Operation, typeMap simdTypeMap) *bytes.Buffer {
|
||||||
|
t := templateOf(simdIntrinsicsTmpl, "simdintrinsics")
|
||||||
|
buffer := new(bytes.Buffer)
|
||||||
|
buffer.WriteString(generatedHeader)
|
||||||
|
|
||||||
|
if err := t.ExecuteTemplate(buffer, "header", nil); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute header template: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(ops, compareOperations)
|
||||||
|
|
||||||
|
for _, op := range ops {
|
||||||
|
if op.NoTypes != nil && *op.NoTypes == "true" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s, op, err := classifyOp(op); err == nil {
|
||||||
|
if err := t.ExecuteTemplate(buffer, s, op); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute template %s for op %s: %w", s, op.Go, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("failed to classify op %v: %w", op.Go, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, conv := range vConvertFromTypeMap(typeMap) {
|
||||||
|
if err := t.ExecuteTemplate(buffer, "vectorConversion", conv); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute vectorConversion template: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, typ := range typesFromTypeMap(typeMap) {
|
||||||
|
if typ.Type != "mask" {
|
||||||
|
if err := t.ExecuteTemplate(buffer, "loadStore", typ); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute loadStore template: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, typ := range typesFromTypeMap(typeMap) {
|
||||||
|
if typ.MaskedLoadStoreFilter() {
|
||||||
|
if err := t.ExecuteTemplate(buffer, "maskedLoadStore", typ); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute maskedLoadStore template: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, mask := range masksFromTypeMap(typeMap) {
|
||||||
|
if err := t.ExecuteTemplate(buffer, "mask", mask); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute mask template: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.ExecuteTemplate(buffer, "footer", nil); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute footer template: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
122
src/simd/_gen/simdgen/gen_simdMachineOps.go
Normal file
122
src/simd/_gen/simdgen/gen_simdMachineOps.go
Normal file
|
|
@ -0,0 +1,122 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const simdMachineOpsTmpl = `
|
||||||
|
package main
|
||||||
|
|
||||||
|
func simdAMD64Ops(v11, v21, v2k, vkv, v2kv, v2kk, v31, v3kv, vgpv, vgp, vfpv, vfpkv, w11, w21, w2k, wkw, w2kw, w2kk, w31, w3kw, wgpw, wgp, wfpw, wfpkw regInfo) []opData {
|
||||||
|
return []opData{
|
||||||
|
{{- range .OpsData }}
|
||||||
|
{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
|
||||||
|
{{- end }}
|
||||||
|
{{- range .OpsDataImm }}
|
||||||
|
{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
|
||||||
|
{{- end }}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
// writeSIMDMachineOps generates the machine ops and writes it to simdAMD64ops.go
|
||||||
|
// within the specified directory.
|
||||||
|
func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
|
||||||
|
t := templateOf(simdMachineOpsTmpl, "simdAMD64Ops")
|
||||||
|
buffer := new(bytes.Buffer)
|
||||||
|
buffer.WriteString(generatedHeader)
|
||||||
|
|
||||||
|
type opData struct {
|
||||||
|
OpName string
|
||||||
|
Asm string
|
||||||
|
OpInLen int
|
||||||
|
RegInfo string
|
||||||
|
Comm bool
|
||||||
|
Type string
|
||||||
|
ResultInArg0 bool
|
||||||
|
}
|
||||||
|
type machineOpsData struct {
|
||||||
|
OpsData []opData
|
||||||
|
OpsDataImm []opData
|
||||||
|
}
|
||||||
|
seen := map[string]struct{}{}
|
||||||
|
regInfoSet := map[string]bool{
|
||||||
|
"v11": true, "v21": true, "v2k": true, "v2kv": true, "v2kk": true, "vkv": true, "v31": true, "v3kv": true, "vgpv": true, "vgp": true, "vfpv": true, "vfpkv": true,
|
||||||
|
"w11": true, "w21": true, "w2k": true, "w2kw": true, "w2kk": true, "wkw": true, "w31": true, "w3kw": true, "wgpw": true, "wgp": true, "wfpw": true, "wfpkw": true}
|
||||||
|
opsData := make([]opData, 0)
|
||||||
|
opsDataImm := make([]opData, 0)
|
||||||
|
for _, op := range ops {
|
||||||
|
shapeIn, shapeOut, maskType, _, gOp := op.shape()
|
||||||
|
asm := machineOpName(maskType, gOp)
|
||||||
|
|
||||||
|
// TODO: all our masked operations are now zeroing, we need to generate machine ops with merging masks, maybe copy
|
||||||
|
// one here with a name suffix "Merging". The rewrite rules will need them.
|
||||||
|
if _, ok := seen[asm]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[asm] = struct{}{}
|
||||||
|
regInfo, err := op.regShape()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
idx, err := checkVecAsScalar(op)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if idx != -1 {
|
||||||
|
if regInfo == "v21" {
|
||||||
|
regInfo = "vfpv"
|
||||||
|
} else if regInfo == "v2kv" {
|
||||||
|
regInfo = "vfpkv"
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("simdgen does not recognize uses of treatLikeAScalarOfSize with op regShape %s in op: %s", regInfo, op))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Makes AVX512 operations use upper registers
|
||||||
|
if strings.Contains(op.CPUFeature, "AVX512") {
|
||||||
|
regInfo = strings.ReplaceAll(regInfo, "v", "w")
|
||||||
|
}
|
||||||
|
if _, ok := regInfoSet[regInfo]; !ok {
|
||||||
|
panic(fmt.Errorf("unsupported register constraint, please update the template and AMD64Ops.go: %s. Op is %s", regInfo, op))
|
||||||
|
}
|
||||||
|
var outType string
|
||||||
|
if shapeOut == OneVregOut || shapeOut == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
|
||||||
|
// If class overwrite is happening, that's not really a mask but a vreg.
|
||||||
|
outType = fmt.Sprintf("Vec%d", *gOp.Out[0].Bits)
|
||||||
|
} else if shapeOut == OneGregOut {
|
||||||
|
outType = gOp.GoType() // this is a straight Go type, not a VecNNN type
|
||||||
|
} else if shapeOut == OneKmaskOut {
|
||||||
|
outType = "Mask"
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("simdgen does not recognize this output shape: %d", shapeOut))
|
||||||
|
}
|
||||||
|
resultInArg0 := false
|
||||||
|
if shapeOut == OneVregOutAtIn {
|
||||||
|
resultInArg0 = true
|
||||||
|
}
|
||||||
|
if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
|
||||||
|
opsDataImm = append(opsDataImm, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
|
||||||
|
} else {
|
||||||
|
opsData = append(opsData, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Slice(opsData, func(i, j int) bool {
|
||||||
|
return compareNatural(opsData[i].OpName, opsData[j].OpName) < 0
|
||||||
|
})
|
||||||
|
sort.Slice(opsDataImm, func(i, j int) bool {
|
||||||
|
return compareNatural(opsData[i].OpName, opsData[j].OpName) < 0
|
||||||
|
})
|
||||||
|
err := t.Execute(buffer, machineOpsData{opsData, opsDataImm})
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute template: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
631
src/simd/_gen/simdgen/gen_simdTypes.go
Normal file
631
src/simd/_gen/simdgen/gen_simdTypes.go
Normal file
|
|
@ -0,0 +1,631 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type simdType struct {
|
||||||
|
Name string // The go type name of this simd type, for example Int32x4.
|
||||||
|
Lanes int // The number of elements in this vector/mask.
|
||||||
|
Base string // The element's type, like for Int32x4 it will be int32.
|
||||||
|
Fields string // The struct fields, it should be right formatted.
|
||||||
|
Type string // Either "mask" or "vreg"
|
||||||
|
VectorCounterpart string // For mask use only: just replacing the "Mask" in [simdType.Name] with "Int"
|
||||||
|
ReshapedVectorWithAndOr string // For mask use only: vector AND and OR are only available in some shape with element width 32.
|
||||||
|
Size int // The size of the vector type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x simdType) ElemBits() int {
|
||||||
|
return x.Size / x.Lanes
|
||||||
|
}
|
||||||
|
|
||||||
|
// LanesContainer returns the smallest int/uint bit size that is
|
||||||
|
// large enough to hold one bit for each lane. E.g., Mask32x4
|
||||||
|
// is 4 lanes, and a uint8 is the smallest uint that has 4 bits.
|
||||||
|
func (x simdType) LanesContainer() int {
|
||||||
|
if x.Lanes > 64 {
|
||||||
|
panic("too many lanes")
|
||||||
|
}
|
||||||
|
if x.Lanes > 32 {
|
||||||
|
return 64
|
||||||
|
}
|
||||||
|
if x.Lanes > 16 {
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
if x.Lanes > 8 {
|
||||||
|
return 16
|
||||||
|
}
|
||||||
|
return 8
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaskedLoadStoreFilter encodes which simd type type currently
|
||||||
|
// get masked loads/stores generated, it is used in two places,
|
||||||
|
// this forces coordination.
|
||||||
|
func (x simdType) MaskedLoadStoreFilter() bool {
|
||||||
|
return x.Size == 512 || x.ElemBits() >= 32 && x.Type != "mask"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x simdType) IntelSizeSuffix() string {
|
||||||
|
switch x.ElemBits() {
|
||||||
|
case 8:
|
||||||
|
return "B"
|
||||||
|
case 16:
|
||||||
|
return "W"
|
||||||
|
case 32:
|
||||||
|
return "D"
|
||||||
|
case 64:
|
||||||
|
return "Q"
|
||||||
|
}
|
||||||
|
panic("oops")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x simdType) MaskedLoadDoc() string {
|
||||||
|
if x.Size == 512 || x.ElemBits() < 32 {
|
||||||
|
return fmt.Sprintf("// Asm: VMOVDQU%d.Z, CPU Feature: AVX512", x.ElemBits())
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("// Asm: VMASKMOV%s, CPU Feature: AVX2", x.IntelSizeSuffix())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x simdType) MaskedStoreDoc() string {
|
||||||
|
if x.Size == 512 || x.ElemBits() < 32 {
|
||||||
|
return fmt.Sprintf("// Asm: VMOVDQU%d, CPU Feature: AVX512", x.ElemBits())
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("// Asm: VMASKMOV%s, CPU Feature: AVX2", x.IntelSizeSuffix())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareSimdTypes(x, y simdType) int {
|
||||||
|
// "vreg" then "mask"
|
||||||
|
if c := -compareNatural(x.Type, y.Type); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
// want "flo" < "int" < "uin" (and then 8 < 16 < 32 < 64),
|
||||||
|
// not "int16" < "int32" < "int64" < "int8")
|
||||||
|
// so limit comparison to first 3 bytes in string.
|
||||||
|
if c := compareNatural(x.Base[:3], y.Base[:3]); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
// base type size, 8 < 16 < 32 < 64
|
||||||
|
if c := x.ElemBits() - y.ElemBits(); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
// vector size last
|
||||||
|
return x.Size - y.Size
|
||||||
|
}
|
||||||
|
|
||||||
|
type simdTypeMap map[int][]simdType
|
||||||
|
|
||||||
|
type simdTypePair struct {
|
||||||
|
Tsrc simdType
|
||||||
|
Tdst simdType
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareSimdTypePairs(x, y simdTypePair) int {
|
||||||
|
c := compareSimdTypes(x.Tsrc, y.Tsrc)
|
||||||
|
if c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return compareSimdTypes(x.Tdst, y.Tdst)
|
||||||
|
}
|
||||||
|
|
||||||
|
const simdPackageHeader = generatedHeader + `
|
||||||
|
//go:build goexperiment.simd
|
||||||
|
|
||||||
|
package simd
|
||||||
|
`
|
||||||
|
|
||||||
|
const simdTypesTemplates = `
|
||||||
|
{{define "sizeTmpl"}}
|
||||||
|
// v{{.}} is a tag type that tells the compiler that this is really {{.}}-bit SIMD
|
||||||
|
type v{{.}} struct {
|
||||||
|
_{{.}} struct{}
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "typeTmpl"}}
|
||||||
|
// {{.Name}} is a {{.Size}}-bit SIMD vector of {{.Lanes}} {{.Base}}
|
||||||
|
type {{.Name}} struct {
|
||||||
|
{{.Fields}}
|
||||||
|
}
|
||||||
|
|
||||||
|
{{end}}
|
||||||
|
`
|
||||||
|
|
||||||
|
const simdFeaturesTemplate = `
|
||||||
|
import "internal/cpu"
|
||||||
|
|
||||||
|
{{range .}}
|
||||||
|
{{- if eq .Feature "AVX512"}}
|
||||||
|
// Has{{.Feature}} returns whether the CPU supports the AVX512F+CD+BW+DQ+VL features.
|
||||||
|
//
|
||||||
|
// These five CPU features are bundled together, and no use of AVX-512
|
||||||
|
// is allowed unless all of these features are supported together.
|
||||||
|
// Nearly every CPU that has shipped with any support for AVX-512 has
|
||||||
|
// supported all five of these features.
|
||||||
|
{{- else -}}
|
||||||
|
// Has{{.Feature}} returns whether the CPU supports the {{.Feature}} feature.
|
||||||
|
{{- end}}
|
||||||
|
//
|
||||||
|
// Has{{.Feature}} is defined on all GOARCHes, but will only return true on
|
||||||
|
// GOARCH {{.GoArch}}.
|
||||||
|
func Has{{.Feature}}() bool {
|
||||||
|
return cpu.X86.Has{{.Feature}}
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
|
`
|
||||||
|
|
||||||
|
const simdLoadStoreTemplate = `
|
||||||
|
// Len returns the number of elements in a {{.Name}}
|
||||||
|
func (x {{.Name}}) Len() int { return {{.Lanes}} }
|
||||||
|
|
||||||
|
// Load{{.Name}} loads a {{.Name}} from an array
|
||||||
|
//
|
||||||
|
//go:noescape
|
||||||
|
func Load{{.Name}}(y *[{{.Lanes}}]{{.Base}}) {{.Name}}
|
||||||
|
|
||||||
|
// Store stores a {{.Name}} to an array
|
||||||
|
//
|
||||||
|
//go:noescape
|
||||||
|
func (x {{.Name}}) Store(y *[{{.Lanes}}]{{.Base}})
|
||||||
|
`
|
||||||
|
|
||||||
|
const simdMaskFromBitsTemplate = `
|
||||||
|
// Load{{.Name}}FromBits constructs a {{.Name}} from a bitmap, where 1 means set for the indexed element, 0 means unset.
|
||||||
|
// Only the lower {{.Lanes}} bits of y are used.
|
||||||
|
//
|
||||||
|
// CPU Features: AVX512
|
||||||
|
//go:noescape
|
||||||
|
func Load{{.Name}}FromBits(y *uint64) {{.Name}}
|
||||||
|
|
||||||
|
// StoreToBits stores a {{.Name}} as a bitmap, where 1 means set for the indexed element, 0 means unset.
|
||||||
|
// Only the lower {{.Lanes}} bits of y are used.
|
||||||
|
//
|
||||||
|
// CPU Features: AVX512
|
||||||
|
//go:noescape
|
||||||
|
func (x {{.Name}}) StoreToBits(y *uint64)
|
||||||
|
`
|
||||||
|
|
||||||
|
const simdMaskFromValTemplate = `
|
||||||
|
// {{.Name}}FromBits constructs a {{.Name}} from a bitmap value, where 1 means set for the indexed element, 0 means unset.
|
||||||
|
// Only the lower {{.Lanes}} bits of y are used.
|
||||||
|
//
|
||||||
|
// Asm: KMOV{{.IntelSizeSuffix}}, CPU Feature: AVX512
|
||||||
|
func {{.Name}}FromBits(y uint{{.LanesContainer}}) {{.Name}}
|
||||||
|
|
||||||
|
// ToBits constructs a bitmap from a {{.Name}}, where 1 means set for the indexed element, 0 means unset.
|
||||||
|
// Only the lower {{.Lanes}} bits of y are used.
|
||||||
|
//
|
||||||
|
// Asm: KMOV{{.IntelSizeSuffix}}, CPU Features: AVX512
|
||||||
|
func (x {{.Name}}) ToBits() uint{{.LanesContainer}}
|
||||||
|
`
|
||||||
|
|
||||||
|
const simdMaskedLoadStoreTemplate = `
|
||||||
|
// LoadMasked{{.Name}} loads a {{.Name}} from an array,
|
||||||
|
// at those elements enabled by mask
|
||||||
|
//
|
||||||
|
{{.MaskedLoadDoc}}
|
||||||
|
//
|
||||||
|
//go:noescape
|
||||||
|
func LoadMasked{{.Name}}(y *[{{.Lanes}}]{{.Base}}, mask Mask{{.ElemBits}}x{{.Lanes}}) {{.Name}}
|
||||||
|
|
||||||
|
// StoreMasked stores a {{.Name}} to an array,
|
||||||
|
// at those elements enabled by mask
|
||||||
|
//
|
||||||
|
{{.MaskedStoreDoc}}
|
||||||
|
//
|
||||||
|
//go:noescape
|
||||||
|
func (x {{.Name}}) StoreMasked(y *[{{.Lanes}}]{{.Base}}, mask Mask{{.ElemBits}}x{{.Lanes}})
|
||||||
|
`
|
||||||
|
|
||||||
|
const simdStubsTmpl = `
|
||||||
|
{{define "op1"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op0NameAndType "x"}}) {{.Go}}() {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op2"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op0NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op2_21"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op2_21Type1"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op3"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op0NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op2NameAndType "z"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op3_31"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op2NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op0NameAndType "z"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op3_21"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}, {{.Op2NameAndType "z"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op3_21Type1"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}, {{.Op2NameAndType "z"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op3_231Type1"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.Op0NameAndType "z"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op2VecAsScalar"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op0NameAndType "x"}}) {{.Go}}(y uint{{(index .In 1).TreatLikeAScalarOfSize}}) {{(index .Out 0).Go}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op3VecAsScalar"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op0NameAndType "x"}}) {{.Go}}(y uint{{(index .In 1).TreatLikeAScalarOfSize}}, {{.Op2NameAndType "z"}}) {{(index .Out 0).Go}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op4"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op0NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op2NameAndType "z"}}, {{.Op3NameAndType "u"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op4_231Type1"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.Op0NameAndType "z"}}, {{.Op3NameAndType "u"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op4_31"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op2NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op0NameAndType "z"}}, {{.Op3NameAndType "u"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op1Imm8"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
|
||||||
|
//
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op2Imm8"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
|
||||||
|
//
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op2Imm8_2I"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
|
||||||
|
//
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.ImmName}} uint8) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
|
||||||
|
{{define "op3Imm8"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
|
||||||
|
//
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}, {{.Op3NameAndType "z"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "op3Imm8_2I"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
|
||||||
|
//
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.ImmName}} uint8, {{.Op3NameAndType "z"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
|
||||||
|
{{define "op4Imm8"}}
|
||||||
|
{{if .Documentation}}{{.Documentation}}
|
||||||
|
//{{end}}
|
||||||
|
// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
|
||||||
|
//
|
||||||
|
// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
|
||||||
|
func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}, {{.Op3NameAndType "z"}}, {{.Op4NameAndType "u"}}) {{.GoType}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "vectorConversion"}}
|
||||||
|
// {{.Tdst.Name}} converts from {{.Tsrc.Name}} to {{.Tdst.Name}}
|
||||||
|
func (from {{.Tsrc.Name}}) As{{.Tdst.Name}}() (to {{.Tdst.Name}})
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "mask"}}
|
||||||
|
// converts from {{.Name}} to {{.VectorCounterpart}}
|
||||||
|
func (from {{.Name}}) As{{.VectorCounterpart}}() (to {{.VectorCounterpart}})
|
||||||
|
|
||||||
|
// converts from {{.VectorCounterpart}} to {{.Name}}
|
||||||
|
func (from {{.VectorCounterpart}}) As{{.Name}}() (to {{.Name}})
|
||||||
|
|
||||||
|
func (x {{.Name}}) And(y {{.Name}}) {{.Name}}
|
||||||
|
|
||||||
|
func (x {{.Name}}) Or(y {{.Name}}) {{.Name}}
|
||||||
|
{{end}}
|
||||||
|
`
|
||||||
|
|
||||||
|
// parseSIMDTypes groups go simd types by their vector sizes, and
|
||||||
|
// returns a map whose key is the vector size, value is the simd type.
|
||||||
|
func parseSIMDTypes(ops []Operation) simdTypeMap {
|
||||||
|
// TODO: maybe instead of going over ops, let's try go over types.yaml.
|
||||||
|
ret := map[int][]simdType{}
|
||||||
|
seen := map[string]struct{}{}
|
||||||
|
processArg := func(arg Operand) {
|
||||||
|
if arg.Class == "immediate" || arg.Class == "greg" {
|
||||||
|
// Immediates are not encoded as vector types.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := seen[*arg.Go]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[*arg.Go] = struct{}{}
|
||||||
|
|
||||||
|
lanes := *arg.Lanes
|
||||||
|
base := fmt.Sprintf("%s%d", *arg.Base, *arg.ElemBits)
|
||||||
|
tagFieldNameS := fmt.Sprintf("%sx%d", base, lanes)
|
||||||
|
tagFieldS := fmt.Sprintf("%s v%d", tagFieldNameS, *arg.Bits)
|
||||||
|
valFieldS := fmt.Sprintf("vals%s[%d]%s", strings.Repeat(" ", len(tagFieldNameS)-3), lanes, base)
|
||||||
|
fields := fmt.Sprintf("\t%s\n\t%s", tagFieldS, valFieldS)
|
||||||
|
if arg.Class == "mask" {
|
||||||
|
vectorCounterpart := strings.ReplaceAll(*arg.Go, "Mask", "Int")
|
||||||
|
reshapedVectorWithAndOr := fmt.Sprintf("Int32x%d", *arg.Bits/32)
|
||||||
|
ret[*arg.Bits] = append(ret[*arg.Bits], simdType{*arg.Go, lanes, base, fields, arg.Class, vectorCounterpart, reshapedVectorWithAndOr, *arg.Bits})
|
||||||
|
// In case the vector counterpart of a mask is not present, put its vector counterpart typedef into the map as well.
|
||||||
|
if _, ok := seen[vectorCounterpart]; !ok {
|
||||||
|
seen[vectorCounterpart] = struct{}{}
|
||||||
|
ret[*arg.Bits] = append(ret[*arg.Bits], simdType{vectorCounterpart, lanes, base, fields, "vreg", "", "", *arg.Bits})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ret[*arg.Bits] = append(ret[*arg.Bits], simdType{*arg.Go, lanes, base, fields, arg.Class, "", "", *arg.Bits})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, op := range ops {
|
||||||
|
for _, arg := range op.In {
|
||||||
|
processArg(arg)
|
||||||
|
}
|
||||||
|
for _, arg := range op.Out {
|
||||||
|
processArg(arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func vConvertFromTypeMap(typeMap simdTypeMap) []simdTypePair {
|
||||||
|
v := []simdTypePair{}
|
||||||
|
for _, ts := range typeMap {
|
||||||
|
for i, tsrc := range ts {
|
||||||
|
for j, tdst := range ts {
|
||||||
|
if i != j && tsrc.Type == tdst.Type && tsrc.Type == "vreg" &&
|
||||||
|
tsrc.Lanes > 1 && tdst.Lanes > 1 {
|
||||||
|
v = append(v, simdTypePair{tsrc, tdst})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.SortFunc(v, compareSimdTypePairs)
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func masksFromTypeMap(typeMap simdTypeMap) []simdType {
|
||||||
|
m := []simdType{}
|
||||||
|
for _, ts := range typeMap {
|
||||||
|
for _, tsrc := range ts {
|
||||||
|
if tsrc.Type == "mask" {
|
||||||
|
m = append(m, tsrc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.SortFunc(m, compareSimdTypes)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func typesFromTypeMap(typeMap simdTypeMap) []simdType {
|
||||||
|
m := []simdType{}
|
||||||
|
for _, ts := range typeMap {
|
||||||
|
for _, tsrc := range ts {
|
||||||
|
if tsrc.Lanes > 1 {
|
||||||
|
m = append(m, tsrc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.SortFunc(m, compareSimdTypes)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeSIMDTypes generates the simd vector types into a bytes.Buffer
|
||||||
|
func writeSIMDTypes(typeMap simdTypeMap) *bytes.Buffer {
|
||||||
|
t := templateOf(simdTypesTemplates, "types_amd64")
|
||||||
|
loadStore := templateOf(simdLoadStoreTemplate, "loadstore_amd64")
|
||||||
|
maskedLoadStore := templateOf(simdMaskedLoadStoreTemplate, "maskedloadstore_amd64")
|
||||||
|
maskFromBits := templateOf(simdMaskFromBitsTemplate, "maskFromBits_amd64")
|
||||||
|
maskFromVal := templateOf(simdMaskFromValTemplate, "maskFromVal_amd64")
|
||||||
|
|
||||||
|
buffer := new(bytes.Buffer)
|
||||||
|
buffer.WriteString(simdPackageHeader)
|
||||||
|
|
||||||
|
sizes := make([]int, 0, len(typeMap))
|
||||||
|
for size, types := range typeMap {
|
||||||
|
slices.SortFunc(types, compareSimdTypes)
|
||||||
|
sizes = append(sizes, size)
|
||||||
|
}
|
||||||
|
sort.Ints(sizes)
|
||||||
|
|
||||||
|
for _, size := range sizes {
|
||||||
|
if size <= 64 {
|
||||||
|
// these are scalar
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := t.ExecuteTemplate(buffer, "sizeTmpl", size); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute size template for size %d: %w", size, err))
|
||||||
|
}
|
||||||
|
for _, typeDef := range typeMap[size] {
|
||||||
|
if typeDef.Lanes == 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := t.ExecuteTemplate(buffer, "typeTmpl", typeDef); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute type template for type %s: %w", typeDef.Name, err))
|
||||||
|
}
|
||||||
|
if typeDef.Type != "mask" {
|
||||||
|
if err := loadStore.ExecuteTemplate(buffer, "loadstore_amd64", typeDef); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute loadstore template for type %s: %w", typeDef.Name, err))
|
||||||
|
}
|
||||||
|
// restrict to AVX2 masked loads/stores first.
|
||||||
|
if typeDef.MaskedLoadStoreFilter() {
|
||||||
|
if err := maskedLoadStore.ExecuteTemplate(buffer, "maskedloadstore_amd64", typeDef); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute maskedloadstore template for type %s: %w", typeDef.Name, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := maskFromBits.ExecuteTemplate(buffer, "maskFromBits_amd64", typeDef); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute maskFromBits template for type %s: %w", typeDef.Name, err))
|
||||||
|
}
|
||||||
|
if err := maskFromVal.ExecuteTemplate(buffer, "maskFromVal_amd64", typeDef); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute maskFromVal template for type %s: %w", typeDef.Name, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSIMDFeatures(ops []Operation) *bytes.Buffer {
|
||||||
|
// Gather all features
|
||||||
|
type featureKey struct {
|
||||||
|
GoArch string
|
||||||
|
Feature string
|
||||||
|
}
|
||||||
|
featureSet := make(map[featureKey]struct{})
|
||||||
|
for _, op := range ops {
|
||||||
|
featureSet[featureKey{op.GoArch, op.CPUFeature}] = struct{}{}
|
||||||
|
}
|
||||||
|
features := slices.SortedFunc(maps.Keys(featureSet), func(a, b featureKey) int {
|
||||||
|
if c := cmp.Compare(a.GoArch, b.GoArch); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return compareNatural(a.Feature, b.Feature)
|
||||||
|
})
|
||||||
|
|
||||||
|
// If we ever have the same feature name on more than one GOARCH, we'll have
|
||||||
|
// to be more careful about this.
|
||||||
|
t := templateOf(simdFeaturesTemplate, "features")
|
||||||
|
|
||||||
|
buffer := new(bytes.Buffer)
|
||||||
|
buffer.WriteString(simdPackageHeader)
|
||||||
|
|
||||||
|
if err := t.Execute(buffer, features); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute features template: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeSIMDStubs generates the simd vector intrinsic stubs and writes it to ops_amd64.go and ops_internal_amd64.go
|
||||||
|
// within the specified directory.
|
||||||
|
func writeSIMDStubs(ops []Operation, typeMap simdTypeMap) *bytes.Buffer {
|
||||||
|
t := templateOf(simdStubsTmpl, "simdStubs")
|
||||||
|
buffer := new(bytes.Buffer)
|
||||||
|
buffer.WriteString(simdPackageHeader)
|
||||||
|
|
||||||
|
slices.SortFunc(ops, compareOperations)
|
||||||
|
|
||||||
|
for i, op := range ops {
|
||||||
|
if op.NoTypes != nil && *op.NoTypes == "true" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idxVecAsScalar, err := checkVecAsScalar(op)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if s, op, err := classifyOp(op); err == nil {
|
||||||
|
if idxVecAsScalar != -1 {
|
||||||
|
if s == "op2" || s == "op3" {
|
||||||
|
s += "VecAsScalar"
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("simdgen only supports op2 or op3 with TreatLikeAScalarOfSize"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i == 0 || op.Go != ops[i-1].Go {
|
||||||
|
fmt.Fprintf(buffer, "\n/* %s */\n", op.Go)
|
||||||
|
}
|
||||||
|
if err := t.ExecuteTemplate(buffer, s, op); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute template %s for op %v: %w", s, op, err))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("failed to classify op %v: %w", op.Go, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vectorConversions := vConvertFromTypeMap(typeMap)
|
||||||
|
for _, conv := range vectorConversions {
|
||||||
|
if err := t.ExecuteTemplate(buffer, "vectorConversion", conv); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute vectorConversion template: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
masks := masksFromTypeMap(typeMap)
|
||||||
|
for _, mask := range masks {
|
||||||
|
if err := t.ExecuteTemplate(buffer, "mask", mask); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute mask template for mask %s: %w", mask.Name, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
211
src/simd/_gen/simdgen/gen_simdrules.go
Normal file
211
src/simd/_gen/simdgen/gen_simdrules.go
Normal file
|
|
@ -0,0 +1,211 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"text/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tplRuleData struct {
|
||||||
|
tplName string // e.g. "sftimm"
|
||||||
|
GoOp string // e.g. "ShiftAllLeft"
|
||||||
|
GoType string // e.g. "Uint32x8"
|
||||||
|
Args string // e.g. "x y"
|
||||||
|
Asm string // e.g. "VPSLLD256"
|
||||||
|
ArgsOut string // e.g. "x y"
|
||||||
|
MaskInConvert string // e.g. "VPMOVVec32x8ToM"
|
||||||
|
MaskOutConvert string // e.g. "VPMOVMToVec32x8"
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ruleTemplates = template.Must(template.New("simdRules").Parse(`
|
||||||
|
{{define "pureVreg"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ArgsOut}})
|
||||||
|
{{end}}
|
||||||
|
{{define "maskIn"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask))
|
||||||
|
{{end}}
|
||||||
|
{{define "maskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}}))
|
||||||
|
{{end}}
|
||||||
|
{{define "maskInMaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask)))
|
||||||
|
{{end}}
|
||||||
|
{{define "sftimm"}}({{.Asm}} x (MOVQconst [c])) => ({{.Asm}}const [uint8(c)] x)
|
||||||
|
{{end}}
|
||||||
|
{{define "masksftimm"}}({{.Asm}} x (MOVQconst [c]) mask) => ({{.Asm}}const [uint8(c)] x mask)
|
||||||
|
{{end}}
|
||||||
|
`))
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSA rewrite rules need to appear in a most-to-least-specific order. This works for that.
|
||||||
|
var tmplOrder = map[string]int{
|
||||||
|
"masksftimm": 0,
|
||||||
|
"sftimm": 1,
|
||||||
|
"maskInMaskOut": 2,
|
||||||
|
"maskOut": 3,
|
||||||
|
"maskIn": 4,
|
||||||
|
"pureVreg": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareTplRuleData(x, y tplRuleData) int {
|
||||||
|
if c := compareNatural(x.GoOp, y.GoOp); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
if c := compareNatural(x.GoType, y.GoType); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
if c := compareNatural(x.Args, y.Args); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
if x.tplName == y.tplName {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
xo, xok := tmplOrder[x.tplName]
|
||||||
|
yo, yok := tmplOrder[y.tplName]
|
||||||
|
if !xok {
|
||||||
|
panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", x.tplName))
|
||||||
|
}
|
||||||
|
if !yok {
|
||||||
|
panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", y.tplName))
|
||||||
|
}
|
||||||
|
return xo - yo
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules
|
||||||
|
// within the specified directory.
|
||||||
|
func writeSIMDRules(ops []Operation) *bytes.Buffer {
|
||||||
|
buffer := new(bytes.Buffer)
|
||||||
|
buffer.WriteString(generatedHeader + "\n")
|
||||||
|
|
||||||
|
var allData []tplRuleData
|
||||||
|
|
||||||
|
for _, opr := range ops {
|
||||||
|
if opr.NoGenericOps != nil && *opr.NoGenericOps == "true" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
opInShape, opOutShape, maskType, immType, gOp := opr.shape()
|
||||||
|
asm := machineOpName(maskType, gOp)
|
||||||
|
vregInCnt := len(gOp.In)
|
||||||
|
if maskType == OneMask {
|
||||||
|
vregInCnt--
|
||||||
|
}
|
||||||
|
|
||||||
|
data := tplRuleData{
|
||||||
|
GoOp: gOp.Go,
|
||||||
|
Asm: asm,
|
||||||
|
}
|
||||||
|
|
||||||
|
if vregInCnt == 1 {
|
||||||
|
data.Args = "x"
|
||||||
|
data.ArgsOut = data.Args
|
||||||
|
} else if vregInCnt == 2 {
|
||||||
|
data.Args = "x y"
|
||||||
|
data.ArgsOut = data.Args
|
||||||
|
} else if vregInCnt == 3 {
|
||||||
|
data.Args = "x y z"
|
||||||
|
data.ArgsOut = data.Args
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("simdgen does not support more than 3 vreg in inputs"))
|
||||||
|
}
|
||||||
|
if immType == ConstImm {
|
||||||
|
data.ArgsOut = fmt.Sprintf("[%s] %s", *opr.In[0].Const, data.ArgsOut)
|
||||||
|
} else if immType == VarImm {
|
||||||
|
data.Args = fmt.Sprintf("[a] %s", data.Args)
|
||||||
|
data.ArgsOut = fmt.Sprintf("[a] %s", data.ArgsOut)
|
||||||
|
} else if immType == ConstVarImm {
|
||||||
|
data.Args = fmt.Sprintf("[a] %s", data.Args)
|
||||||
|
data.ArgsOut = fmt.Sprintf("[a+%s] %s", *opr.In[0].Const, data.ArgsOut)
|
||||||
|
}
|
||||||
|
|
||||||
|
goType := func(op Operation) string {
|
||||||
|
if op.OperandOrder != nil {
|
||||||
|
switch *op.OperandOrder {
|
||||||
|
case "21Type1", "231Type1":
|
||||||
|
// Permute uses operand[1] for method receiver.
|
||||||
|
return *op.In[1].Go
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *op.In[0].Go
|
||||||
|
}
|
||||||
|
var tplName string
|
||||||
|
// If class overwrite is happening, that's not really a mask but a vreg.
|
||||||
|
if opOutShape == OneVregOut || opOutShape == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
|
||||||
|
switch opInShape {
|
||||||
|
case OneImmIn:
|
||||||
|
tplName = "pureVreg"
|
||||||
|
data.GoType = goType(gOp)
|
||||||
|
case PureVregIn:
|
||||||
|
tplName = "pureVreg"
|
||||||
|
data.GoType = goType(gOp)
|
||||||
|
case OneKmaskImmIn:
|
||||||
|
fallthrough
|
||||||
|
case OneKmaskIn:
|
||||||
|
tplName = "maskIn"
|
||||||
|
data.GoType = goType(gOp)
|
||||||
|
rearIdx := len(gOp.In) - 1
|
||||||
|
// Mask is at the end.
|
||||||
|
data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes)
|
||||||
|
case PureKmaskIn:
|
||||||
|
panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
|
||||||
|
}
|
||||||
|
} else if opOutShape == OneGregOut {
|
||||||
|
tplName = "pureVreg" // TODO this will be wrong
|
||||||
|
data.GoType = goType(gOp)
|
||||||
|
} else {
|
||||||
|
// OneKmaskOut case
|
||||||
|
data.MaskOutConvert = fmt.Sprintf("VPMOVMToVec%dx%d", *gOp.Out[0].ElemBits, *gOp.In[0].Lanes)
|
||||||
|
switch opInShape {
|
||||||
|
case OneImmIn:
|
||||||
|
fallthrough
|
||||||
|
case PureVregIn:
|
||||||
|
tplName = "maskOut"
|
||||||
|
data.GoType = goType(gOp)
|
||||||
|
case OneKmaskImmIn:
|
||||||
|
fallthrough
|
||||||
|
case OneKmaskIn:
|
||||||
|
tplName = "maskInMaskOut"
|
||||||
|
data.GoType = goType(gOp)
|
||||||
|
rearIdx := len(gOp.In) - 1
|
||||||
|
data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes)
|
||||||
|
case PureKmaskIn:
|
||||||
|
panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gOp.SpecialLower != nil {
|
||||||
|
if *gOp.SpecialLower == "sftimm" {
|
||||||
|
if data.GoType[0] == 'I' {
|
||||||
|
// only do these for signed types, it is a duplicate rewrite for unsigned
|
||||||
|
sftImmData := data
|
||||||
|
if tplName == "maskIn" {
|
||||||
|
sftImmData.tplName = "masksftimm"
|
||||||
|
} else {
|
||||||
|
sftImmData.tplName = "sftimm"
|
||||||
|
}
|
||||||
|
allData = append(allData, sftImmData)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic("simdgen sees unknwon special lower " + *gOp.SpecialLower + ", maybe implement it?")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tplName == "pureVreg" && data.Args == data.ArgsOut {
|
||||||
|
data.Args = "..."
|
||||||
|
data.ArgsOut = "..."
|
||||||
|
}
|
||||||
|
data.tplName = tplName
|
||||||
|
allData = append(allData, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(allData, compareTplRuleData)
|
||||||
|
|
||||||
|
for _, data := range allData {
|
||||||
|
if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
173
src/simd/_gen/simdgen/gen_simdssa.go
Normal file
173
src/simd/_gen/simdgen/gen_simdssa.go
Normal file
|
|
@ -0,0 +1,173 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ssaTemplates = template.Must(template.New("simdSSA").Parse(`
|
||||||
|
{{define "header"}}// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
|
||||||
|
|
||||||
|
package amd64
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmd/compile/internal/ssa"
|
||||||
|
"cmd/compile/internal/ssagen"
|
||||||
|
"cmd/internal/obj"
|
||||||
|
"cmd/internal/obj/x86"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool {
|
||||||
|
var p *obj.Prog
|
||||||
|
switch v.Op {{"{"}}{{end}}
|
||||||
|
{{define "case"}}
|
||||||
|
case {{.Cases}}:
|
||||||
|
p = {{.Helper}}(s, v)
|
||||||
|
{{end}}
|
||||||
|
{{define "footer"}}
|
||||||
|
default:
|
||||||
|
// Unknown reg shape
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
|
{{define "zeroing"}}
|
||||||
|
// Masked operation are always compiled with zeroing.
|
||||||
|
switch v.Op {
|
||||||
|
case {{.}}:
|
||||||
|
x86.ParseSuffix(p, "Z")
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
|
{{define "ending"}}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
{{end}}`))
|
||||||
|
)
|
||||||
|
|
||||||
|
type tplSSAData struct {
|
||||||
|
Cases string
|
||||||
|
Helper string
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeSIMDSSA generates the ssa to prog lowering codes and writes it to simdssa.go
|
||||||
|
// within the specified directory.
|
||||||
|
func writeSIMDSSA(ops []Operation) *bytes.Buffer {
|
||||||
|
var ZeroingMask []string
|
||||||
|
regInfoKeys := []string{
|
||||||
|
"v11",
|
||||||
|
"v21",
|
||||||
|
"v2k",
|
||||||
|
"v2kv",
|
||||||
|
"v2kk",
|
||||||
|
"vkv",
|
||||||
|
"v31",
|
||||||
|
"v3kv",
|
||||||
|
"v11Imm8",
|
||||||
|
"vkvImm8",
|
||||||
|
"v21Imm8",
|
||||||
|
"v2kImm8",
|
||||||
|
"v2kkImm8",
|
||||||
|
"v31ResultInArg0",
|
||||||
|
"v3kvResultInArg0",
|
||||||
|
"vfpv",
|
||||||
|
"vfpkv",
|
||||||
|
"vgpvImm8",
|
||||||
|
"vgpImm8",
|
||||||
|
"v2kvImm8",
|
||||||
|
}
|
||||||
|
regInfoSet := map[string][]string{}
|
||||||
|
for _, key := range regInfoKeys {
|
||||||
|
regInfoSet[key] = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := map[string]struct{}{}
|
||||||
|
allUnseen := make(map[string][]Operation)
|
||||||
|
for _, op := range ops {
|
||||||
|
shapeIn, shapeOut, maskType, _, gOp := op.shape()
|
||||||
|
asm := machineOpName(maskType, gOp)
|
||||||
|
|
||||||
|
if _, ok := seen[asm]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[asm] = struct{}{}
|
||||||
|
caseStr := fmt.Sprintf("ssa.OpAMD64%s", asm)
|
||||||
|
if shapeIn == OneKmaskIn || shapeIn == OneKmaskImmIn {
|
||||||
|
if gOp.Zeroing == nil {
|
||||||
|
ZeroingMask = append(ZeroingMask, caseStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
regShape, err := op.regShape()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if shapeOut == OneVregOutAtIn {
|
||||||
|
regShape += "ResultInArg0"
|
||||||
|
}
|
||||||
|
if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
|
||||||
|
regShape += "Imm8"
|
||||||
|
}
|
||||||
|
idx, err := checkVecAsScalar(op)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if idx != -1 {
|
||||||
|
if regShape == "v21" {
|
||||||
|
regShape = "vfpv"
|
||||||
|
} else if regShape == "v2kv" {
|
||||||
|
regShape = "vfpkv"
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("simdgen does not recognize uses of treatLikeAScalarOfSize with op regShape %s in op: %s", regShape, op))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ok := regInfoSet[regShape]; !ok {
|
||||||
|
allUnseen[regShape] = append(allUnseen[regShape], op)
|
||||||
|
}
|
||||||
|
regInfoSet[regShape] = append(regInfoSet[regShape], caseStr)
|
||||||
|
}
|
||||||
|
if len(allUnseen) != 0 {
|
||||||
|
panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v", allUnseen))
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer := new(bytes.Buffer)
|
||||||
|
|
||||||
|
if err := ssaTemplates.ExecuteTemplate(buffer, "header", nil); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute header template: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, regShape := range regInfoKeys {
|
||||||
|
// Stable traversal of regInfoSet
|
||||||
|
cases := regInfoSet[regShape]
|
||||||
|
if len(cases) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := tplSSAData{
|
||||||
|
Cases: strings.Join(cases, ",\n\t\t"),
|
||||||
|
Helper: "simd" + capitalizeFirst(regShape),
|
||||||
|
}
|
||||||
|
if err := ssaTemplates.ExecuteTemplate(buffer, "case", data); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute case template for %s: %w", regShape, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ssaTemplates.ExecuteTemplate(buffer, "footer", nil); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute footer template: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ZeroingMask) != 0 {
|
||||||
|
if err := ssaTemplates.ExecuteTemplate(buffer, "zeroing", strings.Join(ZeroingMask, ",\n\t\t")); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute footer template: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ssaTemplates.ExecuteTemplate(buffer, "ending", nil); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute footer template: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
729
src/simd/_gen/simdgen/gen_utility.go
Normal file
729
src/simd/_gen/simdgen/gen_utility.go
Normal file
|
|
@ -0,0 +1,729 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"go/format"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
func templateOf(temp, name string) *template.Template {
|
||||||
|
t, err := template.New(name).Parse(temp)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to parse template %s: %w", name, err))
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func createPath(goroot string, file string) (*os.File, error) {
|
||||||
|
fp := filepath.Join(goroot, file)
|
||||||
|
dir := filepath.Dir(fp)
|
||||||
|
err := os.MkdirAll(dir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create directory %s: %w", dir, err)
|
||||||
|
}
|
||||||
|
f, err := os.Create(fp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create file %s: %w", fp, err)
|
||||||
|
}
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatWriteAndClose(out *bytes.Buffer, goroot string, file string) {
|
||||||
|
b, err := format.Source(out.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
|
||||||
|
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||||
|
panic(err)
|
||||||
|
} else {
|
||||||
|
writeAndClose(b, goroot, file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeAndClose(b []byte, goroot string, file string) {
|
||||||
|
ofile, err := createPath(goroot, file)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
ofile.Write(b)
|
||||||
|
ofile.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// numberLines takes a slice of bytes, and returns a string where each line
|
||||||
|
// is numbered, starting from 1.
|
||||||
|
func numberLines(data []byte) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
r := bytes.NewReader(data)
|
||||||
|
s := bufio.NewScanner(r)
|
||||||
|
for i := 1; s.Scan(); i++ {
|
||||||
|
fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
type inShape uint8
|
||||||
|
type outShape uint8
|
||||||
|
type maskShape uint8
|
||||||
|
type immShape uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
InvalidIn inShape = iota
|
||||||
|
PureVregIn // vector register input only
|
||||||
|
OneKmaskIn // vector and kmask input
|
||||||
|
OneImmIn // vector and immediate input
|
||||||
|
OneKmaskImmIn // vector, kmask, and immediate inputs
|
||||||
|
PureKmaskIn // only mask inputs.
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
InvalidOut outShape = iota
|
||||||
|
NoOut // no output
|
||||||
|
OneVregOut // (one) vector register output
|
||||||
|
OneGregOut // (one) general register output
|
||||||
|
OneKmaskOut // mask output
|
||||||
|
OneVregOutAtIn // the first input is also the output
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
InvalidMask maskShape = iota
|
||||||
|
NoMask // no mask
|
||||||
|
OneMask // with mask (K1 to K7)
|
||||||
|
AllMasks // a K mask instruction (K0-K7)
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
InvalidImm immShape = iota
|
||||||
|
NoImm // no immediate
|
||||||
|
ConstImm // const only immediate
|
||||||
|
VarImm // pure imm argument provided by the users
|
||||||
|
ConstVarImm // a combination of user arg and const
|
||||||
|
)
|
||||||
|
|
||||||
|
// opShape returns the several integers describing the shape of the operation,
|
||||||
|
// and modified versions of the op:
|
||||||
|
//
|
||||||
|
// opNoImm is op with its inputs excluding the const imm.
|
||||||
|
//
|
||||||
|
// This function does not modify op.
|
||||||
|
func (op *Operation) shape() (shapeIn inShape, shapeOut outShape, maskType maskShape, immType immShape,
|
||||||
|
opNoImm Operation) {
|
||||||
|
if len(op.Out) > 1 {
|
||||||
|
panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
|
||||||
|
}
|
||||||
|
var outputReg int
|
||||||
|
if len(op.Out) == 1 {
|
||||||
|
outputReg = op.Out[0].AsmPos
|
||||||
|
if op.Out[0].Class == "vreg" {
|
||||||
|
shapeOut = OneVregOut
|
||||||
|
} else if op.Out[0].Class == "greg" {
|
||||||
|
shapeOut = OneGregOut
|
||||||
|
} else if op.Out[0].Class == "mask" {
|
||||||
|
shapeOut = OneKmaskOut
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("simdgen only supports output of class vreg or mask: %s", op))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
shapeOut = NoOut
|
||||||
|
// TODO: are these only Load/Stores?
|
||||||
|
// We manually supported two Load and Store, are those enough?
|
||||||
|
panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
|
||||||
|
}
|
||||||
|
hasImm := false
|
||||||
|
maskCount := 0
|
||||||
|
hasVreg := false
|
||||||
|
for _, in := range op.In {
|
||||||
|
if in.AsmPos == outputReg {
|
||||||
|
if shapeOut != OneVregOutAtIn && in.AsmPos == 0 && in.Class == "vreg" {
|
||||||
|
shapeOut = OneVregOutAtIn
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("simdgen only support output and input sharing the same position case of \"the first input is vreg and the only output\": %s", op))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if in.Class == "immediate" {
|
||||||
|
// A manual check on XED data found that AMD64 SIMD instructions at most
|
||||||
|
// have 1 immediates. So we don't need to check this here.
|
||||||
|
if *in.Bits != 8 {
|
||||||
|
panic(fmt.Errorf("simdgen only supports immediates of 8 bits: %s", op))
|
||||||
|
}
|
||||||
|
hasImm = true
|
||||||
|
} else if in.Class == "mask" {
|
||||||
|
maskCount++
|
||||||
|
} else {
|
||||||
|
hasVreg = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opNoImm = *op
|
||||||
|
|
||||||
|
removeImm := func(o *Operation) {
|
||||||
|
o.In = o.In[1:]
|
||||||
|
}
|
||||||
|
if hasImm {
|
||||||
|
removeImm(&opNoImm)
|
||||||
|
if op.In[0].Const != nil {
|
||||||
|
if op.In[0].ImmOffset != nil {
|
||||||
|
immType = ConstVarImm
|
||||||
|
} else {
|
||||||
|
immType = ConstImm
|
||||||
|
}
|
||||||
|
} else if op.In[0].ImmOffset != nil {
|
||||||
|
immType = VarImm
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("simdgen requires imm to have at least one of ImmOffset or Const set: %s", op))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
immType = NoImm
|
||||||
|
}
|
||||||
|
if maskCount == 0 {
|
||||||
|
maskType = NoMask
|
||||||
|
} else {
|
||||||
|
maskType = OneMask
|
||||||
|
}
|
||||||
|
checkPureMask := func() bool {
|
||||||
|
if hasImm {
|
||||||
|
panic(fmt.Errorf("simdgen does not support immediates in pure mask operations: %s", op))
|
||||||
|
}
|
||||||
|
if hasVreg {
|
||||||
|
panic(fmt.Errorf("simdgen does not support more than 1 masks in non-pure mask operations: %s", op))
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !hasImm && maskCount == 0 {
|
||||||
|
shapeIn = PureVregIn
|
||||||
|
} else if !hasImm && maskCount > 0 {
|
||||||
|
if maskCount == 1 {
|
||||||
|
shapeIn = OneKmaskIn
|
||||||
|
} else {
|
||||||
|
if checkPureMask() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
shapeIn = PureKmaskIn
|
||||||
|
maskType = AllMasks
|
||||||
|
}
|
||||||
|
} else if hasImm && maskCount == 0 {
|
||||||
|
shapeIn = OneImmIn
|
||||||
|
} else {
|
||||||
|
if maskCount == 1 {
|
||||||
|
shapeIn = OneKmaskImmIn
|
||||||
|
} else {
|
||||||
|
checkPureMask()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// regShape returns a string representation of the register shape.
|
||||||
|
func (op *Operation) regShape() (string, error) {
|
||||||
|
_, _, _, _, gOp := op.shape()
|
||||||
|
var regInfo string
|
||||||
|
var vRegInCnt, gRegInCnt, kMaskInCnt, vRegOutCnt, gRegOutCnt, kMaskOutCnt int
|
||||||
|
for _, in := range gOp.In {
|
||||||
|
if in.Class == "vreg" {
|
||||||
|
vRegInCnt++
|
||||||
|
} else if in.Class == "greg" {
|
||||||
|
gRegInCnt++
|
||||||
|
} else if in.Class == "mask" {
|
||||||
|
kMaskInCnt++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, out := range gOp.Out {
|
||||||
|
// If class overwrite is happening, that's not really a mask but a vreg.
|
||||||
|
if out.Class == "vreg" || out.OverwriteClass != nil {
|
||||||
|
vRegOutCnt++
|
||||||
|
} else if out.Class == "greg" {
|
||||||
|
gRegOutCnt++
|
||||||
|
} else if out.Class == "mask" {
|
||||||
|
kMaskOutCnt++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var inRegs, inMasks, outRegs, outMasks string
|
||||||
|
|
||||||
|
rmAbbrev := func(s string, i int) string {
|
||||||
|
if i == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if i == 1 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s%d", s, i)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
inRegs = rmAbbrev("v", vRegInCnt)
|
||||||
|
inRegs += rmAbbrev("gp", gRegInCnt)
|
||||||
|
inMasks = rmAbbrev("k", kMaskInCnt)
|
||||||
|
|
||||||
|
outRegs = rmAbbrev("v", vRegOutCnt)
|
||||||
|
outRegs += rmAbbrev("gp", gRegOutCnt)
|
||||||
|
outMasks = rmAbbrev("k", kMaskOutCnt)
|
||||||
|
|
||||||
|
if kMaskInCnt == 0 && kMaskOutCnt == 0 && gRegInCnt == 0 && gRegOutCnt == 0 {
|
||||||
|
// For pure v we can abbreviate it as v%d%d.
|
||||||
|
regInfo = fmt.Sprintf("v%d%d", vRegInCnt, vRegOutCnt)
|
||||||
|
} else if kMaskInCnt == 0 && kMaskOutCnt == 0 {
|
||||||
|
regInfo = fmt.Sprintf("%s%s", inRegs, outRegs)
|
||||||
|
} else {
|
||||||
|
regInfo = fmt.Sprintf("%s%s%s%s", inRegs, inMasks, outRegs, outMasks)
|
||||||
|
}
|
||||||
|
return regInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortOperand sorts op.In by putting immediates first, then vreg, and mask the last.
|
||||||
|
// TODO: verify that this is a safe assumption of the prog structure.
|
||||||
|
// from my observation looks like in asm, imms are always the first,
|
||||||
|
// masks are always the last, with vreg in between.
|
||||||
|
func (op *Operation) sortOperand() {
|
||||||
|
priority := map[string]int{"immediate": 0, "vreg": 1, "greg": 1, "mask": 2}
|
||||||
|
sort.SliceStable(op.In, func(i, j int) bool {
|
||||||
|
pi := priority[op.In[i].Class]
|
||||||
|
pj := priority[op.In[j].Class]
|
||||||
|
if pi != pj {
|
||||||
|
return pi < pj
|
||||||
|
}
|
||||||
|
return op.In[i].AsmPos < op.In[j].AsmPos
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// goNormalType returns the Go type name for the result of an Op that
|
||||||
|
// does not return a vector, i.e., that returns a result in a general
|
||||||
|
// register. Currently there's only one family of Ops in Go's simd library
|
||||||
|
// that does this (GetElem), and so this is specialized to work for that,
|
||||||
|
// but the problem (mismatch betwen hardware register width and Go type
|
||||||
|
// width) seems likely to recur if there are any other cases.
|
||||||
|
func (op Operation) goNormalType() string {
|
||||||
|
if op.Go == "GetElem" {
|
||||||
|
// GetElem returns an element of the vector into a general register
|
||||||
|
// but as far as the hardware is concerned, that result is either 32
|
||||||
|
// or 64 bits wide, no matter what the vector element width is.
|
||||||
|
// This is not "wrong" but it is not the right answer for Go source code.
|
||||||
|
// To get the Go type right, combine the base type ("int", "uint", "float"),
|
||||||
|
// with the input vector element width in bits (8,16,32,64).
|
||||||
|
|
||||||
|
at := 0 // proper value of at depends on whether immediate was stripped or not
|
||||||
|
if op.In[at].Class == "immediate" {
|
||||||
|
at++
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s%d", *op.Out[0].Base, *op.In[at].ElemBits)
|
||||||
|
}
|
||||||
|
panic(fmt.Errorf("Implement goNormalType for %v", op))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSAType returns the string for the type reference in SSA generation,
|
||||||
|
// for example in the intrinsics generating template.
|
||||||
|
func (op Operation) SSAType() string {
|
||||||
|
if op.Out[0].Class == "greg" {
|
||||||
|
return fmt.Sprintf("types.Types[types.T%s]", strings.ToUpper(op.goNormalType()))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("types.TypeVec%d", *op.Out[0].Bits)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GoType returns the Go type returned by this operation (relative to the simd package),
|
||||||
|
// for example "int32" or "Int8x16". This is used in a template.
|
||||||
|
func (op Operation) GoType() string {
|
||||||
|
if op.Out[0].Class == "greg" {
|
||||||
|
return op.goNormalType()
|
||||||
|
}
|
||||||
|
return *op.Out[0].Go
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImmName returns the name to use for an operation's immediate operand.
|
||||||
|
// This can be overriden in the yaml with "name" on an operand,
|
||||||
|
// otherwise, for now, "constant"
|
||||||
|
func (op Operation) ImmName() string {
|
||||||
|
return op.Op0Name("constant")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Operand) OpName(s string) string {
|
||||||
|
if n := o.Name; n != nil {
|
||||||
|
return *n
|
||||||
|
}
|
||||||
|
if o.Class == "mask" {
|
||||||
|
return "mask"
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Operand) OpNameAndType(s string) string {
|
||||||
|
return o.OpName(s) + " " + *o.Go
|
||||||
|
}
|
||||||
|
|
||||||
|
// GoExported returns [Go] with first character capitalized.
|
||||||
|
func (op Operation) GoExported() string {
|
||||||
|
return capitalizeFirst(op.Go)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DocumentationExported returns [Documentation] with method name capitalized.
|
||||||
|
func (op Operation) DocumentationExported() string {
|
||||||
|
return strings.ReplaceAll(op.Documentation, op.Go, op.GoExported())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Op0Name returns the name to use for the 0 operand,
|
||||||
|
// if any is present, otherwise the parameter is used.
|
||||||
|
func (op Operation) Op0Name(s string) string {
|
||||||
|
return op.In[0].OpName(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Op1Name returns the name to use for the 1 operand,
|
||||||
|
// if any is present, otherwise the parameter is used.
|
||||||
|
func (op Operation) Op1Name(s string) string {
|
||||||
|
return op.In[1].OpName(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Op2Name returns the name to use for the 2 operand,
|
||||||
|
// if any is present, otherwise the parameter is used.
|
||||||
|
func (op Operation) Op2Name(s string) string {
|
||||||
|
return op.In[2].OpName(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Op3Name returns the name to use for the 3 operand,
|
||||||
|
// if any is present, otherwise the parameter is used.
|
||||||
|
func (op Operation) Op3Name(s string) string {
|
||||||
|
return op.In[3].OpName(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Op0NameAndType returns the name and type to use for
|
||||||
|
// the 0 operand, if a name is provided, otherwise
|
||||||
|
// the parameter value is used as the default.
|
||||||
|
func (op Operation) Op0NameAndType(s string) string {
|
||||||
|
return op.In[0].OpNameAndType(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Op1NameAndType returns the name and type to use for
|
||||||
|
// the 1 operand, if a name is provided, otherwise
|
||||||
|
// the parameter value is used as the default.
|
||||||
|
func (op Operation) Op1NameAndType(s string) string {
|
||||||
|
return op.In[1].OpNameAndType(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Op2NameAndType returns the name and type to use for
|
||||||
|
// the 2 operand, if a name is provided, otherwise
|
||||||
|
// the parameter value is used as the default.
|
||||||
|
func (op Operation) Op2NameAndType(s string) string {
|
||||||
|
return op.In[2].OpNameAndType(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Op3NameAndType returns the name and type to use for
|
||||||
|
// the 3 operand, if a name is provided, otherwise
|
||||||
|
// the parameter value is used as the default.
|
||||||
|
func (op Operation) Op3NameAndType(s string) string {
|
||||||
|
return op.In[3].OpNameAndType(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Op4NameAndType returns the name and type to use for
|
||||||
|
// the 4 operand, if a name is provided, otherwise
|
||||||
|
// the parameter value is used as the default.
|
||||||
|
func (op Operation) Op4NameAndType(s string) string {
|
||||||
|
return op.In[4].OpNameAndType(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
var immClasses []string = []string{"BAD0Imm", "BAD1Imm", "op1Imm8", "op2Imm8", "op3Imm8", "op4Imm8"}
|
||||||
|
var classes []string = []string{"BAD0", "op1", "op2", "op3", "op4"}
|
||||||
|
|
||||||
|
// classifyOp returns a classification string, modified operation, and perhaps error based
|
||||||
|
// on the stub and intrinsic shape for the operation.
|
||||||
|
// The classification string is in the regular expression set "op[1234](Imm8)?(_<order>)?"
|
||||||
|
// where the "<order>" suffix is optionally attached to the Operation in its input yaml.
|
||||||
|
// The classification string is used to select a template or a clause of a template
|
||||||
|
// for intrinsics declaration and the ssagen intrinisics glue code in the compiler.
|
||||||
|
func classifyOp(op Operation) (string, Operation, error) {
|
||||||
|
_, _, _, immType, gOp := op.shape()
|
||||||
|
|
||||||
|
var class string
|
||||||
|
|
||||||
|
if immType == VarImm || immType == ConstVarImm {
|
||||||
|
switch l := len(op.In); l {
|
||||||
|
case 1:
|
||||||
|
return "", op, fmt.Errorf("simdgen does not recognize this operation of only immediate input: %s", op)
|
||||||
|
case 2, 3, 4, 5:
|
||||||
|
class = immClasses[l]
|
||||||
|
default:
|
||||||
|
return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
|
||||||
|
}
|
||||||
|
if order := op.OperandOrder; order != nil {
|
||||||
|
class += "_" + *order
|
||||||
|
}
|
||||||
|
return class, op, nil
|
||||||
|
} else {
|
||||||
|
switch l := len(gOp.In); l {
|
||||||
|
case 1, 2, 3, 4:
|
||||||
|
class = classes[l]
|
||||||
|
default:
|
||||||
|
return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
|
||||||
|
}
|
||||||
|
if order := op.OperandOrder; order != nil {
|
||||||
|
class += "_" + *order
|
||||||
|
}
|
||||||
|
return class, gOp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkVecAsScalar(op Operation) (idx int, err error) {
|
||||||
|
idx = -1
|
||||||
|
sSize := 0
|
||||||
|
for i, o := range op.In {
|
||||||
|
if o.TreatLikeAScalarOfSize != nil {
|
||||||
|
if idx == -1 {
|
||||||
|
idx = i
|
||||||
|
sSize = *o.TreatLikeAScalarOfSize
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("simdgen only supports one TreatLikeAScalarOfSize in the arg list: %s", op)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if idx >= 0 {
|
||||||
|
if idx != 1 {
|
||||||
|
err = fmt.Errorf("simdgen only supports TreatLikeAScalarOfSize at the 2nd arg of the arg list: %s", op)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if sSize != 8 && sSize != 16 && sSize != 32 && sSize != 64 {
|
||||||
|
err = fmt.Errorf("simdgen does not recognize this uint size: %d, %s", sSize, op)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// dedup is deduping operations in the full structure level.
|
||||||
|
func dedup(ops []Operation) (deduped []Operation) {
|
||||||
|
for _, op := range ops {
|
||||||
|
seen := false
|
||||||
|
for _, dop := range deduped {
|
||||||
|
if reflect.DeepEqual(op, dop) {
|
||||||
|
seen = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !seen {
|
||||||
|
deduped = append(deduped, op)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (op Operation) GenericName() string {
|
||||||
|
if op.OperandOrder != nil {
|
||||||
|
switch *op.OperandOrder {
|
||||||
|
case "21Type1", "231Type1":
|
||||||
|
// Permute uses operand[1] for method receiver.
|
||||||
|
return op.Go + *op.In[1].Go
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if op.In[0].Class == "immediate" {
|
||||||
|
return op.Go + *op.In[1].Go
|
||||||
|
}
|
||||||
|
return op.Go + *op.In[0].Go
|
||||||
|
}
|
||||||
|
|
||||||
|
// dedupGodef is deduping operations in [Op.Go]+[*Op.In[0].Go] level.
|
||||||
|
// By deduping, it means picking the least advanced architecture that satisfy the requirement:
|
||||||
|
// AVX512 will be least preferred.
|
||||||
|
// If FlagNoDedup is set, it will report the duplicates to the console.
|
||||||
|
func dedupGodef(ops []Operation) ([]Operation, error) {
|
||||||
|
seen := map[string][]Operation{}
|
||||||
|
for _, op := range ops {
|
||||||
|
_, _, _, _, gOp := op.shape()
|
||||||
|
|
||||||
|
gN := gOp.GenericName()
|
||||||
|
seen[gN] = append(seen[gN], op)
|
||||||
|
}
|
||||||
|
if *FlagReportDup {
|
||||||
|
for gName, dup := range seen {
|
||||||
|
if len(dup) > 1 {
|
||||||
|
log.Printf("Duplicate for %s:\n", gName)
|
||||||
|
for _, op := range dup {
|
||||||
|
log.Printf("%s\n", op)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ops, nil
|
||||||
|
}
|
||||||
|
isAVX512 := func(op Operation) bool {
|
||||||
|
return strings.Contains(op.CPUFeature, "AVX512")
|
||||||
|
}
|
||||||
|
deduped := []Operation{}
|
||||||
|
for _, dup := range seen {
|
||||||
|
if len(dup) > 1 {
|
||||||
|
slices.SortFunc(dup, func(i, j Operation) int {
|
||||||
|
// Put non-AVX512 candidates at the beginning
|
||||||
|
if !isAVX512(i) && isAVX512(j) {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if isAVX512(i) && !isAVX512(j) {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return strings.Compare(i.CPUFeature, j.CPUFeature)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
deduped = append(deduped, dup[0])
|
||||||
|
}
|
||||||
|
slices.SortFunc(deduped, compareOperations)
|
||||||
|
return deduped, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy op.ConstImm to op.In[0].Const
|
||||||
|
// This is a hack to reduce the size of defs we need for const imm operations.
|
||||||
|
func copyConstImm(ops []Operation) error {
|
||||||
|
for _, op := range ops {
|
||||||
|
if op.ConstImm == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, _, _, immType, _ := op.shape()
|
||||||
|
|
||||||
|
if immType == ConstImm || immType == ConstVarImm {
|
||||||
|
op.In[0].Const = op.ConstImm
|
||||||
|
}
|
||||||
|
// Otherwise, just not port it - e.g. {VPCMP[BWDQ] imm=0} and {VPCMPEQ[BWDQ]} are
|
||||||
|
// the same operations "Equal", [dedupgodef] should be able to distinguish them.
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func capitalizeFirst(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// Convert the string to a slice of runes to handle multi-byte characters correctly.
|
||||||
|
r := []rune(s)
|
||||||
|
r[0] = unicode.ToUpper(r[0])
|
||||||
|
return string(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// overwrite corrects some errors due to:
|
||||||
|
// - The XED data is wrong
|
||||||
|
// - Go's SIMD API requirement, for example AVX2 compares should also produce masks.
|
||||||
|
// This rewrite has strict constraints, please see the error message.
|
||||||
|
// These constraints are also explointed in [writeSIMDRules], [writeSIMDMachineOps]
|
||||||
|
// and [writeSIMDSSA], please be careful when updating these constraints.
|
||||||
|
func overwrite(ops []Operation) error {
|
||||||
|
hasClassOverwrite := false
|
||||||
|
overwrite := func(op []Operand, idx int, o Operation) error {
|
||||||
|
if op[idx].OverwriteElementBits != nil {
|
||||||
|
if op[idx].ElemBits == nil {
|
||||||
|
panic(fmt.Errorf("ElemBits is nil at operand %d of %v", idx, o))
|
||||||
|
}
|
||||||
|
*op[idx].ElemBits = *op[idx].OverwriteElementBits
|
||||||
|
*op[idx].Lanes = *op[idx].Bits / *op[idx].ElemBits
|
||||||
|
*op[idx].Go = fmt.Sprintf("%s%dx%d", capitalizeFirst(*op[idx].Base), *op[idx].ElemBits, *op[idx].Lanes)
|
||||||
|
}
|
||||||
|
if op[idx].OverwriteClass != nil {
|
||||||
|
if op[idx].OverwriteBase == nil {
|
||||||
|
panic(fmt.Errorf("simdgen: [OverwriteClass] must be set together with [OverwriteBase]: %s", op[idx]))
|
||||||
|
}
|
||||||
|
oBase := *op[idx].OverwriteBase
|
||||||
|
oClass := *op[idx].OverwriteClass
|
||||||
|
if oClass != "mask" {
|
||||||
|
panic(fmt.Errorf("simdgen: [Class] overwrite only supports overwritting to mask: %s", op[idx]))
|
||||||
|
}
|
||||||
|
if oBase != "int" {
|
||||||
|
panic(fmt.Errorf("simdgen: [Class] overwrite must set [OverwriteBase] to int: %s", op[idx]))
|
||||||
|
}
|
||||||
|
if op[idx].Class != "vreg" {
|
||||||
|
panic(fmt.Errorf("simdgen: [Class] overwrite must be overwriting [Class] from vreg: %s", op[idx]))
|
||||||
|
}
|
||||||
|
hasClassOverwrite = true
|
||||||
|
*op[idx].Base = oBase
|
||||||
|
op[idx].Class = oClass
|
||||||
|
*op[idx].Go = fmt.Sprintf("Mask%dx%d", *op[idx].ElemBits, *op[idx].Lanes)
|
||||||
|
} else if op[idx].OverwriteBase != nil {
|
||||||
|
oBase := *op[idx].OverwriteBase
|
||||||
|
*op[idx].Go = strings.ReplaceAll(*op[idx].Go, capitalizeFirst(*op[idx].Base), capitalizeFirst(oBase))
|
||||||
|
if op[idx].Class == "greg" {
|
||||||
|
*op[idx].Go = strings.ReplaceAll(*op[idx].Go, *op[idx].Base, oBase)
|
||||||
|
}
|
||||||
|
*op[idx].Base = oBase
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for i, o := range ops {
|
||||||
|
hasClassOverwrite = false
|
||||||
|
for j := range ops[i].In {
|
||||||
|
if err := overwrite(ops[i].In, j, o); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if hasClassOverwrite {
|
||||||
|
return fmt.Errorf("simdgen does not support [OverwriteClass] in inputs: %s", ops[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j := range ops[i].Out {
|
||||||
|
if err := overwrite(ops[i].Out, j, o); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasClassOverwrite {
|
||||||
|
for _, in := range ops[i].In {
|
||||||
|
if in.Class == "mask" {
|
||||||
|
return fmt.Errorf("simdgen only supports [OverwriteClass] for operations without mask inputs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// reportXEDInconsistency reports potential XED inconsistencies.
|
||||||
|
// We can add more fields to [Operation] to enable more checks and implement it here.
|
||||||
|
// Supported checks:
|
||||||
|
// [NameAndSizeCheck]: NAME[BWDQ] should set the elemBits accordingly.
|
||||||
|
// This check is useful to find inconsistencies, then we can add overwrite fields to
|
||||||
|
// those defs to correct them manually.
|
||||||
|
func reportXEDInconsistency(ops []Operation) error {
|
||||||
|
for _, o := range ops {
|
||||||
|
if o.NameAndSizeCheck != nil {
|
||||||
|
suffixSizeMap := map[byte]int{'B': 8, 'W': 16, 'D': 32, 'Q': 64}
|
||||||
|
checkOperand := func(opr Operand) error {
|
||||||
|
if opr.ElemBits == nil {
|
||||||
|
return fmt.Errorf("simdgen expects elemBits to be set when performing NameAndSizeCheck")
|
||||||
|
}
|
||||||
|
if v, ok := suffixSizeMap[o.Asm[len(o.Asm)-1]]; !ok {
|
||||||
|
return fmt.Errorf("simdgen expects asm to end with [BWDQ] when performing NameAndSizeCheck")
|
||||||
|
} else {
|
||||||
|
if v != *opr.ElemBits {
|
||||||
|
return fmt.Errorf("simdgen finds NameAndSizeCheck inconsistency in def: %s", o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, in := range o.In {
|
||||||
|
if in.Class != "vreg" && in.Class != "mask" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if in.TreatLikeAScalarOfSize != nil {
|
||||||
|
// This is an irregular operand, don't check it.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := checkOperand(in); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, out := range o.Out {
|
||||||
|
if err := checkOperand(out); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Operation) String() string {
|
||||||
|
return pprints(o)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (op Operand) String() string {
|
||||||
|
return pprints(op)
|
||||||
|
}
|
||||||
1
src/simd/_gen/simdgen/go.yaml
Normal file
1
src/simd/_gen/simdgen/go.yaml
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
!import ops/*/go.yaml
|
||||||
379
src/simd/_gen/simdgen/godefs.go
Normal file
379
src/simd/_gen/simdgen/godefs.go
Normal file
|
|
@ -0,0 +1,379 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"regexp"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"simd/_gen/unify"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Operation struct {
|
||||||
|
rawOperation
|
||||||
|
|
||||||
|
// Go is the Go method name of this operation.
|
||||||
|
//
|
||||||
|
// It is derived from the raw Go method name by adding optional suffixes.
|
||||||
|
// Currently, "Masked" is the only suffix.
|
||||||
|
Go string
|
||||||
|
|
||||||
|
// Documentation is the doc string for this API.
|
||||||
|
//
|
||||||
|
// It is computed from the raw documentation:
|
||||||
|
//
|
||||||
|
// - "NAME" is replaced by the Go method name.
|
||||||
|
//
|
||||||
|
// - For masked operation, a sentence about masking is added.
|
||||||
|
Documentation string
|
||||||
|
|
||||||
|
// In is the sequence of parameters to the Go method.
|
||||||
|
//
|
||||||
|
// For masked operations, this will have the mask operand appended.
|
||||||
|
In []Operand
|
||||||
|
}
|
||||||
|
|
||||||
|
// rawOperation is the unifier representation of an [Operation]. It is
|
||||||
|
// translated into a more parsed form after unifier decoding.
|
||||||
|
type rawOperation struct {
|
||||||
|
Go string // Base Go method name
|
||||||
|
|
||||||
|
GoArch string // GOARCH for this definition
|
||||||
|
Asm string // Assembly mnemonic
|
||||||
|
OperandOrder *string // optional Operand order for better Go declarations
|
||||||
|
// Optional tag to indicate this operation is paired with special generic->machine ssa lowering rules.
|
||||||
|
// Should be paired with special templates in gen_simdrules.go
|
||||||
|
SpecialLower *string
|
||||||
|
|
||||||
|
In []Operand // Parameters
|
||||||
|
InVariant []Operand // Optional parameters
|
||||||
|
Out []Operand // Results
|
||||||
|
Commutative bool // Commutativity
|
||||||
|
CPUFeature string // CPUID/Has* feature name
|
||||||
|
Zeroing *bool // nil => use asm suffix ".Z"; false => do not use asm suffix ".Z"
|
||||||
|
Documentation *string // Documentation will be appended to the stubs comments.
|
||||||
|
// ConstMask is a hack to reduce the size of defs the user writes for const-immediate
|
||||||
|
// If present, it will be copied to [In[0].Const].
|
||||||
|
ConstImm *string
|
||||||
|
// NameAndSizeCheck is used to check [BWDQ] maps to (8|16|32|64) elemBits.
|
||||||
|
NameAndSizeCheck *bool
|
||||||
|
// If non-nil, all generation in gen_simdTypes.go and gen_intrinsics will be skipped.
|
||||||
|
NoTypes *string
|
||||||
|
// If non-nil, all generation in gen_simdGenericOps and gen_simdrules will be skipped.
|
||||||
|
NoGenericOps *string
|
||||||
|
// If non-nil, this string will be attached to the machine ssa op name.
|
||||||
|
SSAVariant *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Operation) DecodeUnified(v *unify.Value) error {
|
||||||
|
if err := v.Decode(&o.rawOperation); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
isMasked := false
|
||||||
|
if len(o.InVariant) == 0 {
|
||||||
|
// No variant
|
||||||
|
} else if len(o.InVariant) == 1 && o.InVariant[0].Class == "mask" {
|
||||||
|
isMasked = true
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("unknown inVariant")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute full Go method name.
|
||||||
|
o.Go = o.rawOperation.Go
|
||||||
|
if isMasked {
|
||||||
|
o.Go += "Masked"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute doc string.
|
||||||
|
if o.rawOperation.Documentation != nil {
|
||||||
|
o.Documentation = *o.rawOperation.Documentation
|
||||||
|
} else {
|
||||||
|
o.Documentation = "// UNDOCUMENTED"
|
||||||
|
}
|
||||||
|
o.Documentation = regexp.MustCompile(`\bNAME\b`).ReplaceAllString(o.Documentation, o.Go)
|
||||||
|
if isMasked {
|
||||||
|
o.Documentation += "\n//\n// This operation is applied selectively under a write mask."
|
||||||
|
}
|
||||||
|
|
||||||
|
o.In = append(o.rawOperation.In, o.rawOperation.InVariant...)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Operation) VectorWidth() int {
|
||||||
|
out := o.Out[0]
|
||||||
|
if out.Class == "vreg" {
|
||||||
|
return *out.Bits
|
||||||
|
} else if out.Class == "greg" || out.Class == "mask" {
|
||||||
|
for i := range o.In {
|
||||||
|
if o.In[i].Class == "vreg" {
|
||||||
|
return *o.In[i].Bits
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
panic(fmt.Errorf("Figure out what the vector width is for %v and implement it", *o))
|
||||||
|
}
|
||||||
|
|
||||||
|
func machineOpName(maskType maskShape, gOp Operation) string {
|
||||||
|
asm := gOp.Asm
|
||||||
|
if maskType == 2 {
|
||||||
|
asm += "Masked"
|
||||||
|
}
|
||||||
|
asm = fmt.Sprintf("%s%d", asm, gOp.VectorWidth())
|
||||||
|
if gOp.SSAVariant != nil {
|
||||||
|
asm += *gOp.SSAVariant
|
||||||
|
}
|
||||||
|
return asm
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareStringPointers(x, y *string) int {
|
||||||
|
if x != nil && y != nil {
|
||||||
|
return compareNatural(*x, *y)
|
||||||
|
}
|
||||||
|
if x == nil && y == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if x == nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareIntPointers(x, y *int) int {
|
||||||
|
if x != nil && y != nil {
|
||||||
|
return *x - *y
|
||||||
|
}
|
||||||
|
if x == nil && y == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if x == nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareOperations(x, y Operation) int {
|
||||||
|
if c := compareNatural(x.Go, y.Go); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
xIn, yIn := x.In, y.In
|
||||||
|
|
||||||
|
if len(xIn) > len(yIn) && xIn[len(xIn)-1].Class == "mask" {
|
||||||
|
xIn = xIn[:len(xIn)-1]
|
||||||
|
} else if len(xIn) < len(yIn) && yIn[len(yIn)-1].Class == "mask" {
|
||||||
|
yIn = yIn[:len(yIn)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(xIn) < len(yIn) {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if len(xIn) > len(yIn) {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if len(x.Out) < len(y.Out) {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if len(x.Out) > len(y.Out) {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
for i := range xIn {
|
||||||
|
ox, oy := &xIn[i], &yIn[i]
|
||||||
|
if c := compareOperands(ox, oy); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareOperands(x, y *Operand) int {
|
||||||
|
if c := compareNatural(x.Class, y.Class); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
if x.Class == "immediate" {
|
||||||
|
return compareStringPointers(x.ImmOffset, y.ImmOffset)
|
||||||
|
} else {
|
||||||
|
if c := compareStringPointers(x.Base, y.Base); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
if c := compareIntPointers(x.ElemBits, y.ElemBits); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
if c := compareIntPointers(x.Bits, y.Bits); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Operand struct {
|
||||||
|
Class string // One of "mask", "immediate", "vreg", "greg", and "mem"
|
||||||
|
|
||||||
|
Go *string // Go type of this operand
|
||||||
|
AsmPos int // Position of this operand in the assembly instruction
|
||||||
|
|
||||||
|
Base *string // Base Go type ("int", "uint", "float")
|
||||||
|
ElemBits *int // Element bit width
|
||||||
|
Bits *int // Total vector bit width
|
||||||
|
|
||||||
|
Const *string // Optional constant value for immediates.
|
||||||
|
// Optional immediate arg offsets. If this field is non-nil,
|
||||||
|
// This operand will be an immediate operand:
|
||||||
|
// The compiler will right-shift the user-passed value by ImmOffset and set it as the AuxInt
|
||||||
|
// field of the operation.
|
||||||
|
ImmOffset *string
|
||||||
|
Name *string // optional name in the Go intrinsic declaration
|
||||||
|
Lanes *int // *Lanes equals Bits/ElemBits except for scalars, when *Lanes == 1
|
||||||
|
// TreatLikeAScalarOfSize means only the lower $TreatLikeAScalarOfSize bits of the vector
|
||||||
|
// is used, so at the API level we can make it just a scalar value of this size; Then we
|
||||||
|
// can overwrite it to a vector of the right size during intrinsics stage.
|
||||||
|
TreatLikeAScalarOfSize *int
|
||||||
|
// If non-nil, it means the [Class] field is overwritten here, right now this is used to
|
||||||
|
// overwrite the results of AVX2 compares to masks.
|
||||||
|
OverwriteClass *string
|
||||||
|
// If non-nil, it means the [Base] field is overwritten here. This field exist solely
|
||||||
|
// because Intel's XED data is inconsistent. e.g. VANDNP[SD] marks its operand int.
|
||||||
|
OverwriteBase *string
|
||||||
|
// If non-nil, it means the [ElementBits] field is overwritten. This field exist solely
|
||||||
|
// because Intel's XED data is inconsistent. e.g. AVX512 VPMADDUBSW marks its operand
|
||||||
|
// elemBits 16, which should be 8.
|
||||||
|
OverwriteElementBits *int
|
||||||
|
}
|
||||||
|
|
||||||
|
// isDigit returns true if the byte is an ASCII digit.
|
||||||
|
func isDigit(b byte) bool {
|
||||||
|
return b >= '0' && b <= '9'
|
||||||
|
}
|
||||||
|
|
||||||
|
// compareNatural performs a "natural sort" comparison of two strings.
|
||||||
|
// It compares non-digit sections lexicographically and digit sections
|
||||||
|
// numerically. In the case of string-unequal "equal" strings like
|
||||||
|
// "a01b" and "a1b", strings.Compare breaks the tie.
|
||||||
|
//
|
||||||
|
// It returns:
|
||||||
|
//
|
||||||
|
// -1 if s1 < s2
|
||||||
|
// 0 if s1 == s2
|
||||||
|
// +1 if s1 > s2
|
||||||
|
func compareNatural(s1, s2 string) int {
|
||||||
|
i, j := 0, 0
|
||||||
|
len1, len2 := len(s1), len(s2)
|
||||||
|
|
||||||
|
for i < len1 && j < len2 {
|
||||||
|
// Find a non-digit segment or a number segment in both strings.
|
||||||
|
if isDigit(s1[i]) && isDigit(s2[j]) {
|
||||||
|
// Number segment comparison.
|
||||||
|
numStart1 := i
|
||||||
|
for i < len1 && isDigit(s1[i]) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
num1, _ := strconv.Atoi(s1[numStart1:i])
|
||||||
|
|
||||||
|
numStart2 := j
|
||||||
|
for j < len2 && isDigit(s2[j]) {
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
num2, _ := strconv.Atoi(s2[numStart2:j])
|
||||||
|
|
||||||
|
if num1 < num2 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if num1 > num2 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
// If numbers are equal, continue to the next segment.
|
||||||
|
} else {
|
||||||
|
// Non-digit comparison.
|
||||||
|
if s1[i] < s2[j] {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if s1[i] > s2[j] {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// deal with a01b vs a1b; there needs to be an order.
|
||||||
|
return strings.Compare(s1, s2)
|
||||||
|
}
|
||||||
|
|
||||||
|
const generatedHeader = `// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
|
||||||
|
`
|
||||||
|
|
||||||
|
func writeGoDefs(path string, cl unify.Closure) error {
|
||||||
|
// TODO: Merge operations with the same signature but multiple
|
||||||
|
// implementations (e.g., SSE vs AVX)
|
||||||
|
var ops []Operation
|
||||||
|
for def := range cl.All() {
|
||||||
|
var op Operation
|
||||||
|
if !def.Exact() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := def.Decode(&op); err != nil {
|
||||||
|
log.Println(err.Error())
|
||||||
|
log.Println(def)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// TODO: verify that this is safe.
|
||||||
|
op.sortOperand()
|
||||||
|
ops = append(ops, op)
|
||||||
|
}
|
||||||
|
slices.SortFunc(ops, compareOperations)
|
||||||
|
// The parsed XED data might contain duplicates, like
|
||||||
|
// 512 bits VPADDP.
|
||||||
|
deduped := dedup(ops)
|
||||||
|
slices.SortFunc(deduped, compareOperations)
|
||||||
|
|
||||||
|
if *Verbose {
|
||||||
|
log.Printf("dedup len: %d\n", len(ops))
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
if err = overwrite(deduped); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if *Verbose {
|
||||||
|
log.Printf("dedup len: %d\n", len(deduped))
|
||||||
|
}
|
||||||
|
if *Verbose {
|
||||||
|
log.Printf("dedup len: %d\n", len(deduped))
|
||||||
|
}
|
||||||
|
if !*FlagNoDedup {
|
||||||
|
// TODO: This can hide mistakes in the API definitions, especially when
|
||||||
|
// multiple patterns result in the same API unintentionally. Make it stricter.
|
||||||
|
if deduped, err = dedupGodef(deduped); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if *Verbose {
|
||||||
|
log.Printf("dedup len: %d\n", len(deduped))
|
||||||
|
}
|
||||||
|
if !*FlagNoConstImmPorting {
|
||||||
|
if err = copyConstImm(deduped); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if *Verbose {
|
||||||
|
log.Printf("dedup len: %d\n", len(deduped))
|
||||||
|
}
|
||||||
|
reportXEDInconsistency(deduped)
|
||||||
|
typeMap := parseSIMDTypes(deduped)
|
||||||
|
|
||||||
|
formatWriteAndClose(writeSIMDTypes(typeMap), path, "src/"+simdPackage+"/types_amd64.go")
|
||||||
|
formatWriteAndClose(writeSIMDFeatures(deduped), path, "src/"+simdPackage+"/cpu.go")
|
||||||
|
formatWriteAndClose(writeSIMDStubs(deduped, typeMap), path, "src/"+simdPackage+"/ops_amd64.go")
|
||||||
|
formatWriteAndClose(writeSIMDIntrinsics(deduped, typeMap), path, "src/cmd/compile/internal/ssagen/simdintrinsics.go")
|
||||||
|
formatWriteAndClose(writeSIMDGenericOps(deduped), path, "src/cmd/compile/internal/ssa/_gen/simdgenericOps.go")
|
||||||
|
formatWriteAndClose(writeSIMDMachineOps(deduped), path, "src/cmd/compile/internal/ssa/_gen/simdAMD64ops.go")
|
||||||
|
formatWriteAndClose(writeSIMDSSA(deduped), path, "src/cmd/compile/internal/amd64/simdssa.go")
|
||||||
|
writeAndClose(writeSIMDRules(deduped).Bytes(), path, "src/cmd/compile/internal/ssa/_gen/simdAMD64.rules")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
280
src/simd/_gen/simdgen/main.go
Normal file
280
src/simd/_gen/simdgen/main.go
Normal file
|
|
@ -0,0 +1,280 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// simdgen is an experiment in generating Go <-> asm SIMD mappings.
|
||||||
|
//
|
||||||
|
// Usage: simdgen [-xedPath=path] [-q=query] input.yaml...
|
||||||
|
//
|
||||||
|
// If -xedPath is provided, one of the inputs is a sum of op-code definitions
|
||||||
|
// generated from the Intel XED data at path.
|
||||||
|
//
|
||||||
|
// If input YAML files are provided, each file is read as an input value. See
|
||||||
|
// [unify.Closure.UnmarshalYAML] or "go doc unify.Closure.UnmarshalYAML" for the
|
||||||
|
// format of these files.
|
||||||
|
//
|
||||||
|
// TODO: Example definitions and values.
|
||||||
|
//
|
||||||
|
// The command unifies across all of the inputs and prints all possible results
|
||||||
|
// of this unification.
|
||||||
|
//
|
||||||
|
// If the -q flag is provided, its string value is parsed as a value and treated
|
||||||
|
// as another input to unification. This is intended as a way to "query" the
|
||||||
|
// result, typically by narrowing it down to a small subset of results.
|
||||||
|
//
|
||||||
|
// Typical usage:
|
||||||
|
//
|
||||||
|
// go run . -xedPath $XEDPATH *.yaml
|
||||||
|
//
|
||||||
|
// To see just the definitions generated from XED, run:
|
||||||
|
//
|
||||||
|
// go run . -xedPath $XEDPATH
|
||||||
|
//
|
||||||
|
// (This works because if there's only one input, there's nothing to unify it
|
||||||
|
// with, so the result is simply itself.)
|
||||||
|
//
|
||||||
|
// To see just the definitions for VPADDQ:
|
||||||
|
//
|
||||||
|
// go run . -xedPath $XEDPATH -q '{asm: VPADDQ}'
|
||||||
|
//
|
||||||
|
// simdgen can also generate Go definitions of SIMD mappings:
|
||||||
|
// To generate go files to the go root, run:
|
||||||
|
//
|
||||||
|
// go run . -xedPath $XEDPATH -o godefs -goroot $PATH/TO/go go.yaml categories.yaml types.yaml
|
||||||
|
//
|
||||||
|
// types.yaml is already written, it specifies the shapes of vectors.
|
||||||
|
// categories.yaml and go.yaml contains definitions that unifies with types.yaml and XED
|
||||||
|
// data, you can find an example in ops/AddSub/.
|
||||||
|
//
|
||||||
|
// When generating Go definitions, simdgen do 3 "magic"s:
|
||||||
|
// - It splits masked operations(with op's [Masked] field set) to const and non const:
|
||||||
|
// - One is a normal masked operation, the original
|
||||||
|
// - The other has its mask operand's [Const] fields set to "K0".
|
||||||
|
// - This way the user does not need to provide a separate "K0"-masked operation def.
|
||||||
|
//
|
||||||
|
// - It deduplicates intrinsic names that have duplicates:
|
||||||
|
// - If there are two operations that shares the same signature, one is AVX512 the other
|
||||||
|
// is before AVX512, the other will be selected.
|
||||||
|
// - This happens often when some operations are defined both before AVX512 and after.
|
||||||
|
// This way the user does not need to provide a separate "K0" operation for the
|
||||||
|
// AVX512 counterpart.
|
||||||
|
//
|
||||||
|
// - It copies the op's [ConstImm] field to its immediate operand's [Const] field.
|
||||||
|
// - This way the user does not need to provide verbose op definition while only
|
||||||
|
// the const immediate field is different. This is useful to reduce verbosity of
|
||||||
|
// compares with imm control predicates.
|
||||||
|
//
|
||||||
|
// These 3 magics could be disabled by enabling -nosplitmask, -nodedup or
|
||||||
|
// -noconstimmporting flags.
|
||||||
|
//
|
||||||
|
// simdgen right now only supports amd64, -arch=$OTHERARCH will trigger a fatal error.
|
||||||
|
package main
|
||||||
|
|
||||||
|
// Big TODOs:
|
||||||
|
//
|
||||||
|
// - This can produce duplicates, which can also lead to less efficient
|
||||||
|
// environment merging. Add hashing and use it for deduplication. Be careful
|
||||||
|
// about how this shows up in debug traces, since it could make things
|
||||||
|
// confusing if we don't show it happening.
|
||||||
|
//
|
||||||
|
// - Do I need Closure, Value, and Domain? It feels like I should only need two
|
||||||
|
// types.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"maps"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime/pprof"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
"simd/_gen/unify"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
xedPath = flag.String("xedPath", "", "load XED datafiles from `path`")
|
||||||
|
flagQ = flag.String("q", "", "query: read `def` as another input (skips final validation)")
|
||||||
|
flagO = flag.String("o", "yaml", "output type: yaml, godefs (generate definitions into a Go source tree")
|
||||||
|
flagGoDefRoot = flag.String("goroot", ".", "the path to the Go dev directory that will receive the generated files")
|
||||||
|
FlagNoDedup = flag.Bool("nodedup", false, "disable deduplicating godefs of 2 qualifying operations from different extensions")
|
||||||
|
FlagNoConstImmPorting = flag.Bool("noconstimmporting", false, "disable const immediate porting from op to imm operand")
|
||||||
|
FlagArch = flag.String("arch", "amd64", "the target architecture")
|
||||||
|
|
||||||
|
Verbose = flag.Bool("v", false, "verbose")
|
||||||
|
|
||||||
|
flagDebugXED = flag.Bool("debug-xed", false, "show XED instructions")
|
||||||
|
flagDebugUnify = flag.Bool("debug-unify", false, "print unification trace")
|
||||||
|
flagDebugHTML = flag.String("debug-html", "", "write unification trace to `file.html`")
|
||||||
|
FlagReportDup = flag.Bool("reportdup", false, "report the duplicate godefs")
|
||||||
|
|
||||||
|
flagCPUProfile = flag.String("cpuprofile", "", "write CPU profile to `file`")
|
||||||
|
flagMemProfile = flag.String("memprofile", "", "write memory profile to `file`")
|
||||||
|
)
|
||||||
|
|
||||||
|
const simdPackage = "simd"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if *flagCPUProfile != "" {
|
||||||
|
f, err := os.Create(*flagCPUProfile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("-cpuprofile: %s", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
pprof.StartCPUProfile(f)
|
||||||
|
defer pprof.StopCPUProfile()
|
||||||
|
}
|
||||||
|
if *flagMemProfile != "" {
|
||||||
|
f, err := os.Create(*flagMemProfile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("-memprofile: %s", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
pprof.WriteHeapProfile(f)
|
||||||
|
f.Close()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
var inputs []unify.Closure
|
||||||
|
|
||||||
|
if *FlagArch != "amd64" {
|
||||||
|
log.Fatalf("simdgen only supports amd64")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load XED into a defs set.
|
||||||
|
if *xedPath != "" {
|
||||||
|
xedDefs := loadXED(*xedPath)
|
||||||
|
inputs = append(inputs, unify.NewSum(xedDefs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load query.
|
||||||
|
if *flagQ != "" {
|
||||||
|
r := strings.NewReader(*flagQ)
|
||||||
|
def, err := unify.Read(r, "<query>", unify.ReadOpts{})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("parsing -q: %s", err)
|
||||||
|
}
|
||||||
|
inputs = append(inputs, def)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load defs files.
|
||||||
|
must := make(map[*unify.Value]struct{})
|
||||||
|
for _, path := range flag.Args() {
|
||||||
|
defs, err := unify.ReadFile(path, unify.ReadOpts{})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
inputs = append(inputs, defs)
|
||||||
|
|
||||||
|
if filepath.Base(path) == "go.yaml" {
|
||||||
|
// These must all be used in the final result
|
||||||
|
for def := range defs.Summands() {
|
||||||
|
must[def] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare for unification
|
||||||
|
if *flagDebugUnify {
|
||||||
|
unify.Debug.UnifyLog = os.Stderr
|
||||||
|
}
|
||||||
|
if *flagDebugHTML != "" {
|
||||||
|
f, err := os.Create(*flagDebugHTML)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
unify.Debug.HTML = f
|
||||||
|
defer f.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unify!
|
||||||
|
unified, err := unify.Unify(inputs...)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print results.
|
||||||
|
switch *flagO {
|
||||||
|
case "yaml":
|
||||||
|
// Produce a result that looks like encoding a slice, but stream it.
|
||||||
|
fmt.Println("!sum")
|
||||||
|
var val1 [1]*unify.Value
|
||||||
|
for val := range unified.All() {
|
||||||
|
val1[0] = val
|
||||||
|
// We have to make a new encoder each time or it'll print a document
|
||||||
|
// separator between each object.
|
||||||
|
enc := yaml.NewEncoder(os.Stdout)
|
||||||
|
if err := enc.Encode(val1); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
enc.Close()
|
||||||
|
}
|
||||||
|
case "godefs":
|
||||||
|
if err := writeGoDefs(*flagGoDefRoot, unified); err != nil {
|
||||||
|
log.Fatalf("Failed writing godefs: %+v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !*Verbose && *xedPath != "" {
|
||||||
|
if operandRemarks == 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "XED decoding generated no errors, which is unusual.\n")
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "XED decoding generated %d \"errors\" which is not cause for alarm, use -v for details.\n", operandRemarks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate results.
|
||||||
|
//
|
||||||
|
// Don't validate if this is a command-line query because that tends to
|
||||||
|
// eliminate lots of required defs and is used in cases where maybe defs
|
||||||
|
// aren't enumerable anyway.
|
||||||
|
if *flagQ == "" && len(must) > 0 {
|
||||||
|
validate(unified, must)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validate(cl unify.Closure, required map[*unify.Value]struct{}) {
|
||||||
|
// Validate that:
|
||||||
|
// 1. All final defs are exact
|
||||||
|
// 2. All required defs are used
|
||||||
|
for def := range cl.All() {
|
||||||
|
if _, ok := def.Domain.(unify.Def); !ok {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s: expected Def, got %T\n", def.PosString(), def.Domain)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !def.Exact() {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s: def not reduced to an exact value, why is %s:\n", def.PosString(), def.WhyNotExact())
|
||||||
|
fmt.Fprintf(os.Stderr, "\t%s\n", strings.ReplaceAll(def.String(), "\n", "\n\t"))
|
||||||
|
}
|
||||||
|
|
||||||
|
for root := range def.Provenance() {
|
||||||
|
delete(required, root)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Report unused defs
|
||||||
|
unused := slices.SortedFunc(maps.Keys(required),
|
||||||
|
func(a, b *unify.Value) int {
|
||||||
|
return cmp.Or(
|
||||||
|
cmp.Compare(a.Pos().Path, b.Pos().Path),
|
||||||
|
cmp.Compare(a.Pos().Line, b.Pos().Line),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
for _, def := range unused {
|
||||||
|
// TODO: Can we say anything more actionable? This is always a problem
|
||||||
|
// with unification: if it fails, it's very hard to point a finger at
|
||||||
|
// any particular reason. We could go back and try unifying this again
|
||||||
|
// with each subset of the inputs (starting with individual inputs) to
|
||||||
|
// at least say "it doesn't unify with anything in x.yaml". That's a lot
|
||||||
|
// of work, but if we have trouble debugging unification failure it may
|
||||||
|
// be worth it.
|
||||||
|
fmt.Fprintf(os.Stderr, "%s: def required, but did not unify (%v)\n",
|
||||||
|
def.PosString(), def)
|
||||||
|
}
|
||||||
|
}
|
||||||
37
src/simd/_gen/simdgen/ops/AddSub/categories.yaml
Normal file
37
src/simd/_gen/simdgen/ops/AddSub/categories.yaml
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
!sum
|
||||||
|
- go: Add
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME adds corresponding elements of two vectors.
|
||||||
|
- go: AddSaturated
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME adds corresponding elements of two vectors with saturation.
|
||||||
|
- go: Sub
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME subtracts corresponding elements of two vectors.
|
||||||
|
- go: SubSaturated
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME subtracts corresponding elements of two vectors with saturation.
|
||||||
|
- go: AddPairs
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME horizontally adds adjacent pairs of elements.
|
||||||
|
// For x = [x0, x1, x2, x3, ...] and y = [y0, y1, y2, y3, ...], the result is [y0+y1, y2+y3, ..., x0+x1, x2+x3, ...].
|
||||||
|
- go: SubPairs
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME horizontally subtracts adjacent pairs of elements.
|
||||||
|
// For x = [x0, x1, x2, x3, ...] and y = [y0, y1, y2, y3, ...], the result is [y0-y1, y2-y3, ..., x0-x1, x2-x3, ...].
|
||||||
|
- go: AddPairsSaturated
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME horizontally adds adjacent pairs of elements with saturation.
|
||||||
|
// For x = [x0, x1, x2, x3, ...] and y = [y0, y1, y2, y3, ...], the result is [y0+y1, y2+y3, ..., x0+x1, x2+x3, ...].
|
||||||
|
- go: SubPairsSaturated
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME horizontally subtracts adjacent pairs of elements with saturation.
|
||||||
|
// For x = [x0, x1, x2, x3, ...] and y = [y0, y1, y2, y3, ...], the result is [y0-y1, y2-y3, ..., x0-x1, x2-x3, ...].
|
||||||
77
src/simd/_gen/simdgen/ops/AddSub/go.yaml
Normal file
77
src/simd/_gen/simdgen/ops/AddSub/go.yaml
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
!sum
|
||||||
|
# Add
|
||||||
|
- go: Add
|
||||||
|
asm: "VPADD[BWDQ]|VADDP[SD]"
|
||||||
|
in:
|
||||||
|
- &any
|
||||||
|
go: $t
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
# Add Saturated
|
||||||
|
- go: AddSaturated
|
||||||
|
asm: "VPADDS[BWDQ]"
|
||||||
|
in:
|
||||||
|
- &int
|
||||||
|
go: $t
|
||||||
|
base: int
|
||||||
|
- *int
|
||||||
|
out:
|
||||||
|
- *int
|
||||||
|
- go: AddSaturated
|
||||||
|
asm: "VPADDUS[BWDQ]"
|
||||||
|
in:
|
||||||
|
- &uint
|
||||||
|
go: $t
|
||||||
|
base: uint
|
||||||
|
- *uint
|
||||||
|
out:
|
||||||
|
- *uint
|
||||||
|
|
||||||
|
# Sub
|
||||||
|
- go: Sub
|
||||||
|
asm: "VPSUB[BWDQ]|VSUBP[SD]"
|
||||||
|
in: &2any
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
out: &1any
|
||||||
|
- *any
|
||||||
|
# Sub Saturated
|
||||||
|
- go: SubSaturated
|
||||||
|
asm: "VPSUBS[BWDQ]"
|
||||||
|
in: &2int
|
||||||
|
- *int
|
||||||
|
- *int
|
||||||
|
out: &1int
|
||||||
|
- *int
|
||||||
|
- go: SubSaturated
|
||||||
|
asm: "VPSUBUS[BWDQ]"
|
||||||
|
in:
|
||||||
|
- *uint
|
||||||
|
- *uint
|
||||||
|
out:
|
||||||
|
- *uint
|
||||||
|
- go: AddPairs
|
||||||
|
asm: "VPHADD[DW]"
|
||||||
|
in: *2any
|
||||||
|
out: *1any
|
||||||
|
- go: SubPairs
|
||||||
|
asm: "VPHSUB[DW]"
|
||||||
|
in: *2any
|
||||||
|
out: *1any
|
||||||
|
- go: AddPairs
|
||||||
|
asm: "VHADDP[SD]" # floats
|
||||||
|
in: *2any
|
||||||
|
out: *1any
|
||||||
|
- go: SubPairs
|
||||||
|
asm: "VHSUBP[SD]" # floats
|
||||||
|
in: *2any
|
||||||
|
out: *1any
|
||||||
|
- go: AddPairsSaturated
|
||||||
|
asm: "VPHADDS[DW]"
|
||||||
|
in: *2int
|
||||||
|
out: *1int
|
||||||
|
- go: SubPairsSaturated
|
||||||
|
asm: "VPHSUBS[DW]"
|
||||||
|
in: *2int
|
||||||
|
out: *1int
|
||||||
20
src/simd/_gen/simdgen/ops/BitwiseLogic/categories.yaml
Normal file
20
src/simd/_gen/simdgen/ops/BitwiseLogic/categories.yaml
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
!sum
|
||||||
|
- go: And
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs a bitwise AND operation between two vectors.
|
||||||
|
- go: Or
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs a bitwise OR operation between two vectors.
|
||||||
|
- go: AndNot
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs a bitwise x &^ y.
|
||||||
|
- go: Xor
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs a bitwise XOR operation between two vectors.
|
||||||
|
|
||||||
|
# We also have PTEST and VPTERNLOG, those should be hidden from the users
|
||||||
|
# and only appear in rewrite rules.
|
||||||
128
src/simd/_gen/simdgen/ops/BitwiseLogic/go.yaml
Normal file
128
src/simd/_gen/simdgen/ops/BitwiseLogic/go.yaml
Normal file
|
|
@ -0,0 +1,128 @@
|
||||||
|
!sum
|
||||||
|
# In the XED data, *all* floating point bitwise logic operation has their
|
||||||
|
# operand type marked as uint. We are not trying to understand why Intel
|
||||||
|
# decided that they want FP bit-wise logic operations, but this irregularity
|
||||||
|
# has to be dealed with in separate rules with some overwrites.
|
||||||
|
|
||||||
|
# For many bit-wise operations, we have the following non-orthogonal
|
||||||
|
# choices:
|
||||||
|
#
|
||||||
|
# - Non-masked AVX operations have no element width (because it
|
||||||
|
# doesn't matter), but only cover 128 and 256 bit vectors.
|
||||||
|
#
|
||||||
|
# - Masked AVX-512 operations have an element width (because it needs
|
||||||
|
# to know how to interpret the mask), and cover 128, 256, and 512 bit
|
||||||
|
# vectors. These only cover 32- and 64-bit element widths.
|
||||||
|
#
|
||||||
|
# - Non-masked AVX-512 operations still have an element width (because
|
||||||
|
# they're just the masked operations with an implicit K0 mask) but it
|
||||||
|
# doesn't matter! This is the only option for non-masked 512 bit
|
||||||
|
# operations, and we can pick any of the element widths.
|
||||||
|
#
|
||||||
|
# We unify with ALL of these operations and the compiler generator
|
||||||
|
# picks when there are multiple options.
|
||||||
|
|
||||||
|
# TODO: We don't currently generate unmasked bit-wise operations on 512 bit
|
||||||
|
# vectors of 8- or 16-bit elements. AVX-512 only has *masked* bit-wise
|
||||||
|
# operations for 32- and 64-bit elements; while the element width doesn't matter
|
||||||
|
# for unmasked operations, right now we don't realize that we can just use the
|
||||||
|
# 32- or 64-bit version for the unmasked form. Maybe in the XED decoder we
|
||||||
|
# should recognize bit-wise operations when generating unmasked versions and
|
||||||
|
# omit the element width.
|
||||||
|
|
||||||
|
# For binary operations, we constrain their two inputs and one output to the
|
||||||
|
# same Go type using a variable.
|
||||||
|
|
||||||
|
- go: And
|
||||||
|
asm: "VPAND[DQ]?"
|
||||||
|
in:
|
||||||
|
- &any
|
||||||
|
go: $t
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
|
||||||
|
- go: And
|
||||||
|
asm: "VPANDD" # Fill in the gap, And is missing for Uint8x64 and Int8x64
|
||||||
|
inVariant: []
|
||||||
|
in: &twoI8x64
|
||||||
|
- &i8x64
|
||||||
|
go: $t
|
||||||
|
overwriteElementBits: 8
|
||||||
|
- *i8x64
|
||||||
|
out: &oneI8x64
|
||||||
|
- *i8x64
|
||||||
|
|
||||||
|
- go: And
|
||||||
|
asm: "VPANDD" # Fill in the gap, And is missing for Uint16x32 and Int16x32
|
||||||
|
inVariant: []
|
||||||
|
in: &twoI16x32
|
||||||
|
- &i16x32
|
||||||
|
go: $t
|
||||||
|
overwriteElementBits: 16
|
||||||
|
- *i16x32
|
||||||
|
out: &oneI16x32
|
||||||
|
- *i16x32
|
||||||
|
|
||||||
|
- go: AndNot
|
||||||
|
asm: "VPANDN[DQ]?"
|
||||||
|
operandOrder: "21" # switch the arg order
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
|
||||||
|
- go: AndNot
|
||||||
|
asm: "VPANDND" # Fill in the gap, AndNot is missing for Uint8x64 and Int8x64
|
||||||
|
operandOrder: "21" # switch the arg order
|
||||||
|
inVariant: []
|
||||||
|
in: *twoI8x64
|
||||||
|
out: *oneI8x64
|
||||||
|
|
||||||
|
- go: AndNot
|
||||||
|
asm: "VPANDND" # Fill in the gap, AndNot is missing for Uint16x32 and Int16x32
|
||||||
|
operandOrder: "21" # switch the arg order
|
||||||
|
inVariant: []
|
||||||
|
in: *twoI16x32
|
||||||
|
out: *oneI16x32
|
||||||
|
|
||||||
|
- go: Or
|
||||||
|
asm: "VPOR[DQ]?"
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
|
||||||
|
- go: Or
|
||||||
|
asm: "VPORD" # Fill in the gap, Or is missing for Uint8x64 and Int8x64
|
||||||
|
inVariant: []
|
||||||
|
in: *twoI8x64
|
||||||
|
out: *oneI8x64
|
||||||
|
|
||||||
|
- go: Or
|
||||||
|
asm: "VPORD" # Fill in the gap, Or is missing for Uint16x32 and Int16x32
|
||||||
|
inVariant: []
|
||||||
|
in: *twoI16x32
|
||||||
|
out: *oneI16x32
|
||||||
|
|
||||||
|
- go: Xor
|
||||||
|
asm: "VPXOR[DQ]?"
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
|
||||||
|
- go: Xor
|
||||||
|
asm: "VPXORD" # Fill in the gap, Or is missing for Uint8x64 and Int8x64
|
||||||
|
inVariant: []
|
||||||
|
in: *twoI8x64
|
||||||
|
out: *oneI8x64
|
||||||
|
|
||||||
|
- go: Xor
|
||||||
|
asm: "VPXORD" # Fill in the gap, Or is missing for Uint16x32 and Int16x32
|
||||||
|
inVariant: []
|
||||||
|
in: *twoI16x32
|
||||||
|
out: *oneI16x32
|
||||||
43
src/simd/_gen/simdgen/ops/Compares/categories.yaml
Normal file
43
src/simd/_gen/simdgen/ops/Compares/categories.yaml
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
!sum
|
||||||
|
# const imm predicate(holds for both float and int|uint):
|
||||||
|
# 0: Equal
|
||||||
|
# 1: Less
|
||||||
|
# 2: LessEqual
|
||||||
|
# 4: NotEqual
|
||||||
|
# 5: GreaterEqual
|
||||||
|
# 6: Greater
|
||||||
|
- go: Equal
|
||||||
|
constImm: 0
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME compares for equality.
|
||||||
|
- go: Less
|
||||||
|
constImm: 1
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME compares for less than.
|
||||||
|
- go: LessEqual
|
||||||
|
constImm: 2
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME compares for less than or equal.
|
||||||
|
- go: IsNan # For float only.
|
||||||
|
constImm: 3
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME checks if elements are NaN. Use as x.IsNan(x).
|
||||||
|
- go: NotEqual
|
||||||
|
constImm: 4
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME compares for inequality.
|
||||||
|
- go: GreaterEqual
|
||||||
|
constImm: 13
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME compares for greater than or equal.
|
||||||
|
- go: Greater
|
||||||
|
constImm: 14
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME compares for greater than.
|
||||||
141
src/simd/_gen/simdgen/ops/Compares/go.yaml
Normal file
141
src/simd/_gen/simdgen/ops/Compares/go.yaml
Normal file
|
|
@ -0,0 +1,141 @@
|
||||||
|
!sum
|
||||||
|
# Ints
|
||||||
|
- go: Equal
|
||||||
|
asm: "V?PCMPEQ[BWDQ]"
|
||||||
|
in:
|
||||||
|
- &any
|
||||||
|
go: $t
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- &anyvregToMask
|
||||||
|
go: $t
|
||||||
|
overwriteBase: int
|
||||||
|
overwriteClass: mask
|
||||||
|
- go: Greater
|
||||||
|
asm: "V?PCMPGT[BWDQ]"
|
||||||
|
in:
|
||||||
|
- &int
|
||||||
|
go: $t
|
||||||
|
base: int
|
||||||
|
- *int
|
||||||
|
out:
|
||||||
|
- *anyvregToMask
|
||||||
|
# 256-bit VCMPGTQ's output elemBits is marked 32-bit in the XED data, we
|
||||||
|
# believe this is an error, so add this definition to overwrite.
|
||||||
|
- go: Greater
|
||||||
|
asm: "VPCMPGTQ"
|
||||||
|
in:
|
||||||
|
- &int64
|
||||||
|
go: $t
|
||||||
|
base: int
|
||||||
|
elemBits: 64
|
||||||
|
- *int64
|
||||||
|
out:
|
||||||
|
- base: int
|
||||||
|
elemBits: 32
|
||||||
|
overwriteElementBits: 64
|
||||||
|
overwriteClass: mask
|
||||||
|
overwriteBase: int
|
||||||
|
|
||||||
|
# TODO these are redundant with VPCMP operations.
|
||||||
|
# AVX-512 compares produce masks.
|
||||||
|
- go: Equal
|
||||||
|
asm: "V?PCMPEQ[BWDQ]"
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- class: mask
|
||||||
|
- go: Greater
|
||||||
|
asm: "V?PCMPGT[BWDQ]"
|
||||||
|
in:
|
||||||
|
- *int
|
||||||
|
- *int
|
||||||
|
out:
|
||||||
|
- class: mask
|
||||||
|
|
||||||
|
# MASKED signed comparisons for X/Y registers
|
||||||
|
# unmasked would clash with emulations on AVX2
|
||||||
|
- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual)
|
||||||
|
asm: "VPCMP[BWDQ]"
|
||||||
|
in:
|
||||||
|
- &int
|
||||||
|
bits: (128|256)
|
||||||
|
go: $t
|
||||||
|
base: int
|
||||||
|
- *int
|
||||||
|
- class: immediate
|
||||||
|
const: 0 # Just a placeholder, will be overwritten by const imm porting.
|
||||||
|
inVariant:
|
||||||
|
- class: mask
|
||||||
|
out:
|
||||||
|
- class: mask
|
||||||
|
|
||||||
|
# MASKED unsigned comparisons for X/Y registers
|
||||||
|
# unmasked would clash with emulations on AVX2
|
||||||
|
- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual)
|
||||||
|
asm: "VPCMPU[BWDQ]"
|
||||||
|
in:
|
||||||
|
- &uint
|
||||||
|
bits: (128|256)
|
||||||
|
go: $t
|
||||||
|
base: uint
|
||||||
|
- *uint
|
||||||
|
- class: immediate
|
||||||
|
const: 0
|
||||||
|
inVariant:
|
||||||
|
- class: mask
|
||||||
|
out:
|
||||||
|
- class: mask
|
||||||
|
|
||||||
|
# masked/unmasked signed comparisons for Z registers
|
||||||
|
- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual)
|
||||||
|
asm: "VPCMP[BWDQ]"
|
||||||
|
in:
|
||||||
|
- &int
|
||||||
|
bits: 512
|
||||||
|
go: $t
|
||||||
|
base: int
|
||||||
|
- *int
|
||||||
|
- class: immediate
|
||||||
|
const: 0 # Just a placeholder, will be overwritten by const imm porting.
|
||||||
|
out:
|
||||||
|
- class: mask
|
||||||
|
|
||||||
|
# masked/unmasked unsigned comparisons for Z registers
|
||||||
|
- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual)
|
||||||
|
asm: "VPCMPU[BWDQ]"
|
||||||
|
in:
|
||||||
|
- &uint
|
||||||
|
bits: 512
|
||||||
|
go: $t
|
||||||
|
base: uint
|
||||||
|
- *uint
|
||||||
|
- class: immediate
|
||||||
|
const: 0
|
||||||
|
out:
|
||||||
|
- class: mask
|
||||||
|
|
||||||
|
# Floats
|
||||||
|
- go: Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual|IsNan
|
||||||
|
asm: "VCMPP[SD]"
|
||||||
|
in:
|
||||||
|
- &float
|
||||||
|
go: $t
|
||||||
|
base: float
|
||||||
|
- *float
|
||||||
|
- class: immediate
|
||||||
|
const: 0
|
||||||
|
out:
|
||||||
|
- go: $t
|
||||||
|
overwriteBase: int
|
||||||
|
overwriteClass: mask
|
||||||
|
- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual|IsNan)
|
||||||
|
asm: "VCMPP[SD]"
|
||||||
|
in:
|
||||||
|
- *float
|
||||||
|
- *float
|
||||||
|
- class: immediate
|
||||||
|
const: 0
|
||||||
|
out:
|
||||||
|
- class: mask
|
||||||
10
src/simd/_gen/simdgen/ops/Converts/categories.yaml
Normal file
10
src/simd/_gen/simdgen/ops/Converts/categories.yaml
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
!sum
|
||||||
|
- go: ConvertToInt32
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// ConvertToInt32 converts element values to int32.
|
||||||
|
|
||||||
|
- go: ConvertToUint32
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// ConvertToUint32Masked converts element values to uint32.
|
||||||
21
src/simd/_gen/simdgen/ops/Converts/go.yaml
Normal file
21
src/simd/_gen/simdgen/ops/Converts/go.yaml
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
!sum
|
||||||
|
- go: ConvertToInt32
|
||||||
|
asm: "VCVTTPS2DQ"
|
||||||
|
in:
|
||||||
|
- &fp
|
||||||
|
go: $t
|
||||||
|
base: float
|
||||||
|
out:
|
||||||
|
- &i32
|
||||||
|
go: $u
|
||||||
|
base: int
|
||||||
|
elemBits: 32
|
||||||
|
- go: ConvertToUint32
|
||||||
|
asm: "VCVTPS2UDQ"
|
||||||
|
in:
|
||||||
|
- *fp
|
||||||
|
out:
|
||||||
|
- &u32
|
||||||
|
go: $u
|
||||||
|
base: uint
|
||||||
|
elemBits: 32
|
||||||
85
src/simd/_gen/simdgen/ops/FPonlyArith/categories.yaml
Normal file
85
src/simd/_gen/simdgen/ops/FPonlyArith/categories.yaml
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
!sum
|
||||||
|
- go: Div
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME divides elements of two vectors.
|
||||||
|
- go: Sqrt
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes the square root of each element.
|
||||||
|
- go: Reciprocal
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes an approximate reciprocal of each element.
|
||||||
|
- go: ReciprocalSqrt
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes an approximate reciprocal of the square root of each element.
|
||||||
|
- go: Scale
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME multiplies elements by a power of 2.
|
||||||
|
- go: RoundToEven
|
||||||
|
commutative: false
|
||||||
|
constImm: 0
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME rounds elements to the nearest integer.
|
||||||
|
- go: RoundToEvenScaled
|
||||||
|
commutative: false
|
||||||
|
constImm: 0
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME rounds elements with specified precision.
|
||||||
|
- go: RoundToEvenScaledResidue
|
||||||
|
commutative: false
|
||||||
|
constImm: 0
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes the difference after rounding with specified precision.
|
||||||
|
- go: Floor
|
||||||
|
commutative: false
|
||||||
|
constImm: 1
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME rounds elements down to the nearest integer.
|
||||||
|
- go: FloorScaled
|
||||||
|
commutative: false
|
||||||
|
constImm: 1
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME rounds elements down with specified precision.
|
||||||
|
- go: FloorScaledResidue
|
||||||
|
commutative: false
|
||||||
|
constImm: 1
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes the difference after flooring with specified precision.
|
||||||
|
- go: Ceil
|
||||||
|
commutative: false
|
||||||
|
constImm: 2
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME rounds elements up to the nearest integer.
|
||||||
|
- go: CeilScaled
|
||||||
|
commutative: false
|
||||||
|
constImm: 2
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME rounds elements up with specified precision.
|
||||||
|
- go: CeilScaledResidue
|
||||||
|
commutative: false
|
||||||
|
constImm: 2
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes the difference after ceiling with specified precision.
|
||||||
|
- go: Trunc
|
||||||
|
commutative: false
|
||||||
|
constImm: 3
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME truncates elements towards zero.
|
||||||
|
- go: TruncScaled
|
||||||
|
commutative: false
|
||||||
|
constImm: 3
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME truncates elements with specified precision.
|
||||||
|
- go: TruncScaledResidue
|
||||||
|
commutative: false
|
||||||
|
constImm: 3
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes the difference after truncating with specified precision.
|
||||||
|
- go: AddSub
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME subtracts even elements and adds odd elements of two vectors.
|
||||||
62
src/simd/_gen/simdgen/ops/FPonlyArith/go.yaml
Normal file
62
src/simd/_gen/simdgen/ops/FPonlyArith/go.yaml
Normal file
|
|
@ -0,0 +1,62 @@
|
||||||
|
!sum
|
||||||
|
- go: Div
|
||||||
|
asm: "V?DIVP[SD]"
|
||||||
|
in: &2fp
|
||||||
|
- &fp
|
||||||
|
go: $t
|
||||||
|
base: float
|
||||||
|
- *fp
|
||||||
|
out: &1fp
|
||||||
|
- *fp
|
||||||
|
- go: Sqrt
|
||||||
|
asm: "V?SQRTP[SD]"
|
||||||
|
in: *1fp
|
||||||
|
out: *1fp
|
||||||
|
# TODO: Provide separate methods for 12-bit precision and 14-bit precision?
|
||||||
|
- go: Reciprocal
|
||||||
|
asm: "VRCP(14)?P[SD]"
|
||||||
|
in: *1fp
|
||||||
|
out: *1fp
|
||||||
|
- go: ReciprocalSqrt
|
||||||
|
asm: "V?RSQRT(14)?P[SD]"
|
||||||
|
in: *1fp
|
||||||
|
out: *1fp
|
||||||
|
- go: Scale
|
||||||
|
asm: "VSCALEFP[SD]"
|
||||||
|
in: *2fp
|
||||||
|
out: *1fp
|
||||||
|
|
||||||
|
- go: "RoundToEven|Ceil|Floor|Trunc"
|
||||||
|
asm: "VROUNDP[SD]"
|
||||||
|
in:
|
||||||
|
- *fp
|
||||||
|
- class: immediate
|
||||||
|
const: 0 # place holder
|
||||||
|
out: *1fp
|
||||||
|
|
||||||
|
- go: "(RoundToEven|Ceil|Floor|Trunc)Scaled"
|
||||||
|
asm: "VRNDSCALEP[SD]"
|
||||||
|
in:
|
||||||
|
- *fp
|
||||||
|
- class: immediate
|
||||||
|
const: 0 # place holder
|
||||||
|
immOffset: 4 # "M", round to numbers with M digits after dot(by means of binary number).
|
||||||
|
name: prec
|
||||||
|
out: *1fp
|
||||||
|
- go: "(RoundToEven|Ceil|Floor|Trunc)ScaledResidue"
|
||||||
|
asm: "VREDUCEP[SD]"
|
||||||
|
in:
|
||||||
|
- *fp
|
||||||
|
- class: immediate
|
||||||
|
const: 0 # place holder
|
||||||
|
immOffset: 4 # "M", round to numbers with M digits after dot(by means of binary number).
|
||||||
|
name: prec
|
||||||
|
out: *1fp
|
||||||
|
|
||||||
|
- go: "AddSub"
|
||||||
|
asm: "VADDSUBP[SD]"
|
||||||
|
in:
|
||||||
|
- *fp
|
||||||
|
- *fp
|
||||||
|
out:
|
||||||
|
- *fp
|
||||||
21
src/simd/_gen/simdgen/ops/GaloisField/categories.yaml
Normal file
21
src/simd/_gen/simdgen/ops/GaloisField/categories.yaml
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
!sum
|
||||||
|
- go: GaloisFieldAffineTransform
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes an affine transformation in GF(2^8):
|
||||||
|
// x is a vector of 8-bit vectors, with each adjacent 8 as a group; y is a vector of 8x8 1-bit matrixes;
|
||||||
|
// b is an 8-bit vector. The affine transformation is y * x + b, with each element of y
|
||||||
|
// corresponding to a group of 8 elements in x.
|
||||||
|
- go: GaloisFieldAffineTransformInverse
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes an affine transformation in GF(2^8),
|
||||||
|
// with x inverted with respect to reduction polynomial x^8 + x^4 + x^3 + x + 1:
|
||||||
|
// x is a vector of 8-bit vectors, with each adjacent 8 as a group; y is a vector of 8x8 1-bit matrixes;
|
||||||
|
// b is an 8-bit vector. The affine transformation is y * x + b, with each element of y
|
||||||
|
// corresponding to a group of 8 elements in x.
|
||||||
|
- go: GaloisFieldMul
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes element-wise GF(2^8) multiplication with
|
||||||
|
// reduction polynomial x^8 + x^4 + x^3 + x + 1.
|
||||||
32
src/simd/_gen/simdgen/ops/GaloisField/go.yaml
Normal file
32
src/simd/_gen/simdgen/ops/GaloisField/go.yaml
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
!sum
|
||||||
|
- go: GaloisFieldAffineTransform
|
||||||
|
asm: VGF2P8AFFINEQB
|
||||||
|
operandOrder: 2I # 2nd operand, then immediate
|
||||||
|
in: &AffineArgs
|
||||||
|
- &uint8
|
||||||
|
go: $t
|
||||||
|
base: uint
|
||||||
|
- &uint8x8
|
||||||
|
go: $t2
|
||||||
|
base: uint
|
||||||
|
- &pureImmVar
|
||||||
|
class: immediate
|
||||||
|
immOffset: 0
|
||||||
|
name: b
|
||||||
|
out:
|
||||||
|
- *uint8
|
||||||
|
|
||||||
|
- go: GaloisFieldAffineTransformInverse
|
||||||
|
asm: VGF2P8AFFINEINVQB
|
||||||
|
operandOrder: 2I # 2nd operand, then immediate
|
||||||
|
in: *AffineArgs
|
||||||
|
out:
|
||||||
|
- *uint8
|
||||||
|
|
||||||
|
- go: GaloisFieldMul
|
||||||
|
asm: VGF2P8MULB
|
||||||
|
in:
|
||||||
|
- *uint8
|
||||||
|
- *uint8
|
||||||
|
out:
|
||||||
|
- *uint8
|
||||||
21
src/simd/_gen/simdgen/ops/IntOnlyArith/categories.yaml
Normal file
21
src/simd/_gen/simdgen/ops/IntOnlyArith/categories.yaml
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
!sum
|
||||||
|
- go: Average
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes the rounded average of corresponding elements.
|
||||||
|
- go: Abs
|
||||||
|
commutative: false
|
||||||
|
# Unary operation, not commutative
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes the absolute value of each element.
|
||||||
|
- go: CopySign
|
||||||
|
# Applies sign of second operand to first: sign(val, sign_src)
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME returns the product of the first operand with -1, 0, or 1,
|
||||||
|
// whichever constant is nearest to the value of the second operand.
|
||||||
|
# Sign does not have masked version
|
||||||
|
- go: OnesCount
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME counts the number of set bits in each element.
|
||||||
45
src/simd/_gen/simdgen/ops/IntOnlyArith/go.yaml
Normal file
45
src/simd/_gen/simdgen/ops/IntOnlyArith/go.yaml
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
!sum
|
||||||
|
# Average (unsigned byte, unsigned word)
|
||||||
|
# Instructions: VPAVGB, VPAVGW
|
||||||
|
- go: Average
|
||||||
|
asm: "VPAVG[BW]" # Matches VPAVGB (byte) and VPAVGW (word)
|
||||||
|
in:
|
||||||
|
- &uint_t # $t will be Uint8xN for VPAVGB, Uint16xN for VPAVGW
|
||||||
|
go: $t
|
||||||
|
base: uint
|
||||||
|
- *uint_t
|
||||||
|
out:
|
||||||
|
- *uint_t
|
||||||
|
|
||||||
|
# Absolute Value (signed byte, word, dword, qword)
|
||||||
|
# Instructions: VPABSB, VPABSW, VPABSD, VPABSQ
|
||||||
|
- go: Abs
|
||||||
|
asm: "VPABS[BWDQ]" # Matches VPABSB, VPABSW, VPABSD, VPABSQ
|
||||||
|
in:
|
||||||
|
- &int_t # $t will be Int8xN, Int16xN, Int32xN, Int64xN
|
||||||
|
go: $t
|
||||||
|
base: int
|
||||||
|
out:
|
||||||
|
- *int_t # Output is magnitude, fits in the same signed type
|
||||||
|
|
||||||
|
# Sign Operation (signed byte, word, dword)
|
||||||
|
# Applies sign of second operand to the first.
|
||||||
|
# Instructions: VPSIGNB, VPSIGNW, VPSIGND
|
||||||
|
- go: CopySign
|
||||||
|
asm: "VPSIGN[BWD]" # Matches VPSIGNB, VPSIGNW, VPSIGND
|
||||||
|
in:
|
||||||
|
- *int_t # value to apply sign to
|
||||||
|
- *int_t # value from which to take the sign
|
||||||
|
out:
|
||||||
|
- *int_t
|
||||||
|
|
||||||
|
# Population Count (count set bits in each element)
|
||||||
|
# Instructions: VPOPCNTB, VPOPCNTW (AVX512_BITALG)
|
||||||
|
# VPOPCNTD, VPOPCNTQ (AVX512_VPOPCNTDQ)
|
||||||
|
- go: OnesCount
|
||||||
|
asm: "VPOPCNT[BWDQ]"
|
||||||
|
in:
|
||||||
|
- &any
|
||||||
|
go: $t
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
47
src/simd/_gen/simdgen/ops/MLOps/categories.yaml
Normal file
47
src/simd/_gen/simdgen/ops/MLOps/categories.yaml
Normal file
|
|
@ -0,0 +1,47 @@
|
||||||
|
!sum
|
||||||
|
- go: DotProdPairs
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME multiplies the elements and add the pairs together,
|
||||||
|
// yielding a vector of half as many elements with twice the input element size.
|
||||||
|
# TODO: maybe simplify this name within the receiver-type + method-naming scheme we use.
|
||||||
|
- go: DotProdPairsSaturated
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME multiplies the elements and add the pairs together with saturation,
|
||||||
|
// yielding a vector of half as many elements with twice the input element size.
|
||||||
|
# QuadDotProd, i.e. VPDPBUSD(S) are operations with src/dst on the same register, we are not supporting this as of now.
|
||||||
|
# - go: DotProdBroadcast
|
||||||
|
# commutative: true
|
||||||
|
# # documentation: !string |-
|
||||||
|
# // NAME multiplies all elements and broadcasts the sum.
|
||||||
|
- go: AddDotProdQuadruple
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs dot products on groups of 4 elements of x and y and then adds z.
|
||||||
|
- go: AddDotProdQuadrupleSaturated
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME multiplies performs dot products on groups of 4 elements of x and y and then adds z.
|
||||||
|
- go: AddDotProdPairs
|
||||||
|
commutative: false
|
||||||
|
noTypes: "true"
|
||||||
|
noGenericOps: "true"
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs dot products on pairs of elements of y and z and then adds x.
|
||||||
|
- go: AddDotProdPairsSaturated
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs dot products on pairs of elements of y and z and then adds x.
|
||||||
|
- go: MulAdd
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs a fused (x * y) + z.
|
||||||
|
- go: MulAddSub
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs a fused (x * y) - z for odd-indexed elements, and (x * y) + z for even-indexed elements.
|
||||||
|
- go: MulSubAdd
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs a fused (x * y) + z for odd-indexed elements, and (x * y) - z for even-indexed elements.
|
||||||
113
src/simd/_gen/simdgen/ops/MLOps/go.yaml
Normal file
113
src/simd/_gen/simdgen/ops/MLOps/go.yaml
Normal file
|
|
@ -0,0 +1,113 @@
|
||||||
|
!sum
|
||||||
|
- go: DotProdPairs
|
||||||
|
asm: VPMADDWD
|
||||||
|
in:
|
||||||
|
- &int
|
||||||
|
go: $t
|
||||||
|
base: int
|
||||||
|
- *int
|
||||||
|
out:
|
||||||
|
- &int2 # The elemBits are different
|
||||||
|
go: $t2
|
||||||
|
base: int
|
||||||
|
- go: DotProdPairsSaturated
|
||||||
|
asm: VPMADDUBSW
|
||||||
|
in:
|
||||||
|
- &uint
|
||||||
|
go: $t
|
||||||
|
base: uint
|
||||||
|
overwriteElementBits: 8
|
||||||
|
- &int3
|
||||||
|
go: $t3
|
||||||
|
base: int
|
||||||
|
overwriteElementBits: 8
|
||||||
|
out:
|
||||||
|
- *int2
|
||||||
|
# - go: DotProdBroadcast
|
||||||
|
# asm: VDPP[SD]
|
||||||
|
# in:
|
||||||
|
# - &dpb_src
|
||||||
|
# go: $t
|
||||||
|
# - *dpb_src
|
||||||
|
# - class: immediate
|
||||||
|
# const: 127
|
||||||
|
# out:
|
||||||
|
# - *dpb_src
|
||||||
|
- go: AddDotProdQuadruple
|
||||||
|
asm: "VPDPBUSD"
|
||||||
|
operandOrder: "31" # switch operand 3 and 1
|
||||||
|
in:
|
||||||
|
- &qdpa_acc
|
||||||
|
go: $t_acc
|
||||||
|
base: int
|
||||||
|
elemBits: 32
|
||||||
|
- &qdpa_src1
|
||||||
|
go: $t_src1
|
||||||
|
base: uint
|
||||||
|
overwriteElementBits: 8
|
||||||
|
- &qdpa_src2
|
||||||
|
go: $t_src2
|
||||||
|
base: int
|
||||||
|
overwriteElementBits: 8
|
||||||
|
out:
|
||||||
|
- *qdpa_acc
|
||||||
|
- go: AddDotProdQuadrupleSaturated
|
||||||
|
asm: "VPDPBUSDS"
|
||||||
|
operandOrder: "31" # switch operand 3 and 1
|
||||||
|
in:
|
||||||
|
- *qdpa_acc
|
||||||
|
- *qdpa_src1
|
||||||
|
- *qdpa_src2
|
||||||
|
out:
|
||||||
|
- *qdpa_acc
|
||||||
|
- go: AddDotProdPairs
|
||||||
|
asm: "VPDPWSSD"
|
||||||
|
in:
|
||||||
|
- &pdpa_acc
|
||||||
|
go: $t_acc
|
||||||
|
base: int
|
||||||
|
elemBits: 32
|
||||||
|
- &pdpa_src1
|
||||||
|
go: $t_src1
|
||||||
|
base: int
|
||||||
|
overwriteElementBits: 16
|
||||||
|
- &pdpa_src2
|
||||||
|
go: $t_src2
|
||||||
|
base: int
|
||||||
|
overwriteElementBits: 16
|
||||||
|
out:
|
||||||
|
- *pdpa_acc
|
||||||
|
- go: AddDotProdPairsSaturated
|
||||||
|
asm: "VPDPWSSDS"
|
||||||
|
in:
|
||||||
|
- *pdpa_acc
|
||||||
|
- *pdpa_src1
|
||||||
|
- *pdpa_src2
|
||||||
|
out:
|
||||||
|
- *pdpa_acc
|
||||||
|
- go: MulAdd
|
||||||
|
asm: "VFMADD213PS|VFMADD213PD"
|
||||||
|
in:
|
||||||
|
- &fma_op
|
||||||
|
go: $t
|
||||||
|
base: float
|
||||||
|
- *fma_op
|
||||||
|
- *fma_op
|
||||||
|
out:
|
||||||
|
- *fma_op
|
||||||
|
- go: MulAddSub
|
||||||
|
asm: "VFMADDSUB213PS|VFMADDSUB213PD"
|
||||||
|
in:
|
||||||
|
- *fma_op
|
||||||
|
- *fma_op
|
||||||
|
- *fma_op
|
||||||
|
out:
|
||||||
|
- *fma_op
|
||||||
|
- go: MulSubAdd
|
||||||
|
asm: "VFMSUBADD213PS|VFMSUBADD213PD"
|
||||||
|
in:
|
||||||
|
- *fma_op
|
||||||
|
- *fma_op
|
||||||
|
- *fma_op
|
||||||
|
out:
|
||||||
|
- *fma_op
|
||||||
9
src/simd/_gen/simdgen/ops/MinMax/categories.yaml
Normal file
9
src/simd/_gen/simdgen/ops/MinMax/categories.yaml
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
!sum
|
||||||
|
- go: Max
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes the maximum of corresponding elements.
|
||||||
|
- go: Min
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME computes the minimum of corresponding elements.
|
||||||
42
src/simd/_gen/simdgen/ops/MinMax/go.yaml
Normal file
42
src/simd/_gen/simdgen/ops/MinMax/go.yaml
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
!sum
|
||||||
|
- go: Max
|
||||||
|
asm: "V?PMAXS[BWDQ]"
|
||||||
|
in: &2int
|
||||||
|
- &int
|
||||||
|
go: $t
|
||||||
|
base: int
|
||||||
|
- *int
|
||||||
|
out: &1int
|
||||||
|
- *int
|
||||||
|
- go: Max
|
||||||
|
asm: "V?PMAXU[BWDQ]"
|
||||||
|
in: &2uint
|
||||||
|
- &uint
|
||||||
|
go: $t
|
||||||
|
base: uint
|
||||||
|
- *uint
|
||||||
|
out: &1uint
|
||||||
|
- *uint
|
||||||
|
|
||||||
|
- go: Min
|
||||||
|
asm: "V?PMINS[BWDQ]"
|
||||||
|
in: *2int
|
||||||
|
out: *1int
|
||||||
|
- go: Min
|
||||||
|
asm: "V?PMINU[BWDQ]"
|
||||||
|
in: *2uint
|
||||||
|
out: *1uint
|
||||||
|
|
||||||
|
- go: Max
|
||||||
|
asm: "V?MAXP[SD]"
|
||||||
|
in: &2float
|
||||||
|
- &float
|
||||||
|
go: $t
|
||||||
|
base: float
|
||||||
|
- *float
|
||||||
|
out: &1float
|
||||||
|
- *float
|
||||||
|
- go: Min
|
||||||
|
asm: "V?MINP[SD]"
|
||||||
|
in: *2float
|
||||||
|
out: *1float
|
||||||
72
src/simd/_gen/simdgen/ops/Moves/categories.yaml
Normal file
72
src/simd/_gen/simdgen/ops/Moves/categories.yaml
Normal file
|
|
@ -0,0 +1,72 @@
|
||||||
|
!sum
|
||||||
|
- go: SetElem
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME sets a single constant-indexed element's value.
|
||||||
|
- go: GetElem
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME retrieves a single constant-indexed element's value.
|
||||||
|
- go: SetLo
|
||||||
|
commutative: false
|
||||||
|
constImm: 0
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME returns x with its lower half set to y.
|
||||||
|
- go: GetLo
|
||||||
|
commutative: false
|
||||||
|
constImm: 0
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME returns the lower half of x.
|
||||||
|
- go: SetHi
|
||||||
|
commutative: false
|
||||||
|
constImm: 1
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME returns x with its upper half set to y.
|
||||||
|
- go: GetHi
|
||||||
|
commutative: false
|
||||||
|
constImm: 1
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME returns the upper half of x.
|
||||||
|
- go: Permute
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs a full permutation of vector x using indices:
|
||||||
|
// result := {x[indices[0]], x[indices[1]], ..., x[indices[n]]}
|
||||||
|
// Only the needed bits to represent x's index are used in indices' elements.
|
||||||
|
- go: Permute2 # Permute2 is only available on or after AVX512
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs a full permutation of vector x, y using indices:
|
||||||
|
// result := {xy[indices[0]], xy[indices[1]], ..., xy[indices[n]]}
|
||||||
|
// where xy is x appending y.
|
||||||
|
// Only the needed bits to represent xy's index are used in indices' elements.
|
||||||
|
- go: Compress
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs a compression on vector x using mask by
|
||||||
|
// selecting elements as indicated by mask, and pack them to lower indexed elements.
|
||||||
|
- go: blend
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME blends two vectors based on mask values, choosing either
|
||||||
|
// the first or the second based on whether the third is false or true
|
||||||
|
- go: Expand
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME performs an expansion on a vector x whose elements are packed to lower parts.
|
||||||
|
// The expansion is to distribute elements as indexed by mask, from lower mask elements to upper in order.
|
||||||
|
- go: Broadcast128
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME copies element zero of its (128-bit) input to all elements of
|
||||||
|
// the 128-bit output vector.
|
||||||
|
- go: Broadcast256
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME copies element zero of its (128-bit) input to all elements of
|
||||||
|
// the 256-bit output vector.
|
||||||
|
- go: Broadcast512
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME copies element zero of its (128-bit) input to all elements of
|
||||||
|
// the 512-bit output vector.
|
||||||
372
src/simd/_gen/simdgen/ops/Moves/go.yaml
Normal file
372
src/simd/_gen/simdgen/ops/Moves/go.yaml
Normal file
|
|
@ -0,0 +1,372 @@
|
||||||
|
!sum
|
||||||
|
- go: SetElem
|
||||||
|
asm: "VPINSR[BWDQ]"
|
||||||
|
in:
|
||||||
|
- &t
|
||||||
|
class: vreg
|
||||||
|
base: $b
|
||||||
|
- class: greg
|
||||||
|
base: $b
|
||||||
|
lanes: 1 # Scalar, darn it!
|
||||||
|
- &imm
|
||||||
|
class: immediate
|
||||||
|
immOffset: 0
|
||||||
|
name: index
|
||||||
|
out:
|
||||||
|
- *t
|
||||||
|
|
||||||
|
- go: SetElem
|
||||||
|
asm: "VPINSR[DQ]"
|
||||||
|
in:
|
||||||
|
- &t
|
||||||
|
class: vreg
|
||||||
|
base: int
|
||||||
|
OverwriteBase: float
|
||||||
|
- class: greg
|
||||||
|
base: int
|
||||||
|
OverwriteBase: float
|
||||||
|
lanes: 1 # Scalar, darn it!
|
||||||
|
- &imm
|
||||||
|
class: immediate
|
||||||
|
immOffset: 0
|
||||||
|
name: index
|
||||||
|
out:
|
||||||
|
- *t
|
||||||
|
|
||||||
|
- go: GetElem
|
||||||
|
asm: "VPEXTR[BWDQ]"
|
||||||
|
in:
|
||||||
|
- class: vreg
|
||||||
|
base: $b
|
||||||
|
elemBits: $e
|
||||||
|
- *imm
|
||||||
|
out:
|
||||||
|
- class: greg
|
||||||
|
base: $b
|
||||||
|
bits: $e
|
||||||
|
|
||||||
|
- go: "SetHi|SetLo"
|
||||||
|
asm: "VINSERTI128|VINSERTI64X4"
|
||||||
|
inVariant: []
|
||||||
|
in:
|
||||||
|
- &i8x2N
|
||||||
|
class: vreg
|
||||||
|
base: $t
|
||||||
|
OverwriteElementBits: 8
|
||||||
|
- &i8xN
|
||||||
|
class: vreg
|
||||||
|
base: $t
|
||||||
|
OverwriteElementBits: 8
|
||||||
|
- &imm01 # This immediate should be only 0 or 1
|
||||||
|
class: immediate
|
||||||
|
const: 0 # place holder
|
||||||
|
name: index
|
||||||
|
out:
|
||||||
|
- *i8x2N
|
||||||
|
|
||||||
|
- go: "GetHi|GetLo"
|
||||||
|
asm: "VEXTRACTI128|VEXTRACTI64X4"
|
||||||
|
inVariant: []
|
||||||
|
in:
|
||||||
|
- *i8x2N
|
||||||
|
- *imm01
|
||||||
|
out:
|
||||||
|
- *i8xN
|
||||||
|
|
||||||
|
- go: "SetHi|SetLo"
|
||||||
|
asm: "VINSERTI128|VINSERTI64X4"
|
||||||
|
inVariant: []
|
||||||
|
in:
|
||||||
|
- &i16x2N
|
||||||
|
class: vreg
|
||||||
|
base: $t
|
||||||
|
OverwriteElementBits: 16
|
||||||
|
- &i16xN
|
||||||
|
class: vreg
|
||||||
|
base: $t
|
||||||
|
OverwriteElementBits: 16
|
||||||
|
- *imm01
|
||||||
|
out:
|
||||||
|
- *i16x2N
|
||||||
|
|
||||||
|
- go: "GetHi|GetLo"
|
||||||
|
asm: "VEXTRACTI128|VEXTRACTI64X4"
|
||||||
|
inVariant: []
|
||||||
|
in:
|
||||||
|
- *i16x2N
|
||||||
|
- *imm01
|
||||||
|
out:
|
||||||
|
- *i16xN
|
||||||
|
|
||||||
|
- go: "SetHi|SetLo"
|
||||||
|
asm: "VINSERTI128|VINSERTI64X4"
|
||||||
|
inVariant: []
|
||||||
|
in:
|
||||||
|
- &i32x2N
|
||||||
|
class: vreg
|
||||||
|
base: $t
|
||||||
|
OverwriteElementBits: 32
|
||||||
|
- &i32xN
|
||||||
|
class: vreg
|
||||||
|
base: $t
|
||||||
|
OverwriteElementBits: 32
|
||||||
|
- *imm01
|
||||||
|
out:
|
||||||
|
- *i32x2N
|
||||||
|
|
||||||
|
- go: "GetHi|GetLo"
|
||||||
|
asm: "VEXTRACTI128|VEXTRACTI64X4"
|
||||||
|
inVariant: []
|
||||||
|
in:
|
||||||
|
- *i32x2N
|
||||||
|
- *imm01
|
||||||
|
out:
|
||||||
|
- *i32xN
|
||||||
|
|
||||||
|
- go: "SetHi|SetLo"
|
||||||
|
asm: "VINSERTI128|VINSERTI64X4"
|
||||||
|
inVariant: []
|
||||||
|
in:
|
||||||
|
- &i64x2N
|
||||||
|
class: vreg
|
||||||
|
base: $t
|
||||||
|
OverwriteElementBits: 64
|
||||||
|
- &i64xN
|
||||||
|
class: vreg
|
||||||
|
base: $t
|
||||||
|
OverwriteElementBits: 64
|
||||||
|
- *imm01
|
||||||
|
out:
|
||||||
|
- *i64x2N
|
||||||
|
|
||||||
|
- go: "GetHi|GetLo"
|
||||||
|
asm: "VEXTRACTI128|VEXTRACTI64X4"
|
||||||
|
inVariant: []
|
||||||
|
in:
|
||||||
|
- *i64x2N
|
||||||
|
- *imm01
|
||||||
|
out:
|
||||||
|
- *i64xN
|
||||||
|
|
||||||
|
- go: "SetHi|SetLo"
|
||||||
|
asm: "VINSERTF128|VINSERTF64X4"
|
||||||
|
inVariant: []
|
||||||
|
in:
|
||||||
|
- &f32x2N
|
||||||
|
class: vreg
|
||||||
|
base: $t
|
||||||
|
OverwriteElementBits: 32
|
||||||
|
- &f32xN
|
||||||
|
class: vreg
|
||||||
|
base: $t
|
||||||
|
OverwriteElementBits: 32
|
||||||
|
- *imm01
|
||||||
|
out:
|
||||||
|
- *f32x2N
|
||||||
|
|
||||||
|
- go: "GetHi|GetLo"
|
||||||
|
asm: "VEXTRACTF128|VEXTRACTF64X4"
|
||||||
|
inVariant: []
|
||||||
|
in:
|
||||||
|
- *f32x2N
|
||||||
|
- *imm01
|
||||||
|
out:
|
||||||
|
- *f32xN
|
||||||
|
|
||||||
|
- go: "SetHi|SetLo"
|
||||||
|
asm: "VINSERTF128|VINSERTF64X4"
|
||||||
|
inVariant: []
|
||||||
|
in:
|
||||||
|
- &f64x2N
|
||||||
|
class: vreg
|
||||||
|
base: $t
|
||||||
|
OverwriteElementBits: 64
|
||||||
|
- &f64xN
|
||||||
|
class: vreg
|
||||||
|
base: $t
|
||||||
|
OverwriteElementBits: 64
|
||||||
|
- *imm01
|
||||||
|
out:
|
||||||
|
- *f64x2N
|
||||||
|
|
||||||
|
- go: "GetHi|GetLo"
|
||||||
|
asm: "VEXTRACTF128|VEXTRACTF64X4"
|
||||||
|
inVariant: []
|
||||||
|
in:
|
||||||
|
- *f64x2N
|
||||||
|
- *imm01
|
||||||
|
out:
|
||||||
|
- *f64xN
|
||||||
|
|
||||||
|
- go: Permute
|
||||||
|
asm: "VPERM[BWDQ]|VPERMP[SD]"
|
||||||
|
operandOrder: "21Type1"
|
||||||
|
in:
|
||||||
|
- &anyindices
|
||||||
|
go: $t
|
||||||
|
name: indices
|
||||||
|
overwriteBase: uint
|
||||||
|
- &any
|
||||||
|
go: $t
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
|
||||||
|
- go: Permute2
|
||||||
|
asm: "VPERMI2[BWDQ]|VPERMI2P[SD]"
|
||||||
|
# Because we are overwriting the receiver's type, we
|
||||||
|
# have to move the receiver to be a parameter so that
|
||||||
|
# we can have no duplication.
|
||||||
|
operandOrder: "231Type1"
|
||||||
|
in:
|
||||||
|
- *anyindices # result in arg 0
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
|
||||||
|
- go: Compress
|
||||||
|
asm: "VPCOMPRESS[BWDQ]|VCOMPRESSP[SD]"
|
||||||
|
in:
|
||||||
|
# The mask in Compress is a control mask rather than a write mask, so it's not optional.
|
||||||
|
- class: mask
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
|
||||||
|
# For now a non-public method because
|
||||||
|
# (1) [OverwriteClass] must be set together with [OverwriteBase]
|
||||||
|
# (2) "simdgen does not support [OverwriteClass] in inputs".
|
||||||
|
# That means the signature is wrong.
|
||||||
|
- go: blend
|
||||||
|
asm: VPBLENDVB
|
||||||
|
in:
|
||||||
|
- &v
|
||||||
|
go: $t
|
||||||
|
class: vreg
|
||||||
|
base: int
|
||||||
|
- *v
|
||||||
|
-
|
||||||
|
class: vreg
|
||||||
|
base: int
|
||||||
|
name: mask
|
||||||
|
out:
|
||||||
|
- *v
|
||||||
|
|
||||||
|
# For AVX512
|
||||||
|
- go: blend
|
||||||
|
asm: VPBLENDM[BWDQ]
|
||||||
|
in:
|
||||||
|
- &v
|
||||||
|
go: $t
|
||||||
|
bits: 512
|
||||||
|
class: vreg
|
||||||
|
base: int
|
||||||
|
- *v
|
||||||
|
inVariant:
|
||||||
|
-
|
||||||
|
class: mask
|
||||||
|
out:
|
||||||
|
- *v
|
||||||
|
|
||||||
|
- go: Expand
|
||||||
|
asm: "VPEXPAND[BWDQ]|VEXPANDP[SD]"
|
||||||
|
in:
|
||||||
|
# The mask in Expand is a control mask rather than a write mask, so it's not optional.
|
||||||
|
- class: mask
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
|
||||||
|
- go: Broadcast128
|
||||||
|
asm: VPBROADCAST[BWDQ]
|
||||||
|
in:
|
||||||
|
- class: vreg
|
||||||
|
bits: 128
|
||||||
|
elemBits: $e
|
||||||
|
base: $b
|
||||||
|
out:
|
||||||
|
- class: vreg
|
||||||
|
bits: 128
|
||||||
|
elemBits: $e
|
||||||
|
base: $b
|
||||||
|
|
||||||
|
# weirdly, this one case on AVX2 is memory-operand-only
|
||||||
|
- go: Broadcast128
|
||||||
|
asm: VPBROADCASTQ
|
||||||
|
in:
|
||||||
|
- class: vreg
|
||||||
|
bits: 128
|
||||||
|
elemBits: 64
|
||||||
|
base: int
|
||||||
|
OverwriteBase: float
|
||||||
|
out:
|
||||||
|
- class: vreg
|
||||||
|
bits: 128
|
||||||
|
elemBits: 64
|
||||||
|
base: int
|
||||||
|
OverwriteBase: float
|
||||||
|
|
||||||
|
- go: Broadcast256
|
||||||
|
asm: VPBROADCAST[BWDQ]
|
||||||
|
in:
|
||||||
|
- class: vreg
|
||||||
|
bits: 128
|
||||||
|
elemBits: $e
|
||||||
|
base: $b
|
||||||
|
out:
|
||||||
|
- class: vreg
|
||||||
|
bits: 256
|
||||||
|
elemBits: $e
|
||||||
|
base: $b
|
||||||
|
|
||||||
|
- go: Broadcast512
|
||||||
|
asm: VPBROADCAST[BWDQ]
|
||||||
|
in:
|
||||||
|
- class: vreg
|
||||||
|
bits: 128
|
||||||
|
elemBits: $e
|
||||||
|
base: $b
|
||||||
|
out:
|
||||||
|
- class: vreg
|
||||||
|
bits: 512
|
||||||
|
elemBits: $e
|
||||||
|
base: $b
|
||||||
|
|
||||||
|
- go: Broadcast128
|
||||||
|
asm: VBROADCASTS[SD]
|
||||||
|
in:
|
||||||
|
- class: vreg
|
||||||
|
bits: 128
|
||||||
|
elemBits: $e
|
||||||
|
base: $b
|
||||||
|
out:
|
||||||
|
- class: vreg
|
||||||
|
bits: 128
|
||||||
|
elemBits: $e
|
||||||
|
base: $b
|
||||||
|
|
||||||
|
- go: Broadcast256
|
||||||
|
asm: VBROADCASTS[SD]
|
||||||
|
in:
|
||||||
|
- class: vreg
|
||||||
|
bits: 128
|
||||||
|
elemBits: $e
|
||||||
|
base: $b
|
||||||
|
out:
|
||||||
|
- class: vreg
|
||||||
|
bits: 256
|
||||||
|
elemBits: $e
|
||||||
|
base: $b
|
||||||
|
|
||||||
|
- go: Broadcast512
|
||||||
|
asm: VBROADCASTS[SD]
|
||||||
|
in:
|
||||||
|
- class: vreg
|
||||||
|
bits: 128
|
||||||
|
elemBits: $e
|
||||||
|
base: $b
|
||||||
|
out:
|
||||||
|
- class: vreg
|
||||||
|
bits: 512
|
||||||
|
elemBits: $e
|
||||||
|
base: $b
|
||||||
14
src/simd/_gen/simdgen/ops/Mul/categories.yaml
Normal file
14
src/simd/_gen/simdgen/ops/Mul/categories.yaml
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
!sum
|
||||||
|
- go: Mul
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME multiplies corresponding elements of two vectors.
|
||||||
|
- go: MulEvenWiden
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME multiplies even-indexed elements, widening the result.
|
||||||
|
// Result[i] = v1.Even[i] * v2.Even[i].
|
||||||
|
- go: MulHigh
|
||||||
|
commutative: true
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME multiplies elements and stores the high part of the result.
|
||||||
73
src/simd/_gen/simdgen/ops/Mul/go.yaml
Normal file
73
src/simd/_gen/simdgen/ops/Mul/go.yaml
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
!sum
|
||||||
|
# "Normal" multiplication is only available for floats.
|
||||||
|
# This only covers the single and double precision.
|
||||||
|
- go: Mul
|
||||||
|
asm: "VMULP[SD]"
|
||||||
|
in:
|
||||||
|
- &fp
|
||||||
|
go: $t
|
||||||
|
base: float
|
||||||
|
- *fp
|
||||||
|
out:
|
||||||
|
- *fp
|
||||||
|
|
||||||
|
# Integer multiplications.
|
||||||
|
|
||||||
|
# MulEvenWiden
|
||||||
|
# Dword only.
|
||||||
|
- go: MulEvenWiden
|
||||||
|
asm: "VPMULDQ"
|
||||||
|
in:
|
||||||
|
- &intNot64
|
||||||
|
go: $t
|
||||||
|
elemBits: 8|16|32
|
||||||
|
base: int
|
||||||
|
- *intNot64
|
||||||
|
out:
|
||||||
|
- &int2
|
||||||
|
go: $t2
|
||||||
|
base: int
|
||||||
|
- go: MulEvenWiden
|
||||||
|
asm: "VPMULUDQ"
|
||||||
|
in:
|
||||||
|
- &uintNot64
|
||||||
|
go: $t
|
||||||
|
elemBits: 8|16|32
|
||||||
|
base: uint
|
||||||
|
- *uintNot64
|
||||||
|
out:
|
||||||
|
- &uint2
|
||||||
|
go: $t2
|
||||||
|
base: uint
|
||||||
|
|
||||||
|
# MulHigh
|
||||||
|
# Word only.
|
||||||
|
- go: MulHigh
|
||||||
|
asm: "VPMULHW"
|
||||||
|
in:
|
||||||
|
- &int
|
||||||
|
go: $t
|
||||||
|
base: int
|
||||||
|
- *int
|
||||||
|
out:
|
||||||
|
- *int
|
||||||
|
- go: MulHigh
|
||||||
|
asm: "VPMULHUW"
|
||||||
|
in:
|
||||||
|
- &uint
|
||||||
|
go: $t
|
||||||
|
base: uint
|
||||||
|
- *uint
|
||||||
|
out:
|
||||||
|
- *uint
|
||||||
|
|
||||||
|
# MulLow
|
||||||
|
# signed and unsigned are the same for lower bits.
|
||||||
|
- go: Mul
|
||||||
|
asm: "VPMULL[WDQ]"
|
||||||
|
in:
|
||||||
|
- &any
|
||||||
|
go: $t
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
103
src/simd/_gen/simdgen/ops/ShiftRotate/categories.yaml
Normal file
103
src/simd/_gen/simdgen/ops/ShiftRotate/categories.yaml
Normal file
|
|
@ -0,0 +1,103 @@
|
||||||
|
!sum
|
||||||
|
- go: ShiftAllLeft
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
specialLower: sftimm
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME shifts each element to the left by the specified number of bits. Emptied lower bits are zeroed.
|
||||||
|
- go: ShiftAllRight
|
||||||
|
signed: false
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
specialLower: sftimm
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME shifts each element to the right by the specified number of bits. Emptied upper bits are zeroed.
|
||||||
|
- go: ShiftAllRight
|
||||||
|
signed: true
|
||||||
|
specialLower: sftimm
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME shifts each element to the right by the specified number of bits. Emptied upper bits are filled with the sign bit.
|
||||||
|
- go: shiftAllLeftConst # no APIs, only ssa ops.
|
||||||
|
noTypes: "true"
|
||||||
|
noGenericOps: "true"
|
||||||
|
SSAVariant: "const" # to avoid its name colliding with reg version of this instruction, amend this to its ssa op name.
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
- go: shiftAllRightConst # no APIs, only ssa ops.
|
||||||
|
noTypes: "true"
|
||||||
|
noGenericOps: "true"
|
||||||
|
SSAVariant: "const"
|
||||||
|
signed: false
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
- go: shiftAllRightConst # no APIs, only ssa ops.
|
||||||
|
noTypes: "true"
|
||||||
|
noGenericOps: "true"
|
||||||
|
SSAVariant: "const"
|
||||||
|
signed: true
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
|
||||||
|
- go: ShiftLeft
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME shifts each element in x to the left by the number of bits specified in y's corresponding elements. Emptied lower bits are zeroed.
|
||||||
|
- go: ShiftRight
|
||||||
|
signed: false
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME shifts each element in x to the right by the number of bits specified in y's corresponding elements. Emptied upper bits are zeroed.
|
||||||
|
- go: ShiftRight
|
||||||
|
signed: true
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME shifts each element in x to the right by the number of bits specified in y's corresponding elements. Emptied upper bits are filled with the sign bit.
|
||||||
|
- go: RotateAllLeft
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME rotates each element to the left by the number of bits specified by the immediate.
|
||||||
|
- go: RotateLeft
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME rotates each element in x to the left by the number of bits specified by y's corresponding elements.
|
||||||
|
- go: RotateAllRight
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME rotates each element to the right by the number of bits specified by the immediate.
|
||||||
|
- go: RotateRight
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME rotates each element in x to the right by the number of bits specified by y's corresponding elements.
|
||||||
|
- go: ShiftAllLeftConcat
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME shifts each element of x to the left by the number of bits specified by the
|
||||||
|
// immediate(only the lower 5 bits are used), and then copies the upper bits of y to the emptied lower bits of the shifted x.
|
||||||
|
- go: ShiftAllRightConcat
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME shifts each element of x to the right by the number of bits specified by the
|
||||||
|
// immediate(only the lower 5 bits are used), and then copies the lower bits of y to the emptied upper bits of the shifted x.
|
||||||
|
- go: ShiftLeftConcat
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME shifts each element of x to the left by the number of bits specified by the
|
||||||
|
// corresponding elements in y(only the lower 5 bits are used), and then copies the upper bits of z to the emptied lower bits of the shifted x.
|
||||||
|
- go: ShiftRightConcat
|
||||||
|
nameAndSizeCheck: true
|
||||||
|
commutative: false
|
||||||
|
documentation: !string |-
|
||||||
|
// NAME shifts each element of x to the right by the number of bits specified by the
|
||||||
|
// corresponding elements in y(only the lower 5 bits are used), and then copies the lower bits of z to the emptied upper bits of the shifted x.
|
||||||
172
src/simd/_gen/simdgen/ops/ShiftRotate/go.yaml
Normal file
172
src/simd/_gen/simdgen/ops/ShiftRotate/go.yaml
Normal file
|
|
@ -0,0 +1,172 @@
|
||||||
|
!sum
|
||||||
|
# Integers
|
||||||
|
# ShiftAll*
|
||||||
|
- go: ShiftAllLeft
|
||||||
|
asm: "VPSLL[WDQ]"
|
||||||
|
in:
|
||||||
|
- &any
|
||||||
|
go: $t
|
||||||
|
- &vecAsScalar64
|
||||||
|
go: "Uint.*"
|
||||||
|
treatLikeAScalarOfSize: 64
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
- go: ShiftAllRight
|
||||||
|
signed: false
|
||||||
|
asm: "VPSRL[WDQ]"
|
||||||
|
in:
|
||||||
|
- &uint
|
||||||
|
go: $t
|
||||||
|
base: uint
|
||||||
|
- *vecAsScalar64
|
||||||
|
out:
|
||||||
|
- *uint
|
||||||
|
- go: ShiftAllRight
|
||||||
|
signed: true
|
||||||
|
asm: "VPSRA[WDQ]"
|
||||||
|
in:
|
||||||
|
- &int
|
||||||
|
go: $t
|
||||||
|
base: int
|
||||||
|
- *vecAsScalar64
|
||||||
|
out:
|
||||||
|
- *int
|
||||||
|
|
||||||
|
- go: shiftAllLeftConst
|
||||||
|
asm: "VPSLL[WDQ]"
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- &imm
|
||||||
|
class: immediate
|
||||||
|
immOffset: 0
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
- go: shiftAllRightConst
|
||||||
|
asm: "VPSRL[WDQ]"
|
||||||
|
in:
|
||||||
|
- *int
|
||||||
|
- *imm
|
||||||
|
out:
|
||||||
|
- *int
|
||||||
|
- go: shiftAllRightConst
|
||||||
|
asm: "VPSRA[WDQ]"
|
||||||
|
in:
|
||||||
|
- *uint
|
||||||
|
- *imm
|
||||||
|
out:
|
||||||
|
- *uint
|
||||||
|
|
||||||
|
# Shift* (variable)
|
||||||
|
- go: ShiftLeft
|
||||||
|
asm: "VPSLLV[WD]"
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
# XED data of VPSLLVQ marks the element bits 32 which is off to the actual semantic, we need to overwrite
|
||||||
|
# it to 64.
|
||||||
|
- go: ShiftLeft
|
||||||
|
asm: "VPSLLVQ"
|
||||||
|
in:
|
||||||
|
- &anyOverwriteElemBits
|
||||||
|
go: $t
|
||||||
|
overwriteElementBits: 64
|
||||||
|
- *anyOverwriteElemBits
|
||||||
|
out:
|
||||||
|
- *anyOverwriteElemBits
|
||||||
|
- go: ShiftRight
|
||||||
|
signed: false
|
||||||
|
asm: "VPSRLV[WD]"
|
||||||
|
in:
|
||||||
|
- *uint
|
||||||
|
- *uint
|
||||||
|
out:
|
||||||
|
- *uint
|
||||||
|
# XED data of VPSRLVQ needs the same overwrite as VPSLLVQ.
|
||||||
|
- go: ShiftRight
|
||||||
|
signed: false
|
||||||
|
asm: "VPSRLVQ"
|
||||||
|
in:
|
||||||
|
- &uintOverwriteElemBits
|
||||||
|
go: $t
|
||||||
|
base: uint
|
||||||
|
overwriteElementBits: 64
|
||||||
|
- *uintOverwriteElemBits
|
||||||
|
out:
|
||||||
|
- *uintOverwriteElemBits
|
||||||
|
- go: ShiftRight
|
||||||
|
signed: true
|
||||||
|
asm: "VPSRAV[WDQ]"
|
||||||
|
in:
|
||||||
|
- *int
|
||||||
|
- *int
|
||||||
|
out:
|
||||||
|
- *int
|
||||||
|
|
||||||
|
# Rotate
|
||||||
|
- go: RotateAllLeft
|
||||||
|
asm: "VPROL[DQ]"
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- &pureImm
|
||||||
|
class: immediate
|
||||||
|
immOffset: 0
|
||||||
|
name: shift
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
- go: RotateAllRight
|
||||||
|
asm: "VPROR[DQ]"
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- *pureImm
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
- go: RotateLeft
|
||||||
|
asm: "VPROLV[DQ]"
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
- go: RotateRight
|
||||||
|
asm: "VPRORV[DQ]"
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
|
||||||
|
# Bizzare shifts.
|
||||||
|
- go: ShiftAllLeftConcat
|
||||||
|
asm: "VPSHLD[WDQ]"
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
- *pureImm
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
- go: ShiftAllRightConcat
|
||||||
|
asm: "VPSHRD[WDQ]"
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
- *pureImm
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
- go: ShiftLeftConcat
|
||||||
|
asm: "VPSHLDV[WDQ]"
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
|
- go: ShiftRightConcat
|
||||||
|
asm: "VPSHRDV[WDQ]"
|
||||||
|
in:
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
- *any
|
||||||
|
out:
|
||||||
|
- *any
|
||||||
73
src/simd/_gen/simdgen/pprint.go
Normal file
73
src/simd/_gen/simdgen/pprint.go
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func pprints(v any) string {
|
||||||
|
var pp pprinter
|
||||||
|
pp.val(reflect.ValueOf(v), 0)
|
||||||
|
return string(pp.buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
type pprinter struct {
|
||||||
|
buf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pprinter) indent(by int) {
|
||||||
|
for range by {
|
||||||
|
p.buf = append(p.buf, '\t')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pprinter) val(v reflect.Value, indent int) {
|
||||||
|
switch v.Kind() {
|
||||||
|
default:
|
||||||
|
p.buf = fmt.Appendf(p.buf, "unsupported kind %v", v.Kind())
|
||||||
|
|
||||||
|
case reflect.Bool:
|
||||||
|
p.buf = strconv.AppendBool(p.buf, v.Bool())
|
||||||
|
|
||||||
|
case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
p.buf = strconv.AppendInt(p.buf, v.Int(), 10)
|
||||||
|
|
||||||
|
case reflect.String:
|
||||||
|
p.buf = strconv.AppendQuote(p.buf, v.String())
|
||||||
|
|
||||||
|
case reflect.Pointer:
|
||||||
|
if v.IsNil() {
|
||||||
|
p.buf = append(p.buf, "nil"...)
|
||||||
|
} else {
|
||||||
|
p.buf = append(p.buf, "&"...)
|
||||||
|
p.val(v.Elem(), indent)
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
p.buf = append(p.buf, "[\n"...)
|
||||||
|
for i := range v.Len() {
|
||||||
|
p.indent(indent + 1)
|
||||||
|
p.val(v.Index(i), indent+1)
|
||||||
|
p.buf = append(p.buf, ",\n"...)
|
||||||
|
}
|
||||||
|
p.indent(indent)
|
||||||
|
p.buf = append(p.buf, ']')
|
||||||
|
|
||||||
|
case reflect.Struct:
|
||||||
|
vt := v.Type()
|
||||||
|
p.buf = append(append(p.buf, vt.String()...), "{\n"...)
|
||||||
|
for f := range v.NumField() {
|
||||||
|
p.indent(indent + 1)
|
||||||
|
p.buf = append(append(p.buf, vt.Field(f).Name...), ": "...)
|
||||||
|
p.val(v.Field(f), indent+1)
|
||||||
|
p.buf = append(p.buf, ",\n"...)
|
||||||
|
}
|
||||||
|
p.indent(indent)
|
||||||
|
p.buf = append(p.buf, '}')
|
||||||
|
}
|
||||||
|
}
|
||||||
41
src/simd/_gen/simdgen/sort_test.go
Normal file
41
src/simd/_gen/simdgen/sort_test.go
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSort(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
s1, s2 string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"a1", "a2", -1},
|
||||||
|
{"a11a", "a11b", -1},
|
||||||
|
{"a01a1", "a1a01", -1},
|
||||||
|
{"a2", "a1", 1},
|
||||||
|
{"a10", "a2", 1},
|
||||||
|
{"a1", "a10", -1},
|
||||||
|
{"z11", "z2", 1},
|
||||||
|
{"z2", "z11", -1},
|
||||||
|
{"abc", "abd", -1},
|
||||||
|
{"123", "45", 1},
|
||||||
|
{"file1", "file1", 0},
|
||||||
|
{"file", "file1", -1},
|
||||||
|
{"file1", "file", 1},
|
||||||
|
{"a01", "a1", -1},
|
||||||
|
{"a1a", "a1b", -1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
got := compareNatural(tc.s1, tc.s2)
|
||||||
|
result := "✅"
|
||||||
|
if got != tc.want {
|
||||||
|
result = "❌"
|
||||||
|
t.Errorf("%s CompareNatural(\"%s\", \"%s\") -> got %2d, want %2d\n", result, tc.s1, tc.s2, got, tc.want)
|
||||||
|
} else {
|
||||||
|
t.Logf("%s CompareNatural(\"%s\", \"%s\") -> got %2d, want %2d\n", result, tc.s1, tc.s2, got, tc.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
90
src/simd/_gen/simdgen/types.yaml
Normal file
90
src/simd/_gen/simdgen/types.yaml
Normal file
|
|
@ -0,0 +1,90 @@
|
||||||
|
# This file defines the possible types of each operand and result.
|
||||||
|
#
|
||||||
|
# In general, we're able to narrow this down on some attributes directly from
|
||||||
|
# the machine instruction descriptions, but the Go mappings need to further
|
||||||
|
# constrain them and how they relate. For example, on x86 we can't distinguish
|
||||||
|
# int and uint, though we can distinguish these from float.
|
||||||
|
|
||||||
|
in: !repeat
|
||||||
|
- !sum &types
|
||||||
|
- {class: vreg, go: Int8x16, base: "int", elemBits: 8, bits: 128, lanes: 16}
|
||||||
|
- {class: vreg, go: Uint8x16, base: "uint", elemBits: 8, bits: 128, lanes: 16}
|
||||||
|
- {class: vreg, go: Int16x8, base: "int", elemBits: 16, bits: 128, lanes: 8}
|
||||||
|
- {class: vreg, go: Uint16x8, base: "uint", elemBits: 16, bits: 128, lanes: 8}
|
||||||
|
- {class: vreg, go: Int32x4, base: "int", elemBits: 32, bits: 128, lanes: 4}
|
||||||
|
- {class: vreg, go: Uint32x4, base: "uint", elemBits: 32, bits: 128, lanes: 4}
|
||||||
|
- {class: vreg, go: Int64x2, base: "int", elemBits: 64, bits: 128, lanes: 2}
|
||||||
|
- {class: vreg, go: Uint64x2, base: "uint", elemBits: 64, bits: 128, lanes: 2}
|
||||||
|
- {class: vreg, go: Float32x4, base: "float", elemBits: 32, bits: 128, lanes: 4}
|
||||||
|
- {class: vreg, go: Float64x2, base: "float", elemBits: 64, bits: 128, lanes: 2}
|
||||||
|
- {class: vreg, go: Int8x32, base: "int", elemBits: 8, bits: 256, lanes: 32}
|
||||||
|
- {class: vreg, go: Uint8x32, base: "uint", elemBits: 8, bits: 256, lanes: 32}
|
||||||
|
- {class: vreg, go: Int16x16, base: "int", elemBits: 16, bits: 256, lanes: 16}
|
||||||
|
- {class: vreg, go: Uint16x16, base: "uint", elemBits: 16, bits: 256, lanes: 16}
|
||||||
|
- {class: vreg, go: Int32x8, base: "int", elemBits: 32, bits: 256, lanes: 8}
|
||||||
|
- {class: vreg, go: Uint32x8, base: "uint", elemBits: 32, bits: 256, lanes: 8}
|
||||||
|
- {class: vreg, go: Int64x4, base: "int", elemBits: 64, bits: 256, lanes: 4}
|
||||||
|
- {class: vreg, go: Uint64x4, base: "uint", elemBits: 64, bits: 256, lanes: 4}
|
||||||
|
- {class: vreg, go: Float32x8, base: "float", elemBits: 32, bits: 256, lanes: 8}
|
||||||
|
- {class: vreg, go: Float64x4, base: "float", elemBits: 64, bits: 256, lanes: 4}
|
||||||
|
- {class: vreg, go: Int8x64, base: "int", elemBits: 8, bits: 512, lanes: 64}
|
||||||
|
- {class: vreg, go: Uint8x64, base: "uint", elemBits: 8, bits: 512, lanes: 64}
|
||||||
|
- {class: vreg, go: Int16x32, base: "int", elemBits: 16, bits: 512, lanes: 32}
|
||||||
|
- {class: vreg, go: Uint16x32, base: "uint", elemBits: 16, bits: 512, lanes: 32}
|
||||||
|
- {class: vreg, go: Int32x16, base: "int", elemBits: 32, bits: 512, lanes: 16}
|
||||||
|
- {class: vreg, go: Uint32x16, base: "uint", elemBits: 32, bits: 512, lanes: 16}
|
||||||
|
- {class: vreg, go: Int64x8, base: "int", elemBits: 64, bits: 512, lanes: 8}
|
||||||
|
- {class: vreg, go: Uint64x8, base: "uint", elemBits: 64, bits: 512, lanes: 8}
|
||||||
|
- {class: vreg, go: Float32x16, base: "float", elemBits: 32, bits: 512, lanes: 16}
|
||||||
|
- {class: vreg, go: Float64x8, base: "float", elemBits: 64, bits: 512, lanes: 8}
|
||||||
|
|
||||||
|
- {class: mask, go: Mask8x16, base: "int", elemBits: 8, bits: 128, lanes: 16}
|
||||||
|
- {class: mask, go: Mask16x8, base: "int", elemBits: 16, bits: 128, lanes: 8}
|
||||||
|
- {class: mask, go: Mask32x4, base: "int", elemBits: 32, bits: 128, lanes: 4}
|
||||||
|
- {class: mask, go: Mask64x2, base: "int", elemBits: 64, bits: 128, lanes: 2}
|
||||||
|
- {class: mask, go: Mask8x32, base: "int", elemBits: 8, bits: 256, lanes: 32}
|
||||||
|
- {class: mask, go: Mask16x16, base: "int", elemBits: 16, bits: 256, lanes: 16}
|
||||||
|
- {class: mask, go: Mask32x8, base: "int", elemBits: 32, bits: 256, lanes: 8}
|
||||||
|
- {class: mask, go: Mask64x4, base: "int", elemBits: 64, bits: 256, lanes: 4}
|
||||||
|
- {class: mask, go: Mask8x64, base: "int", elemBits: 8, bits: 512, lanes: 64}
|
||||||
|
- {class: mask, go: Mask16x32, base: "int", elemBits: 16, bits: 512, lanes: 32}
|
||||||
|
- {class: mask, go: Mask32x16, base: "int", elemBits: 32, bits: 512, lanes: 16}
|
||||||
|
- {class: mask, go: Mask64x8, base: "int", elemBits: 64, bits: 512, lanes: 8}
|
||||||
|
|
||||||
|
|
||||||
|
- {class: greg, go: float64, base: "float", bits: 64, lanes: 1}
|
||||||
|
- {class: greg, go: float32, base: "float", bits: 32, lanes: 1}
|
||||||
|
- {class: greg, go: int64, base: "int", bits: 64, lanes: 1}
|
||||||
|
- {class: greg, go: int32, base: "int", bits: 32, lanes: 1}
|
||||||
|
- {class: greg, go: int16, base: "int", bits: 16, lanes: 1}
|
||||||
|
- {class: greg, go: int8, base: "int", bits: 8, lanes: 1}
|
||||||
|
- {class: greg, go: uint64, base: "uint", bits: 64, lanes: 1}
|
||||||
|
- {class: greg, go: uint32, base: "uint", bits: 32, lanes: 1}
|
||||||
|
- {class: greg, go: uint16, base: "uint", bits: 16, lanes: 1}
|
||||||
|
- {class: greg, go: uint8, base: "uint", bits: 8, lanes: 1}
|
||||||
|
|
||||||
|
# Special shapes just to make INSERT[IF]128 work.
|
||||||
|
# The elemBits field of these shapes are wrong, it would be overwritten by overwriteElemBits.
|
||||||
|
- {class: vreg, go: Int8x16, base: "int", elemBits: 128, bits: 128, lanes: 16}
|
||||||
|
- {class: vreg, go: Uint8x16, base: "uint", elemBits: 128, bits: 128, lanes: 16}
|
||||||
|
- {class: vreg, go: Int16x8, base: "int", elemBits: 128, bits: 128, lanes: 8}
|
||||||
|
- {class: vreg, go: Uint16x8, base: "uint", elemBits: 128, bits: 128, lanes: 8}
|
||||||
|
- {class: vreg, go: Int32x4, base: "int", elemBits: 128, bits: 128, lanes: 4}
|
||||||
|
- {class: vreg, go: Uint32x4, base: "uint", elemBits: 128, bits: 128, lanes: 4}
|
||||||
|
- {class: vreg, go: Int64x2, base: "int", elemBits: 128, bits: 128, lanes: 2}
|
||||||
|
- {class: vreg, go: Uint64x2, base: "uint", elemBits: 128, bits: 128, lanes: 2}
|
||||||
|
|
||||||
|
- {class: vreg, go: Int8x32, base: "int", elemBits: 128, bits: 256, lanes: 32}
|
||||||
|
- {class: vreg, go: Uint8x32, base: "uint", elemBits: 128, bits: 256, lanes: 32}
|
||||||
|
- {class: vreg, go: Int16x16, base: "int", elemBits: 128, bits: 256, lanes: 16}
|
||||||
|
- {class: vreg, go: Uint16x16, base: "uint", elemBits: 128, bits: 256, lanes: 16}
|
||||||
|
- {class: vreg, go: Int32x8, base: "int", elemBits: 128, bits: 256, lanes: 8}
|
||||||
|
- {class: vreg, go: Uint32x8, base: "uint", elemBits: 128, bits: 256, lanes: 8}
|
||||||
|
- {class: vreg, go: Int64x4, base: "int", elemBits: 128, bits: 256, lanes: 4}
|
||||||
|
- {class: vreg, go: Uint64x4, base: "uint", elemBits: 128, bits: 256, lanes: 4}
|
||||||
|
|
||||||
|
- {class: immediate, go: Immediate} # TODO: we only support imms that are not used as value -- usually as instruction semantic predicate like VPCMP as of now.
|
||||||
|
inVariant: !repeat
|
||||||
|
- *types
|
||||||
|
out: !repeat
|
||||||
|
- *types
|
||||||
780
src/simd/_gen/simdgen/xed.go
Normal file
780
src/simd/_gen/simdgen/xed.go
Normal file
|
|
@ -0,0 +1,780 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"maps"
|
||||||
|
"regexp"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/arch/x86/xeddata"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
"simd/_gen/unify"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
NOT_REG_CLASS = 0 // not a register
|
||||||
|
VREG_CLASS = 1 // classify as a vector register; see
|
||||||
|
GREG_CLASS = 2 // classify as a general register
|
||||||
|
)
|
||||||
|
|
||||||
|
// instVariant is a bitmap indicating a variant of an instruction that has
|
||||||
|
// optional parameters.
|
||||||
|
type instVariant uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
instVariantNone instVariant = 0
|
||||||
|
|
||||||
|
// instVariantMasked indicates that this is the masked variant of an
|
||||||
|
// optionally-masked instruction.
|
||||||
|
instVariantMasked instVariant = 1 << iota
|
||||||
|
)
|
||||||
|
|
||||||
|
var operandRemarks int
|
||||||
|
|
||||||
|
// TODO: Doc. Returns Values with Def domains.
|
||||||
|
func loadXED(xedPath string) []*unify.Value {
|
||||||
|
// TODO: Obviously a bunch more to do here.
|
||||||
|
|
||||||
|
db, err := xeddata.NewDatabase(xedPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("open database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var defs []*unify.Value
|
||||||
|
err = xeddata.WalkInsts(xedPath, func(inst *xeddata.Inst) {
|
||||||
|
inst.Pattern = xeddata.ExpandStates(db, inst.Pattern)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case inst.RealOpcode == "N":
|
||||||
|
return // Skip unstable instructions
|
||||||
|
case !strings.HasPrefix(inst.Extension, "AVX"):
|
||||||
|
// We're only interested in AVX instructions.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if *flagDebugXED {
|
||||||
|
fmt.Printf("%s:\n%+v\n", inst.Pos, inst)
|
||||||
|
}
|
||||||
|
|
||||||
|
ops, err := decodeOperands(db, strings.Fields(inst.Operands))
|
||||||
|
if err != nil {
|
||||||
|
operandRemarks++
|
||||||
|
if *Verbose {
|
||||||
|
log.Printf("%s: [%s] %s", inst.Pos, inst.Opcode(), err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
applyQuirks(inst, ops)
|
||||||
|
|
||||||
|
defsPos := len(defs)
|
||||||
|
defs = append(defs, instToUVal(inst, ops)...)
|
||||||
|
|
||||||
|
if *flagDebugXED {
|
||||||
|
for i := defsPos; i < len(defs); i++ {
|
||||||
|
y, _ := yaml.Marshal(defs[i])
|
||||||
|
fmt.Printf("==>\n%s\n", y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("walk insts: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(unknownFeatures) > 0 {
|
||||||
|
if !*Verbose {
|
||||||
|
nInst := 0
|
||||||
|
for _, insts := range unknownFeatures {
|
||||||
|
nInst += len(insts)
|
||||||
|
}
|
||||||
|
log.Printf("%d unhandled CPU features for %d instructions (use -v for details)", len(unknownFeatures), nInst)
|
||||||
|
} else {
|
||||||
|
keys := slices.SortedFunc(maps.Keys(unknownFeatures), func(a, b cpuFeatureKey) int {
|
||||||
|
return cmp.Or(cmp.Compare(a.Extension, b.Extension),
|
||||||
|
cmp.Compare(a.ISASet, b.ISASet))
|
||||||
|
})
|
||||||
|
for _, key := range keys {
|
||||||
|
if key.ISASet == "" || key.ISASet == key.Extension {
|
||||||
|
log.Printf("unhandled Extension %s", key.Extension)
|
||||||
|
} else {
|
||||||
|
log.Printf("unhandled Extension %s and ISASet %s", key.Extension, key.ISASet)
|
||||||
|
}
|
||||||
|
log.Printf(" opcodes: %s", slices.Sorted(maps.Keys(unknownFeatures[key])))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return defs
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
maskRequiredRe = regexp.MustCompile(`VPCOMPRESS[BWDQ]|VCOMPRESSP[SD]|VPEXPAND[BWDQ]|VEXPANDP[SD]`)
|
||||||
|
maskOptionalRe = regexp.MustCompile(`VPCMP(EQ|GT|U)?[BWDQ]|VCMPP[SD]`)
|
||||||
|
)
|
||||||
|
|
||||||
|
func applyQuirks(inst *xeddata.Inst, ops []operand) {
|
||||||
|
opc := inst.Opcode()
|
||||||
|
switch {
|
||||||
|
case maskRequiredRe.MatchString(opc):
|
||||||
|
// The mask on these instructions is marked optional, but the
|
||||||
|
// instruction is pointless without the mask.
|
||||||
|
for i, op := range ops {
|
||||||
|
if op, ok := op.(operandMask); ok {
|
||||||
|
op.optional = false
|
||||||
|
ops[i] = op
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case maskOptionalRe.MatchString(opc):
|
||||||
|
// Conversely, these masks should be marked optional and aren't.
|
||||||
|
for i, op := range ops {
|
||||||
|
if op, ok := op.(operandMask); ok && op.action.r {
|
||||||
|
op.optional = true
|
||||||
|
ops[i] = op
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type operandCommon struct {
|
||||||
|
action operandAction
|
||||||
|
}
|
||||||
|
|
||||||
|
// operandAction defines whether this operand is read and/or written.
|
||||||
|
//
|
||||||
|
// TODO: Should this live in [xeddata.Operand]?
|
||||||
|
type operandAction struct {
|
||||||
|
r bool // Read
|
||||||
|
w bool // Written
|
||||||
|
cr bool // Read is conditional (implies r==true)
|
||||||
|
cw bool // Write is conditional (implies w==true)
|
||||||
|
}
|
||||||
|
|
||||||
|
type operandMem struct {
|
||||||
|
operandCommon
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
type vecShape struct {
|
||||||
|
elemBits int // Element size in bits
|
||||||
|
bits int // Register width in bits (total vector bits)
|
||||||
|
}
|
||||||
|
|
||||||
|
type operandVReg struct { // Vector register
|
||||||
|
operandCommon
|
||||||
|
vecShape
|
||||||
|
elemBaseType scalarBaseType
|
||||||
|
}
|
||||||
|
|
||||||
|
type operandGReg struct { // Vector register
|
||||||
|
operandCommon
|
||||||
|
vecShape
|
||||||
|
elemBaseType scalarBaseType
|
||||||
|
}
|
||||||
|
|
||||||
|
// operandMask is a vector mask.
|
||||||
|
//
|
||||||
|
// Regardless of the actual mask representation, the [vecShape] of this operand
|
||||||
|
// corresponds to the "bit for bit" type of mask. That is, elemBits gives the
|
||||||
|
// element width covered by each mask element, and bits/elemBits gives the total
|
||||||
|
// number of mask elements. (bits gives the total number of bits as if this were
|
||||||
|
// a bit-for-bit mask, which may be meaningless on its own.)
|
||||||
|
type operandMask struct {
|
||||||
|
operandCommon
|
||||||
|
vecShape
|
||||||
|
// Bits in the mask is w/bits.
|
||||||
|
|
||||||
|
allMasks bool // If set, size cannot be inferred because all operands are masks.
|
||||||
|
|
||||||
|
// Mask can be omitted, in which case it defaults to K0/"no mask"
|
||||||
|
optional bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type operandImm struct {
|
||||||
|
operandCommon
|
||||||
|
bits int // Immediate size in bits
|
||||||
|
}
|
||||||
|
|
||||||
|
type operand interface {
|
||||||
|
common() operandCommon
|
||||||
|
addToDef(b *unify.DefBuilder)
|
||||||
|
}
|
||||||
|
|
||||||
|
func strVal(s any) *unify.Value {
|
||||||
|
return unify.NewValue(unify.NewStringExact(fmt.Sprint(s)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o operandCommon) common() operandCommon {
|
||||||
|
return o
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o operandMem) addToDef(b *unify.DefBuilder) {
|
||||||
|
// TODO: w, base
|
||||||
|
b.Add("class", strVal("memory"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o operandVReg) addToDef(b *unify.DefBuilder) {
|
||||||
|
baseDomain, err := unify.NewStringRegex(o.elemBaseType.regex())
|
||||||
|
if err != nil {
|
||||||
|
panic("parsing baseRe: " + err.Error())
|
||||||
|
}
|
||||||
|
b.Add("class", strVal("vreg"))
|
||||||
|
b.Add("bits", strVal(o.bits))
|
||||||
|
b.Add("base", unify.NewValue(baseDomain))
|
||||||
|
// If elemBits == bits, then the vector can be ANY shape. This happens with,
|
||||||
|
// for example, logical ops.
|
||||||
|
if o.elemBits != o.bits {
|
||||||
|
b.Add("elemBits", strVal(o.elemBits))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o operandGReg) addToDef(b *unify.DefBuilder) {
|
||||||
|
baseDomain, err := unify.NewStringRegex(o.elemBaseType.regex())
|
||||||
|
if err != nil {
|
||||||
|
panic("parsing baseRe: " + err.Error())
|
||||||
|
}
|
||||||
|
b.Add("class", strVal("greg"))
|
||||||
|
b.Add("bits", strVal(o.bits))
|
||||||
|
b.Add("base", unify.NewValue(baseDomain))
|
||||||
|
if o.elemBits != o.bits {
|
||||||
|
b.Add("elemBits", strVal(o.elemBits))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o operandMask) addToDef(b *unify.DefBuilder) {
|
||||||
|
b.Add("class", strVal("mask"))
|
||||||
|
if o.allMasks {
|
||||||
|
// If all operands are masks, omit sizes and let unification determine mask sizes.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.Add("elemBits", strVal(o.elemBits))
|
||||||
|
b.Add("bits", strVal(o.bits))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o operandImm) addToDef(b *unify.DefBuilder) {
|
||||||
|
b.Add("class", strVal("immediate"))
|
||||||
|
b.Add("bits", strVal(o.bits))
|
||||||
|
}
|
||||||
|
|
||||||
|
var actionEncoding = map[string]operandAction{
|
||||||
|
"r": {r: true},
|
||||||
|
"cr": {r: true, cr: true},
|
||||||
|
"w": {w: true},
|
||||||
|
"cw": {w: true, cw: true},
|
||||||
|
"rw": {r: true, w: true},
|
||||||
|
"crw": {r: true, w: true, cr: true},
|
||||||
|
"rcw": {r: true, w: true, cw: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeOperand(db *xeddata.Database, operand string) (operand, error) {
|
||||||
|
op, err := xeddata.NewOperand(db, operand)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("parsing operand %q: %v", operand, err)
|
||||||
|
}
|
||||||
|
if *flagDebugXED {
|
||||||
|
fmt.Printf(" %+v\n", op)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(op.Name, "EMX_BROADCAST") {
|
||||||
|
// This refers to a set of macros defined in all-state.txt that set a
|
||||||
|
// BCAST operand to various fixed values. But the BCAST operand is
|
||||||
|
// itself suppressed and "internal", so I think we can just ignore this
|
||||||
|
// operand.
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: See xed_decoded_inst_operand_action. This might need to be more
|
||||||
|
// complicated.
|
||||||
|
action, ok := actionEncoding[op.Action]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown action %q", op.Action)
|
||||||
|
}
|
||||||
|
common := operandCommon{action: action}
|
||||||
|
|
||||||
|
lhs := op.NameLHS()
|
||||||
|
if strings.HasPrefix(lhs, "MEM") {
|
||||||
|
// TODO: Width, base type
|
||||||
|
return operandMem{
|
||||||
|
operandCommon: common,
|
||||||
|
}, nil
|
||||||
|
} else if strings.HasPrefix(lhs, "REG") {
|
||||||
|
if op.Width == "mskw" {
|
||||||
|
// The mask operand doesn't specify a width. We have to infer it.
|
||||||
|
//
|
||||||
|
// XED uses the marker ZEROSTR to indicate that a mask operand is
|
||||||
|
// optional and, if omitted, implies K0, aka "no mask".
|
||||||
|
return operandMask{
|
||||||
|
operandCommon: common,
|
||||||
|
optional: op.Attributes["TXT=ZEROSTR"],
|
||||||
|
}, nil
|
||||||
|
} else {
|
||||||
|
class, regBits := decodeReg(op)
|
||||||
|
if class == NOT_REG_CLASS {
|
||||||
|
return nil, fmt.Errorf("failed to decode register %q", operand)
|
||||||
|
}
|
||||||
|
baseType, elemBits, ok := decodeType(op)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("failed to decode register width %q", operand)
|
||||||
|
}
|
||||||
|
shape := vecShape{elemBits: elemBits, bits: regBits}
|
||||||
|
if class == VREG_CLASS {
|
||||||
|
return operandVReg{
|
||||||
|
operandCommon: common,
|
||||||
|
vecShape: shape,
|
||||||
|
elemBaseType: baseType,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
// general register
|
||||||
|
m := min(shape.bits, shape.elemBits)
|
||||||
|
shape.bits, shape.elemBits = m, m
|
||||||
|
return operandGReg{
|
||||||
|
operandCommon: common,
|
||||||
|
vecShape: shape,
|
||||||
|
elemBaseType: baseType,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(lhs, "IMM") {
|
||||||
|
_, bits, ok := decodeType(op)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("failed to decode register width %q", operand)
|
||||||
|
}
|
||||||
|
return operandImm{
|
||||||
|
operandCommon: common,
|
||||||
|
bits: bits,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: BASE and SEG
|
||||||
|
return nil, fmt.Errorf("unknown operand LHS %q in %q", lhs, operand)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeOperands(db *xeddata.Database, operands []string) (ops []operand, err error) {
|
||||||
|
// Decode the XED operand descriptions.
|
||||||
|
for _, o := range operands {
|
||||||
|
op, err := decodeOperand(db, o)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if op != nil {
|
||||||
|
ops = append(ops, op)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// XED doesn't encode the size of mask operands. If there are mask operands,
|
||||||
|
// try to infer their sizes from other operands.
|
||||||
|
if err := inferMaskSizes(ops); err != nil {
|
||||||
|
return nil, fmt.Errorf("%w in operands %+v", err, operands)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ops, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferMaskSizes(ops []operand) error {
|
||||||
|
// This is a heuristic and it falls apart in some cases:
|
||||||
|
//
|
||||||
|
// - Mask operations like KAND[BWDQ] have *nothing* in the XED to indicate
|
||||||
|
// mask size.
|
||||||
|
//
|
||||||
|
// - VINSERT*, VPSLL*, VPSRA*, and VPSRL* and some others naturally have
|
||||||
|
// mixed input sizes and the XED doesn't indicate which operands the mask
|
||||||
|
// applies to.
|
||||||
|
//
|
||||||
|
// - VPDP* and VP4DP* have really complex mixed operand patterns.
|
||||||
|
//
|
||||||
|
// I think for these we may just have to hand-write a table of which
|
||||||
|
// operands each mask applies to.
|
||||||
|
inferMask := func(r, w bool) error {
|
||||||
|
var masks []int
|
||||||
|
var rSizes, wSizes, sizes []vecShape
|
||||||
|
allMasks := true
|
||||||
|
hasWMask := false
|
||||||
|
for i, op := range ops {
|
||||||
|
action := op.common().action
|
||||||
|
if _, ok := op.(operandMask); ok {
|
||||||
|
if action.r && action.w {
|
||||||
|
return fmt.Errorf("unexpected rw mask")
|
||||||
|
}
|
||||||
|
if action.r == r || action.w == w {
|
||||||
|
masks = append(masks, i)
|
||||||
|
}
|
||||||
|
if action.w {
|
||||||
|
hasWMask = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
allMasks = false
|
||||||
|
if reg, ok := op.(operandVReg); ok {
|
||||||
|
if action.r {
|
||||||
|
rSizes = append(rSizes, reg.vecShape)
|
||||||
|
}
|
||||||
|
if action.w {
|
||||||
|
wSizes = append(wSizes, reg.vecShape)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(masks) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r {
|
||||||
|
sizes = rSizes
|
||||||
|
if len(sizes) == 0 {
|
||||||
|
sizes = wSizes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if w {
|
||||||
|
sizes = wSizes
|
||||||
|
if len(sizes) == 0 {
|
||||||
|
sizes = rSizes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sizes) == 0 {
|
||||||
|
// If all operands are masks, leave the mask inferrence to the users.
|
||||||
|
if allMasks {
|
||||||
|
for _, i := range masks {
|
||||||
|
m := ops[i].(operandMask)
|
||||||
|
m.allMasks = true
|
||||||
|
ops[i] = m
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("cannot infer mask size: no register operands")
|
||||||
|
}
|
||||||
|
shape, ok := singular(sizes)
|
||||||
|
if !ok {
|
||||||
|
if !hasWMask && len(wSizes) == 1 && len(masks) == 1 {
|
||||||
|
// This pattern looks like predicate mask, so its shape should align with the
|
||||||
|
// output. TODO: verify this is a safe assumption.
|
||||||
|
shape = wSizes[0]
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("cannot infer mask size: multiple register sizes %v", sizes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, i := range masks {
|
||||||
|
m := ops[i].(operandMask)
|
||||||
|
m.vecShape = shape
|
||||||
|
ops[i] = m
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := inferMask(true, false); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := inferMask(false, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addOperandstoDef adds "in", "inVariant", and "out" to an instruction Def.
|
||||||
|
//
|
||||||
|
// Optional mask input operands are added to the inVariant field if
|
||||||
|
// variant&instVariantMasked, and omitted otherwise.
|
||||||
|
func addOperandsToDef(ops []operand, instDB *unify.DefBuilder, variant instVariant) {
|
||||||
|
var inVals, inVar, outVals []*unify.Value
|
||||||
|
asmPos := 0
|
||||||
|
for _, op := range ops {
|
||||||
|
var db unify.DefBuilder
|
||||||
|
op.addToDef(&db)
|
||||||
|
db.Add("asmPos", unify.NewValue(unify.NewStringExact(fmt.Sprint(asmPos))))
|
||||||
|
|
||||||
|
action := op.common().action
|
||||||
|
asmCount := 1 // # of assembly operands; 0 or 1
|
||||||
|
if action.r {
|
||||||
|
inVal := unify.NewValue(db.Build())
|
||||||
|
// If this is an optional mask, put it in the input variant tuple.
|
||||||
|
if mask, ok := op.(operandMask); ok && mask.optional {
|
||||||
|
if variant&instVariantMasked != 0 {
|
||||||
|
inVar = append(inVar, inVal)
|
||||||
|
} else {
|
||||||
|
// This operand doesn't appear in the assembly at all.
|
||||||
|
asmCount = 0
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Just a regular input operand.
|
||||||
|
inVals = append(inVals, inVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if action.w {
|
||||||
|
outVal := unify.NewValue(db.Build())
|
||||||
|
outVals = append(outVals, outVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
asmPos += asmCount
|
||||||
|
}
|
||||||
|
|
||||||
|
instDB.Add("in", unify.NewValue(unify.NewTuple(inVals...)))
|
||||||
|
instDB.Add("inVariant", unify.NewValue(unify.NewTuple(inVar...)))
|
||||||
|
instDB.Add("out", unify.NewValue(unify.NewTuple(outVals...)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func instToUVal(inst *xeddata.Inst, ops []operand) []*unify.Value {
|
||||||
|
feature, ok := decodeCPUFeature(inst)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var vals []*unify.Value
|
||||||
|
vals = append(vals, instToUVal1(inst, ops, feature, instVariantNone))
|
||||||
|
if hasOptionalMask(ops) {
|
||||||
|
vals = append(vals, instToUVal1(inst, ops, feature, instVariantMasked))
|
||||||
|
}
|
||||||
|
return vals
|
||||||
|
}
|
||||||
|
|
||||||
|
func instToUVal1(inst *xeddata.Inst, ops []operand, feature string, variant instVariant) *unify.Value {
|
||||||
|
var db unify.DefBuilder
|
||||||
|
db.Add("goarch", unify.NewValue(unify.NewStringExact("amd64")))
|
||||||
|
db.Add("asm", unify.NewValue(unify.NewStringExact(inst.Opcode())))
|
||||||
|
addOperandsToDef(ops, &db, variant)
|
||||||
|
db.Add("cpuFeature", unify.NewValue(unify.NewStringExact(feature)))
|
||||||
|
|
||||||
|
if strings.Contains(inst.Pattern, "ZEROING=0") {
|
||||||
|
// This is an EVEX instruction, but the ".Z" (zero-merging)
|
||||||
|
// instruction flag is NOT valid. EVEX.z must be zero.
|
||||||
|
//
|
||||||
|
// This can mean a few things:
|
||||||
|
//
|
||||||
|
// - The output of an instruction is a mask, so merging modes don't
|
||||||
|
// make any sense. E.g., VCMPPS.
|
||||||
|
//
|
||||||
|
// - There are no masks involved anywhere. (Maybe MASK=0 is also set
|
||||||
|
// in this case?) E.g., VINSERTPS.
|
||||||
|
//
|
||||||
|
// - The operation inherently performs merging. E.g., VCOMPRESSPS
|
||||||
|
// with a mem operand.
|
||||||
|
//
|
||||||
|
// There may be other reasons.
|
||||||
|
db.Add("zeroing", unify.NewValue(unify.NewStringExact("false")))
|
||||||
|
}
|
||||||
|
pos := unify.Pos{Path: inst.Pos.Path, Line: inst.Pos.Line}
|
||||||
|
return unify.NewValuePos(db.Build(), pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeCPUFeature returns the CPU feature name required by inst. These match
|
||||||
|
// the names of the "Has*" feature checks in the simd package.
|
||||||
|
func decodeCPUFeature(inst *xeddata.Inst) (string, bool) {
|
||||||
|
key := cpuFeatureKey{
|
||||||
|
Extension: inst.Extension,
|
||||||
|
ISASet: isaSetStrip.ReplaceAllLiteralString(inst.ISASet, ""),
|
||||||
|
}
|
||||||
|
feat, ok := cpuFeatureMap[key]
|
||||||
|
if !ok {
|
||||||
|
imap := unknownFeatures[key]
|
||||||
|
if imap == nil {
|
||||||
|
imap = make(map[string]struct{})
|
||||||
|
unknownFeatures[key] = imap
|
||||||
|
}
|
||||||
|
imap[inst.Opcode()] = struct{}{}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if feat == "ignore" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return feat, true
|
||||||
|
}
|
||||||
|
|
||||||
|
var isaSetStrip = regexp.MustCompile("_(128N?|256N?|512)$")
|
||||||
|
|
||||||
|
type cpuFeatureKey struct {
|
||||||
|
Extension, ISASet string
|
||||||
|
}
|
||||||
|
|
||||||
|
// cpuFeatureMap maps from XED's "EXTENSION" and "ISA_SET" to a CPU feature name
|
||||||
|
// that can be used in the SIMD API.
|
||||||
|
var cpuFeatureMap = map[cpuFeatureKey]string{
|
||||||
|
{"AVX", ""}: "AVX",
|
||||||
|
{"AVX_VNNI", "AVX_VNNI"}: "AVXVNNI",
|
||||||
|
{"AVX2", ""}: "AVX2",
|
||||||
|
|
||||||
|
// AVX-512 foundational features. We combine all of these into one "AVX512" feature.
|
||||||
|
{"AVX512EVEX", "AVX512F"}: "AVX512",
|
||||||
|
{"AVX512EVEX", "AVX512CD"}: "AVX512",
|
||||||
|
{"AVX512EVEX", "AVX512BW"}: "AVX512",
|
||||||
|
{"AVX512EVEX", "AVX512DQ"}: "AVX512",
|
||||||
|
// AVX512VL doesn't appear explicitly in the ISASet. I guess it's implied by
|
||||||
|
// the vector length suffix.
|
||||||
|
|
||||||
|
// AVX-512 extension features
|
||||||
|
{"AVX512EVEX", "AVX512_BITALG"}: "AVX512BITALG",
|
||||||
|
{"AVX512EVEX", "AVX512_GFNI"}: "AVX512GFNI",
|
||||||
|
{"AVX512EVEX", "AVX512_VBMI2"}: "AVX512VBMI2",
|
||||||
|
{"AVX512EVEX", "AVX512_VBMI"}: "AVX512VBMI",
|
||||||
|
{"AVX512EVEX", "AVX512_VNNI"}: "AVX512VNNI",
|
||||||
|
{"AVX512EVEX", "AVX512_VPOPCNTDQ"}: "AVX512VPOPCNTDQ",
|
||||||
|
|
||||||
|
// AVX 10.2 (not yet supported)
|
||||||
|
{"AVX512EVEX", "AVX10_2_RC"}: "ignore",
|
||||||
|
}
|
||||||
|
|
||||||
|
var unknownFeatures = map[cpuFeatureKey]map[string]struct{}{}
|
||||||
|
|
||||||
|
// hasOptionalMask returns whether there is an optional mask operand in ops.
|
||||||
|
func hasOptionalMask(ops []operand) bool {
|
||||||
|
for _, op := range ops {
|
||||||
|
if op, ok := op.(operandMask); ok && op.optional {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func singular[T comparable](xs []T) (T, bool) {
|
||||||
|
if len(xs) == 0 {
|
||||||
|
return *new(T), false
|
||||||
|
}
|
||||||
|
for _, x := range xs[1:] {
|
||||||
|
if x != xs[0] {
|
||||||
|
return *new(T), false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return xs[0], true
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeReg returns class (NOT_REG_CLASS, VREG_CLASS, GREG_CLASS),
|
||||||
|
// and width in bits. If the operand cannot be decided as a register,
|
||||||
|
// then the clas is NOT_REG_CLASS.
|
||||||
|
func decodeReg(op *xeddata.Operand) (class, width int) {
|
||||||
|
// op.Width tells us the total width, e.g.,:
|
||||||
|
//
|
||||||
|
// dq => 128 bits (XMM)
|
||||||
|
// qq => 256 bits (YMM)
|
||||||
|
// mskw => K
|
||||||
|
// z[iuf?](8|16|32|...) => 512 bits (ZMM)
|
||||||
|
//
|
||||||
|
// But the encoding is really weird and it's not clear if these *always*
|
||||||
|
// mean XMM/YMM/ZMM or if other irregular things can use these large widths.
|
||||||
|
// Hence, we dig into the register sets themselves.
|
||||||
|
|
||||||
|
if !strings.HasPrefix(op.NameLHS(), "REG") {
|
||||||
|
return NOT_REG_CLASS, 0
|
||||||
|
}
|
||||||
|
// TODO: We shouldn't be relying on the macro naming conventions. We should
|
||||||
|
// use all-dec-patterns.txt, but xeddata doesn't support that table right now.
|
||||||
|
rhs := op.NameRHS()
|
||||||
|
if !strings.HasSuffix(rhs, "()") {
|
||||||
|
return NOT_REG_CLASS, 0
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(rhs, "XMM_"):
|
||||||
|
return VREG_CLASS, 128
|
||||||
|
case strings.HasPrefix(rhs, "YMM_"):
|
||||||
|
return VREG_CLASS, 256
|
||||||
|
case strings.HasPrefix(rhs, "ZMM_"):
|
||||||
|
return VREG_CLASS, 512
|
||||||
|
case strings.HasPrefix(rhs, "GPR64_"), strings.HasPrefix(rhs, "VGPR64_"):
|
||||||
|
return GREG_CLASS, 64
|
||||||
|
case strings.HasPrefix(rhs, "GPR32_"), strings.HasPrefix(rhs, "VGPR32_"):
|
||||||
|
return GREG_CLASS, 32
|
||||||
|
}
|
||||||
|
return NOT_REG_CLASS, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var xtypeRe = regexp.MustCompile(`^([iuf])([0-9]+)$`)
|
||||||
|
|
||||||
|
// scalarBaseType describes the base type of a scalar element. This is a Go
|
||||||
|
// type, but without the bit width suffix (with the exception of
|
||||||
|
// scalarBaseIntOrUint).
|
||||||
|
type scalarBaseType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
scalarBaseInt scalarBaseType = iota
|
||||||
|
scalarBaseUint
|
||||||
|
scalarBaseIntOrUint // Signed or unsigned is unspecified
|
||||||
|
scalarBaseFloat
|
||||||
|
scalarBaseComplex
|
||||||
|
scalarBaseBFloat
|
||||||
|
scalarBaseHFloat
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s scalarBaseType) regex() string {
|
||||||
|
switch s {
|
||||||
|
case scalarBaseInt:
|
||||||
|
return "int"
|
||||||
|
case scalarBaseUint:
|
||||||
|
return "uint"
|
||||||
|
case scalarBaseIntOrUint:
|
||||||
|
return "int|uint"
|
||||||
|
case scalarBaseFloat:
|
||||||
|
return "float"
|
||||||
|
case scalarBaseComplex:
|
||||||
|
return "complex"
|
||||||
|
case scalarBaseBFloat:
|
||||||
|
return "BFloat"
|
||||||
|
case scalarBaseHFloat:
|
||||||
|
return "HFloat"
|
||||||
|
}
|
||||||
|
panic(fmt.Sprintf("unknown scalar base type %d", s))
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeType(op *xeddata.Operand) (base scalarBaseType, bits int, ok bool) {
|
||||||
|
// The xtype tells you the element type. i8, i16, i32, i64, f32, etc.
|
||||||
|
//
|
||||||
|
// TODO: Things like AVX2 VPAND have an xtype of u256 because they're
|
||||||
|
// element-width agnostic. Do I map that to all widths, or just omit the
|
||||||
|
// element width and let unification flesh it out? There's no u512
|
||||||
|
// (presumably those are all masked, so elem width matters). These are all
|
||||||
|
// Category: LOGICAL, so maybe we could use that info?
|
||||||
|
|
||||||
|
// Handle some weird ones.
|
||||||
|
switch op.Xtype {
|
||||||
|
// 8-bit float formats as defined by Open Compute Project "OCP 8-bit
|
||||||
|
// Floating Point Specification (OFP8)".
|
||||||
|
case "bf8": // E5M2 float
|
||||||
|
return scalarBaseBFloat, 8, true
|
||||||
|
case "hf8": // E4M3 float
|
||||||
|
return scalarBaseHFloat, 8, true
|
||||||
|
case "bf16": // bfloat16 float
|
||||||
|
return scalarBaseBFloat, 16, true
|
||||||
|
case "2f16":
|
||||||
|
// Complex consisting of 2 float16s. Doesn't exist in Go, but we can say
|
||||||
|
// what it would be.
|
||||||
|
return scalarBaseComplex, 32, true
|
||||||
|
case "2i8", "2I8":
|
||||||
|
// These just use the lower INT8 in each 16 bit field.
|
||||||
|
// As far as I can tell, "2I8" is a typo.
|
||||||
|
return scalarBaseInt, 8, true
|
||||||
|
case "2u16", "2U16":
|
||||||
|
// some VPDP* has it
|
||||||
|
// TODO: does "z" means it has zeroing?
|
||||||
|
return scalarBaseUint, 16, true
|
||||||
|
case "2i16", "2I16":
|
||||||
|
// some VPDP* has it
|
||||||
|
return scalarBaseInt, 16, true
|
||||||
|
case "4u8", "4U8":
|
||||||
|
// some VPDP* has it
|
||||||
|
return scalarBaseUint, 8, true
|
||||||
|
case "4i8", "4I8":
|
||||||
|
// some VPDP* has it
|
||||||
|
return scalarBaseInt, 8, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// The rest follow a simple pattern.
|
||||||
|
m := xtypeRe.FindStringSubmatch(op.Xtype)
|
||||||
|
if m == nil {
|
||||||
|
// TODO: Report unrecognized xtype
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
bits, _ = strconv.Atoi(m[2])
|
||||||
|
switch m[1] {
|
||||||
|
case "i", "u":
|
||||||
|
// XED is rather inconsistent about what's signed, unsigned, or doesn't
|
||||||
|
// matter, so merge them together and let the Go definitions narrow as
|
||||||
|
// appropriate. Maybe there's a better way to do this.
|
||||||
|
return scalarBaseIntOrUint, bits, true
|
||||||
|
case "f":
|
||||||
|
return scalarBaseFloat, bits, true
|
||||||
|
default:
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
}
|
||||||
154
src/simd/_gen/unify/closure.go
Normal file
154
src/simd/_gen/unify/closure.go
Normal file
|
|
@ -0,0 +1,154 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package unify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"maps"
|
||||||
|
"slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Closure struct {
|
||||||
|
val *Value
|
||||||
|
env envSet
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSum(vs ...*Value) Closure {
|
||||||
|
id := &ident{name: "sum"}
|
||||||
|
return Closure{NewValue(Var{id}), topEnv.bind(id, vs...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsBottom returns whether c consists of no values.
|
||||||
|
func (c Closure) IsBottom() bool {
|
||||||
|
return c.val.Domain == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Summands returns the top-level Values of c. This assumes the top-level of c
|
||||||
|
// was constructed as a sum, and is mostly useful for debugging.
|
||||||
|
func (c Closure) Summands() iter.Seq[*Value] {
|
||||||
|
return func(yield func(*Value) bool) {
|
||||||
|
var rec func(v *Value, env envSet) bool
|
||||||
|
rec = func(v *Value, env envSet) bool {
|
||||||
|
switch d := v.Domain.(type) {
|
||||||
|
case Var:
|
||||||
|
parts := env.partitionBy(d.id)
|
||||||
|
for _, part := range parts {
|
||||||
|
// It may be a sum of sums. Walk into this value.
|
||||||
|
if !rec(part.value, part.env) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return yield(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rec(c.val, c.env)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All enumerates all possible concrete values of c by substituting variables
|
||||||
|
// from the environment.
|
||||||
|
//
|
||||||
|
// E.g., enumerating this Value
|
||||||
|
//
|
||||||
|
// a: !sum [1, 2]
|
||||||
|
// b: !sum [3, 4]
|
||||||
|
//
|
||||||
|
// results in
|
||||||
|
//
|
||||||
|
// - {a: 1, b: 3}
|
||||||
|
// - {a: 1, b: 4}
|
||||||
|
// - {a: 2, b: 3}
|
||||||
|
// - {a: 2, b: 4}
|
||||||
|
func (c Closure) All() iter.Seq[*Value] {
|
||||||
|
// In order to enumerate all concrete values under all possible variable
|
||||||
|
// bindings, we use a "non-deterministic continuation passing style" to
|
||||||
|
// implement this. We use CPS to traverse the Value tree, threading the
|
||||||
|
// (possibly narrowing) environment through that CPS following an Euler
|
||||||
|
// tour. Where the environment permits multiple choices, we invoke the same
|
||||||
|
// continuation for each choice. Similar to a yield function, the
|
||||||
|
// continuation can return false to stop the non-deterministic walk.
|
||||||
|
return func(yield func(*Value) bool) {
|
||||||
|
c.val.all1(c.env, func(v *Value, e envSet) bool {
|
||||||
|
return yield(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Value) all1(e envSet, cont func(*Value, envSet) bool) bool {
|
||||||
|
switch d := v.Domain.(type) {
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unknown domain type %T", d))
|
||||||
|
|
||||||
|
case nil:
|
||||||
|
return true
|
||||||
|
|
||||||
|
case Top, String:
|
||||||
|
return cont(v, e)
|
||||||
|
|
||||||
|
case Def:
|
||||||
|
fields := d.keys()
|
||||||
|
// We can reuse this parts slice because we're doing a DFS through the
|
||||||
|
// state space. (Otherwise, we'd have to do some messy threading of an
|
||||||
|
// immutable slice-like value through allElt.)
|
||||||
|
parts := make(map[string]*Value, len(fields))
|
||||||
|
|
||||||
|
// TODO: If there are no Vars or Sums under this Def, then nothing can
|
||||||
|
// change the Value or env, so we could just cont(v, e).
|
||||||
|
var allElt func(elt int, e envSet) bool
|
||||||
|
allElt = func(elt int, e envSet) bool {
|
||||||
|
if elt == len(fields) {
|
||||||
|
// Build a new Def from the concrete parts. Clone parts because
|
||||||
|
// we may reuse it on other non-deterministic branches.
|
||||||
|
nVal := newValueFrom(Def{maps.Clone(parts)}, v)
|
||||||
|
return cont(nVal, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.fields[fields[elt]].all1(e, func(v *Value, e envSet) bool {
|
||||||
|
parts[fields[elt]] = v
|
||||||
|
return allElt(elt+1, e)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return allElt(0, e)
|
||||||
|
|
||||||
|
case Tuple:
|
||||||
|
// Essentially the same as Def.
|
||||||
|
if d.repeat != nil {
|
||||||
|
// There's nothing we can do with this.
|
||||||
|
return cont(v, e)
|
||||||
|
}
|
||||||
|
parts := make([]*Value, len(d.vs))
|
||||||
|
var allElt func(elt int, e envSet) bool
|
||||||
|
allElt = func(elt int, e envSet) bool {
|
||||||
|
if elt == len(d.vs) {
|
||||||
|
// Build a new tuple from the concrete parts. Clone parts because
|
||||||
|
// we may reuse it on other non-deterministic branches.
|
||||||
|
nVal := newValueFrom(Tuple{vs: slices.Clone(parts)}, v)
|
||||||
|
return cont(nVal, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.vs[elt].all1(e, func(v *Value, e envSet) bool {
|
||||||
|
parts[elt] = v
|
||||||
|
return allElt(elt+1, e)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return allElt(0, e)
|
||||||
|
|
||||||
|
case Var:
|
||||||
|
// Go each way this variable can be bound.
|
||||||
|
for _, ePart := range e.partitionBy(d.id) {
|
||||||
|
// d.id is no longer bound in this environment partition. We'll may
|
||||||
|
// need it later in the Euler tour, so bind it back to this single
|
||||||
|
// value.
|
||||||
|
env := ePart.env.bind(d.id, ePart.value)
|
||||||
|
if !ePart.value.all1(env, cont) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
359
src/simd/_gen/unify/domain.go
Normal file
359
src/simd/_gen/unify/domain.go
Normal file
|
|
@ -0,0 +1,359 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package unify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"maps"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Domain is a non-empty set of values, all of the same kind.
|
||||||
|
//
|
||||||
|
// Domain may be a scalar:
|
||||||
|
//
|
||||||
|
// - [String] - Represents string-typed values.
|
||||||
|
//
|
||||||
|
// Or a composite:
|
||||||
|
//
|
||||||
|
// - [Def] - A mapping from fixed keys to [Domain]s.
|
||||||
|
//
|
||||||
|
// - [Tuple] - A fixed-length sequence of [Domain]s or
|
||||||
|
// all possible lengths repeating a [Domain].
|
||||||
|
//
|
||||||
|
// Or top or bottom:
|
||||||
|
//
|
||||||
|
// - [Top] - Represents all possible values of all kinds.
|
||||||
|
//
|
||||||
|
// - nil - Represents no values.
|
||||||
|
//
|
||||||
|
// Or a variable:
|
||||||
|
//
|
||||||
|
// - [Var] - A value captured in the environment.
|
||||||
|
type Domain interface {
|
||||||
|
Exact() bool
|
||||||
|
WhyNotExact() string
|
||||||
|
|
||||||
|
// decode stores this value in a Go value. If this value is not exact, this
|
||||||
|
// returns a potentially wrapped *inexactError.
|
||||||
|
decode(reflect.Value) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type inexactError struct {
|
||||||
|
valueType string
|
||||||
|
goType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *inexactError) Error() string {
|
||||||
|
return fmt.Sprintf("cannot store inexact %s value in %s", e.valueType, e.goType)
|
||||||
|
}
|
||||||
|
|
||||||
|
type decodeError struct {
|
||||||
|
path string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDecodeError(path string, err error) *decodeError {
|
||||||
|
if err, ok := err.(*decodeError); ok {
|
||||||
|
return &decodeError{path: path + "." + err.path, err: err.err}
|
||||||
|
}
|
||||||
|
return &decodeError{path: path, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *decodeError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *decodeError) Error() string {
|
||||||
|
return fmt.Sprintf("%s: %s", e.path, e.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Top represents all possible values of all possible types.
|
||||||
|
type Top struct{}
|
||||||
|
|
||||||
|
func (t Top) Exact() bool { return false }
|
||||||
|
func (t Top) WhyNotExact() string { return "is top" }
|
||||||
|
|
||||||
|
func (t Top) decode(rv reflect.Value) error {
|
||||||
|
// We can decode Top into a pointer-typed value as nil.
|
||||||
|
if rv.Kind() != reflect.Pointer {
|
||||||
|
return &inexactError{"top", rv.Type().String()}
|
||||||
|
}
|
||||||
|
rv.SetZero()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Def is a mapping from field names to [Value]s. Any fields not explicitly
|
||||||
|
// listed have [Value] [Top].
|
||||||
|
type Def struct {
|
||||||
|
fields map[string]*Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// A DefBuilder builds a [Def] one field at a time. The zero value is an empty
|
||||||
|
// [Def].
|
||||||
|
type DefBuilder struct {
|
||||||
|
fields map[string]*Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *DefBuilder) Add(name string, v *Value) {
|
||||||
|
if b.fields == nil {
|
||||||
|
b.fields = make(map[string]*Value)
|
||||||
|
}
|
||||||
|
if _, ok := b.fields[name]; ok {
|
||||||
|
panic(fmt.Sprintf("duplicate field %q", name))
|
||||||
|
}
|
||||||
|
b.fields[name] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build constructs a [Def] from the fields added to this builder.
|
||||||
|
func (b *DefBuilder) Build() Def {
|
||||||
|
return Def{maps.Clone(b.fields)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exact returns true if all field Values are exact.
|
||||||
|
func (d Def) Exact() bool {
|
||||||
|
for _, v := range d.fields {
|
||||||
|
if !v.Exact() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// WhyNotExact returns why the value is not exact
|
||||||
|
func (d Def) WhyNotExact() string {
|
||||||
|
for s, v := range d.fields {
|
||||||
|
if !v.Exact() {
|
||||||
|
w := v.WhyNotExact()
|
||||||
|
return "field " + s + ": " + w
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Def) decode(rv reflect.Value) error {
|
||||||
|
if rv.Kind() != reflect.Struct {
|
||||||
|
return fmt.Errorf("cannot decode Def into %s", rv.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
var lowered map[string]string // Lower case -> canonical for d.fields.
|
||||||
|
rt := rv.Type()
|
||||||
|
for fi := range rv.NumField() {
|
||||||
|
fType := rt.Field(fi)
|
||||||
|
if fType.PkgPath != "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
v := d.fields[fType.Name]
|
||||||
|
if v == nil {
|
||||||
|
v = topValue
|
||||||
|
|
||||||
|
// Try a case-insensitive match
|
||||||
|
canon, ok := d.fields[strings.ToLower(fType.Name)]
|
||||||
|
if ok {
|
||||||
|
v = canon
|
||||||
|
} else {
|
||||||
|
if lowered == nil {
|
||||||
|
lowered = make(map[string]string, len(d.fields))
|
||||||
|
for k := range d.fields {
|
||||||
|
l := strings.ToLower(k)
|
||||||
|
if k != l {
|
||||||
|
lowered[l] = k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
canon, ok := lowered[strings.ToLower(fType.Name)]
|
||||||
|
if ok {
|
||||||
|
v = d.fields[canon]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := decodeReflect(v, rv.Field(fi)); err != nil {
|
||||||
|
return newDecodeError(fType.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Def) keys() []string {
|
||||||
|
return slices.Sorted(maps.Keys(d.fields))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Def) All() iter.Seq2[string, *Value] {
|
||||||
|
// TODO: We call All fairly often. It's probably bad to sort this every
|
||||||
|
// time.
|
||||||
|
keys := slices.Sorted(maps.Keys(d.fields))
|
||||||
|
return func(yield func(string, *Value) bool) {
|
||||||
|
for _, k := range keys {
|
||||||
|
if !yield(k, d.fields[k]) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Tuple is a sequence of Values in one of two forms: 1. a fixed-length tuple,
|
||||||
|
// where each Value can be different or 2. a "repeated tuple", which is a Value
|
||||||
|
// repeated 0 or more times.
|
||||||
|
type Tuple struct {
|
||||||
|
vs []*Value
|
||||||
|
|
||||||
|
// repeat, if non-nil, means this Tuple consists of an element repeated 0 or
|
||||||
|
// more times. If repeat is non-nil, vs must be nil. This is a generator
|
||||||
|
// function because we don't necessarily want *exactly* the same Value
|
||||||
|
// repeated. For example, in YAML encoding, a !sum in a repeated tuple needs
|
||||||
|
// a fresh variable in each instance.
|
||||||
|
repeat []func(envSet) (*Value, envSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTuple(vs ...*Value) Tuple {
|
||||||
|
return Tuple{vs: vs}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRepeat(gens ...func(envSet) (*Value, envSet)) Tuple {
|
||||||
|
return Tuple{repeat: gens}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Tuple) Exact() bool {
|
||||||
|
if d.repeat != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, v := range d.vs {
|
||||||
|
if !v.Exact() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Tuple) WhyNotExact() string {
|
||||||
|
if d.repeat != nil {
|
||||||
|
return "d.repeat is not nil"
|
||||||
|
}
|
||||||
|
for i, v := range d.vs {
|
||||||
|
if !v.Exact() {
|
||||||
|
w := v.WhyNotExact()
|
||||||
|
return "index " + strconv.FormatInt(int64(i), 10) + ": " + w
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Tuple) decode(rv reflect.Value) error {
|
||||||
|
if d.repeat != nil {
|
||||||
|
return &inexactError{"repeated tuple", rv.Type().String()}
|
||||||
|
}
|
||||||
|
// TODO: We could also do arrays.
|
||||||
|
if rv.Kind() != reflect.Slice {
|
||||||
|
return fmt.Errorf("cannot decode Tuple into %s", rv.Type())
|
||||||
|
}
|
||||||
|
if rv.IsNil() || rv.Cap() < len(d.vs) {
|
||||||
|
rv.Set(reflect.MakeSlice(rv.Type(), len(d.vs), len(d.vs)))
|
||||||
|
} else {
|
||||||
|
rv.SetLen(len(d.vs))
|
||||||
|
}
|
||||||
|
for i, v := range d.vs {
|
||||||
|
if err := decodeReflect(v, rv.Index(i)); err != nil {
|
||||||
|
return newDecodeError(fmt.Sprintf("%d", i), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// A String represents a set of strings. It can represent the intersection of a
|
||||||
|
// set of regexps, or a single exact string. In general, the domain of a String
|
||||||
|
// is non-empty, but we do not attempt to prove emptiness of a regexp value.
|
||||||
|
type String struct {
|
||||||
|
kind stringKind
|
||||||
|
re []*regexp.Regexp // Intersection of regexps
|
||||||
|
exact string
|
||||||
|
}
|
||||||
|
|
||||||
|
type stringKind int
|
||||||
|
|
||||||
|
const (
|
||||||
|
stringRegex stringKind = iota
|
||||||
|
stringExact
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewStringRegex(exprs ...string) (String, error) {
|
||||||
|
if len(exprs) == 0 {
|
||||||
|
exprs = []string{""}
|
||||||
|
}
|
||||||
|
v := String{kind: -1}
|
||||||
|
for _, expr := range exprs {
|
||||||
|
if expr == "" {
|
||||||
|
// Skip constructing the regexp. It won't have a "literal prefix"
|
||||||
|
// and so we wind up thinking this is a regexp instead of an exact
|
||||||
|
// (empty) string.
|
||||||
|
v = String{kind: stringExact, exact: ""}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
re, err := regexp.Compile(`\A(?:` + expr + `)\z`)
|
||||||
|
if err != nil {
|
||||||
|
return String{}, fmt.Errorf("parsing value: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// An exact value narrows the whole domain to exact, so we're done, but
|
||||||
|
// should keep parsing.
|
||||||
|
if v.kind == stringExact {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if exact, complete := re.LiteralPrefix(); complete {
|
||||||
|
v = String{kind: stringExact, exact: exact}
|
||||||
|
} else {
|
||||||
|
v.kind = stringRegex
|
||||||
|
v.re = append(v.re, re)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStringExact(s string) String {
|
||||||
|
return String{kind: stringExact, exact: s}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exact returns whether this Value is known to consist of a single string.
|
||||||
|
func (d String) Exact() bool {
|
||||||
|
return d.kind == stringExact
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d String) WhyNotExact() string {
|
||||||
|
if d.kind == stringExact {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return "string is not exact"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d String) decode(rv reflect.Value) error {
|
||||||
|
if d.kind != stringExact {
|
||||||
|
return &inexactError{"regex", rv.Type().String()}
|
||||||
|
}
|
||||||
|
switch rv.Kind() {
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("cannot decode String into %s", rv.Type())
|
||||||
|
case reflect.String:
|
||||||
|
rv.SetString(d.exact)
|
||||||
|
case reflect.Int:
|
||||||
|
i, err := strconv.Atoi(d.exact)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot decode String into %s: %s", rv.Type(), err)
|
||||||
|
}
|
||||||
|
rv.SetInt(int64(i))
|
||||||
|
case reflect.Bool:
|
||||||
|
b, err := strconv.ParseBool(d.exact)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot decode String into %s: %s", rv.Type(), err)
|
||||||
|
}
|
||||||
|
rv.SetBool(b)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
221
src/simd/_gen/unify/dot.go
Normal file
221
src/simd/_gen/unify/dot.go
Normal file
|
|
@ -0,0 +1,221 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package unify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"html"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxNodes = 30
|
||||||
|
|
||||||
|
type dotEncoder struct {
|
||||||
|
w *bytes.Buffer
|
||||||
|
|
||||||
|
idGen int // Node name generation
|
||||||
|
valLimit int // Limit the number of Values in a subgraph
|
||||||
|
|
||||||
|
idp identPrinter
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDotEncoder() *dotEncoder {
|
||||||
|
return &dotEncoder{
|
||||||
|
w: new(bytes.Buffer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *dotEncoder) clear() {
|
||||||
|
enc.w.Reset()
|
||||||
|
enc.idGen = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *dotEncoder) writeTo(w io.Writer) {
|
||||||
|
fmt.Fprintln(w, "digraph {")
|
||||||
|
// Use the "new" ranking algorithm, which lets us put nodes from different
|
||||||
|
// clusters in the same rank.
|
||||||
|
fmt.Fprintln(w, "newrank=true;")
|
||||||
|
fmt.Fprintln(w, "node [shape=box, ordering=out];")
|
||||||
|
|
||||||
|
w.Write(enc.w.Bytes())
|
||||||
|
fmt.Fprintln(w, "}")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *dotEncoder) writeSvg(w io.Writer) error {
|
||||||
|
cmd := exec.Command("dot", "-Tsvg")
|
||||||
|
in, err := cmd.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var out bytes.Buffer
|
||||||
|
cmd.Stdout = &out
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
enc.writeTo(in)
|
||||||
|
in.Close()
|
||||||
|
if err := cmd.Wait(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Trim SVG header so the result can be embedded
|
||||||
|
//
|
||||||
|
// TODO: In Graphviz 10.0.1, we could use -Tsvg_inline.
|
||||||
|
svg := out.Bytes()
|
||||||
|
if i := bytes.Index(svg, []byte("<svg ")); i >= 0 {
|
||||||
|
svg = svg[i:]
|
||||||
|
}
|
||||||
|
_, err = w.Write(svg)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *dotEncoder) newID(f string) string {
|
||||||
|
id := fmt.Sprintf(f, enc.idGen)
|
||||||
|
enc.idGen++
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *dotEncoder) node(label, sublabel string) string {
|
||||||
|
id := enc.newID("n%d")
|
||||||
|
l := html.EscapeString(label)
|
||||||
|
if sublabel != "" {
|
||||||
|
l += fmt.Sprintf("<BR ALIGN=\"CENTER\"/><FONT POINT-SIZE=\"10\">%s</FONT>", html.EscapeString(sublabel))
|
||||||
|
}
|
||||||
|
fmt.Fprintf(enc.w, "%s [label=<%s>];\n", id, l)
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *dotEncoder) edge(from, to string, label string, args ...any) {
|
||||||
|
l := fmt.Sprintf(label, args...)
|
||||||
|
fmt.Fprintf(enc.w, "%s -> %s [label=%q];\n", from, to, l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *dotEncoder) valueSubgraph(v *Value) {
|
||||||
|
enc.valLimit = maxNodes
|
||||||
|
cID := enc.newID("cluster_%d")
|
||||||
|
fmt.Fprintf(enc.w, "subgraph %s {\n", cID)
|
||||||
|
fmt.Fprintf(enc.w, "style=invis;")
|
||||||
|
vID := enc.value(v)
|
||||||
|
fmt.Fprintf(enc.w, "}\n")
|
||||||
|
// We don't need the IDs right now.
|
||||||
|
_, _ = cID, vID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *dotEncoder) value(v *Value) string {
|
||||||
|
if enc.valLimit <= 0 {
|
||||||
|
id := enc.newID("n%d")
|
||||||
|
fmt.Fprintf(enc.w, "%s [label=\"...\", shape=triangle];\n", id)
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
enc.valLimit--
|
||||||
|
|
||||||
|
switch vd := v.Domain.(type) {
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unknown domain type %T", vd))
|
||||||
|
|
||||||
|
case nil:
|
||||||
|
return enc.node("_|_", "")
|
||||||
|
|
||||||
|
case Top:
|
||||||
|
return enc.node("_", "")
|
||||||
|
|
||||||
|
// TODO: Like in YAML, figure out if this is just a sum. In dot, we
|
||||||
|
// could say any unentangled variable is a sum, and if it has more than
|
||||||
|
// one reference just share the node.
|
||||||
|
|
||||||
|
// case Sum:
|
||||||
|
// node := enc.node("Sum", "")
|
||||||
|
// for i, elt := range vd.vs {
|
||||||
|
// enc.edge(node, enc.value(elt), "%d", i)
|
||||||
|
// if enc.valLimit <= 0 {
|
||||||
|
// break
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// return node
|
||||||
|
|
||||||
|
case Def:
|
||||||
|
node := enc.node("Def", "")
|
||||||
|
for k, v := range vd.All() {
|
||||||
|
enc.edge(node, enc.value(v), "%s", k)
|
||||||
|
if enc.valLimit <= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
|
||||||
|
case Tuple:
|
||||||
|
if vd.repeat == nil {
|
||||||
|
label := "Tuple"
|
||||||
|
node := enc.node(label, "")
|
||||||
|
for i, elt := range vd.vs {
|
||||||
|
enc.edge(node, enc.value(elt), "%d", i)
|
||||||
|
if enc.valLimit <= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
} else {
|
||||||
|
// TODO
|
||||||
|
return enc.node("TODO: Repeat", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
case String:
|
||||||
|
switch vd.kind {
|
||||||
|
case stringExact:
|
||||||
|
return enc.node(fmt.Sprintf("%q", vd.exact), "")
|
||||||
|
case stringRegex:
|
||||||
|
var parts []string
|
||||||
|
for _, re := range vd.re {
|
||||||
|
parts = append(parts, fmt.Sprintf("%q", re))
|
||||||
|
}
|
||||||
|
return enc.node(strings.Join(parts, "&"), "")
|
||||||
|
}
|
||||||
|
panic("bad String kind")
|
||||||
|
|
||||||
|
case Var:
|
||||||
|
return enc.node(fmt.Sprintf("Var %s", enc.idp.unique(vd.id)), "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *dotEncoder) envSubgraph(e envSet) {
|
||||||
|
enc.valLimit = maxNodes
|
||||||
|
cID := enc.newID("cluster_%d")
|
||||||
|
fmt.Fprintf(enc.w, "subgraph %s {\n", cID)
|
||||||
|
fmt.Fprintf(enc.w, "style=invis;")
|
||||||
|
vID := enc.env(e.root)
|
||||||
|
fmt.Fprintf(enc.w, "}\n")
|
||||||
|
_, _ = cID, vID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *dotEncoder) env(e *envExpr) string {
|
||||||
|
switch e.kind {
|
||||||
|
default:
|
||||||
|
panic("bad kind")
|
||||||
|
case envZero:
|
||||||
|
return enc.node("0", "")
|
||||||
|
case envUnit:
|
||||||
|
return enc.node("1", "")
|
||||||
|
case envBinding:
|
||||||
|
node := enc.node(fmt.Sprintf("%q :", enc.idp.unique(e.id)), "")
|
||||||
|
enc.edge(node, enc.value(e.val), "")
|
||||||
|
return node
|
||||||
|
case envProduct:
|
||||||
|
node := enc.node("⨯", "")
|
||||||
|
for _, op := range e.operands {
|
||||||
|
enc.edge(node, enc.env(op), "")
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
case envSum:
|
||||||
|
node := enc.node("+", "")
|
||||||
|
for _, op := range e.operands {
|
||||||
|
enc.edge(node, enc.env(op), "")
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
}
|
||||||
480
src/simd/_gen/unify/env.go
Normal file
480
src/simd/_gen/unify/env.go
Normal file
|
|
@ -0,0 +1,480 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package unify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// An envSet is an immutable set of environments, where each environment is a
|
||||||
|
// mapping from [ident]s to [Value]s.
|
||||||
|
//
|
||||||
|
// To keep this compact, we use an algebraic representation similar to
|
||||||
|
// relational algebra. The atoms are zero, unit, or a singular binding:
|
||||||
|
//
|
||||||
|
// - A singular binding is an environment set consisting of a single environment
|
||||||
|
// that binds a single ident to a single value.
|
||||||
|
//
|
||||||
|
// - Zero is the empty set.
|
||||||
|
//
|
||||||
|
// - Unit is an environment set consisting of a single, empty environment (no
|
||||||
|
// bindings).
|
||||||
|
//
|
||||||
|
// From these, we build up more complex sets of environments using sums and
|
||||||
|
// cross products:
|
||||||
|
//
|
||||||
|
// - A sum is simply the union of the two environment sets.
|
||||||
|
//
|
||||||
|
// - A cross product is the Cartesian product of the two environment sets,
|
||||||
|
// followed by combining each pair of environments. Combining simply merges the
|
||||||
|
// two mappings, but fails if the mappings overlap.
|
||||||
|
//
|
||||||
|
// For example, to represent {{x: 1, y: 1}, {x: 2, y: 2}}, we build the two
|
||||||
|
// environments and sum them:
|
||||||
|
//
|
||||||
|
// ({x: 1} ⨯ {y: 1}) + ({x: 2} ⨯ {y: 2})
|
||||||
|
//
|
||||||
|
// If we add a third variable z that can be 1 or 2, independent of x and y, we
|
||||||
|
// get four logical environments:
|
||||||
|
//
|
||||||
|
// {x: 1, y: 1, z: 1}
|
||||||
|
// {x: 2, y: 2, z: 1}
|
||||||
|
// {x: 1, y: 1, z: 2}
|
||||||
|
// {x: 2, y: 2, z: 2}
|
||||||
|
//
|
||||||
|
// This could be represented as a sum of all four environments, but because z is
|
||||||
|
// independent, we can use a more compact representation:
|
||||||
|
//
|
||||||
|
// (({x: 1} ⨯ {y: 1}) + ({x: 2} ⨯ {y: 2})) ⨯ ({z: 1} + {z: 2})
|
||||||
|
//
|
||||||
|
// Environment sets obey commutative algebra rules:
|
||||||
|
//
|
||||||
|
// e + 0 = e
|
||||||
|
// e ⨯ 0 = 0
|
||||||
|
// e ⨯ 1 = e
|
||||||
|
// e + f = f + e
|
||||||
|
// e ⨯ f = f ⨯ e
|
||||||
|
type envSet struct {
|
||||||
|
root *envExpr
|
||||||
|
}
|
||||||
|
|
||||||
|
type envExpr struct {
|
||||||
|
// TODO: A tree-based data structure for this may not be ideal, since it
|
||||||
|
// involves a lot of walking to find things and we often have to do deep
|
||||||
|
// rewrites anyway for partitioning. Would some flattened array-style
|
||||||
|
// representation be better, possibly combined with an index of ident uses?
|
||||||
|
// We could even combine that with an immutable array abstraction (ala
|
||||||
|
// Clojure) that could enable more efficient construction operations.
|
||||||
|
|
||||||
|
kind envExprKind
|
||||||
|
|
||||||
|
// For envBinding
|
||||||
|
id *ident
|
||||||
|
val *Value
|
||||||
|
|
||||||
|
// For sum or product. Len must be >= 2 and none of the elements can have
|
||||||
|
// the same kind as this node.
|
||||||
|
operands []*envExpr
|
||||||
|
}
|
||||||
|
|
||||||
|
type envExprKind byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
envZero envExprKind = iota
|
||||||
|
envUnit
|
||||||
|
envProduct
|
||||||
|
envSum
|
||||||
|
envBinding
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// topEnv is the unit value (multiplicative identity) of a [envSet].
|
||||||
|
topEnv = envSet{envExprUnit}
|
||||||
|
// bottomEnv is the zero value (additive identity) of a [envSet].
|
||||||
|
bottomEnv = envSet{envExprZero}
|
||||||
|
|
||||||
|
envExprZero = &envExpr{kind: envZero}
|
||||||
|
envExprUnit = &envExpr{kind: envUnit}
|
||||||
|
)
|
||||||
|
|
||||||
|
// bind binds id to each of vals in e.
|
||||||
|
//
|
||||||
|
// Its panics if id is already bound in e.
|
||||||
|
//
|
||||||
|
// Environments are typically initially constructed by starting with [topEnv]
|
||||||
|
// and calling bind one or more times.
|
||||||
|
func (e envSet) bind(id *ident, vals ...*Value) envSet {
|
||||||
|
if e.isEmpty() {
|
||||||
|
return bottomEnv
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: If any of vals are _, should we just drop that val? We're kind of
|
||||||
|
// inconsistent about whether an id missing from e means id is invalid or
|
||||||
|
// means id is _.
|
||||||
|
|
||||||
|
// Check that id isn't present in e.
|
||||||
|
for range e.root.bindings(id) {
|
||||||
|
panic("id " + id.name + " already present in environment")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a sum of all the values.
|
||||||
|
bindings := make([]*envExpr, 0, 1)
|
||||||
|
for _, val := range vals {
|
||||||
|
bindings = append(bindings, &envExpr{kind: envBinding, id: id, val: val})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply it in.
|
||||||
|
return envSet{newEnvExprProduct(e.root, newEnvExprSum(bindings...))}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e envSet) isEmpty() bool {
|
||||||
|
return e.root.kind == envZero
|
||||||
|
}
|
||||||
|
|
||||||
|
// bindings yields all [envBinding] nodes in e with the given id. If id is nil,
|
||||||
|
// it yields all binding nodes.
|
||||||
|
func (e *envExpr) bindings(id *ident) iter.Seq[*envExpr] {
|
||||||
|
// This is just a pre-order walk and it happens this is the only thing we
|
||||||
|
// need a pre-order walk for.
|
||||||
|
return func(yield func(*envExpr) bool) {
|
||||||
|
var rec func(e *envExpr) bool
|
||||||
|
rec = func(e *envExpr) bool {
|
||||||
|
if e.kind == envBinding && (id == nil || e.id == id) {
|
||||||
|
if !yield(e) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, o := range e.operands {
|
||||||
|
if !rec(o) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
rec(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newEnvExprProduct constructs a product node from exprs, performing
|
||||||
|
// simplifications. It does NOT check that bindings are disjoint.
|
||||||
|
func newEnvExprProduct(exprs ...*envExpr) *envExpr {
|
||||||
|
factors := make([]*envExpr, 0, 2)
|
||||||
|
for _, expr := range exprs {
|
||||||
|
switch expr.kind {
|
||||||
|
case envZero:
|
||||||
|
return envExprZero
|
||||||
|
case envUnit:
|
||||||
|
// No effect on product
|
||||||
|
case envProduct:
|
||||||
|
factors = append(factors, expr.operands...)
|
||||||
|
default:
|
||||||
|
factors = append(factors, expr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(factors) == 0 {
|
||||||
|
return envExprUnit
|
||||||
|
} else if len(factors) == 1 {
|
||||||
|
return factors[0]
|
||||||
|
}
|
||||||
|
return &envExpr{kind: envProduct, operands: factors}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newEnvExprSum constructs a sum node from exprs, performing simplifications.
|
||||||
|
func newEnvExprSum(exprs ...*envExpr) *envExpr {
|
||||||
|
// TODO: If all of envs are products (or bindings), factor any common terms.
|
||||||
|
// E.g., x * y + x * z ==> x * (y + z). This is easy to do for binding
|
||||||
|
// terms, but harder to do for more general terms.
|
||||||
|
|
||||||
|
var have smallSet[*envExpr]
|
||||||
|
terms := make([]*envExpr, 0, 2)
|
||||||
|
for _, expr := range exprs {
|
||||||
|
switch expr.kind {
|
||||||
|
case envZero:
|
||||||
|
// No effect on sum
|
||||||
|
case envSum:
|
||||||
|
for _, expr1 := range expr.operands {
|
||||||
|
if have.Add(expr1) {
|
||||||
|
terms = append(terms, expr1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if have.Add(expr) {
|
||||||
|
terms = append(terms, expr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(terms) == 0 {
|
||||||
|
return envExprZero
|
||||||
|
} else if len(terms) == 1 {
|
||||||
|
return terms[0]
|
||||||
|
}
|
||||||
|
return &envExpr{kind: envSum, operands: terms}
|
||||||
|
}
|
||||||
|
|
||||||
|
func crossEnvs(env1, env2 envSet) envSet {
|
||||||
|
// Confirm that envs have disjoint idents.
|
||||||
|
var ids1 smallSet[*ident]
|
||||||
|
for e := range env1.root.bindings(nil) {
|
||||||
|
ids1.Add(e.id)
|
||||||
|
}
|
||||||
|
for e := range env2.root.bindings(nil) {
|
||||||
|
if ids1.Has(e.id) {
|
||||||
|
panic(fmt.Sprintf("%s bound on both sides of cross-product", e.id.name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return envSet{newEnvExprProduct(env1.root, env2.root)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unionEnvs(envs ...envSet) envSet {
|
||||||
|
exprs := make([]*envExpr, len(envs))
|
||||||
|
for i := range envs {
|
||||||
|
exprs[i] = envs[i].root
|
||||||
|
}
|
||||||
|
return envSet{newEnvExprSum(exprs...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// envPartition is a subset of an env where id is bound to value in all
|
||||||
|
// deterministic environments.
|
||||||
|
type envPartition struct {
|
||||||
|
id *ident
|
||||||
|
value *Value
|
||||||
|
env envSet
|
||||||
|
}
|
||||||
|
|
||||||
|
// partitionBy splits e by distinct bindings of id and removes id from each
|
||||||
|
// partition.
|
||||||
|
//
|
||||||
|
// If there are environments in e where id is not bound, they will not be
|
||||||
|
// reflected in any partition.
|
||||||
|
//
|
||||||
|
// It panics if e is bottom, since attempting to partition an empty environment
|
||||||
|
// set almost certainly indicates a bug.
|
||||||
|
func (e envSet) partitionBy(id *ident) []envPartition {
|
||||||
|
if e.isEmpty() {
|
||||||
|
// We could return zero partitions, but getting here at all almost
|
||||||
|
// certainly indicates a bug.
|
||||||
|
panic("cannot partition empty environment set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit a partition for each value of id.
|
||||||
|
var seen smallSet[*Value]
|
||||||
|
var parts []envPartition
|
||||||
|
for n := range e.root.bindings(id) {
|
||||||
|
if !seen.Add(n.val) {
|
||||||
|
// Already emitted a partition for this value.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
parts = append(parts, envPartition{
|
||||||
|
id: id,
|
||||||
|
value: n.val,
|
||||||
|
env: envSet{e.root.substitute(id, n.val)},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
|
||||||
|
// substitute replaces bindings of id to val with 1 and bindings of id to any
|
||||||
|
// other value with 0 and simplifies the result.
|
||||||
|
func (e *envExpr) substitute(id *ident, val *Value) *envExpr {
|
||||||
|
switch e.kind {
|
||||||
|
default:
|
||||||
|
panic("bad kind")
|
||||||
|
|
||||||
|
case envZero, envUnit:
|
||||||
|
return e
|
||||||
|
|
||||||
|
case envBinding:
|
||||||
|
if e.id != id {
|
||||||
|
return e
|
||||||
|
} else if e.val != val {
|
||||||
|
return envExprZero
|
||||||
|
} else {
|
||||||
|
return envExprUnit
|
||||||
|
}
|
||||||
|
|
||||||
|
case envProduct, envSum:
|
||||||
|
// Substitute each operand. Sometimes, this won't change anything, so we
|
||||||
|
// build the new operands list lazily.
|
||||||
|
var nOperands []*envExpr
|
||||||
|
for i, op := range e.operands {
|
||||||
|
nOp := op.substitute(id, val)
|
||||||
|
if nOperands == nil && op != nOp {
|
||||||
|
// Operand diverged; initialize nOperands.
|
||||||
|
nOperands = make([]*envExpr, 0, len(e.operands))
|
||||||
|
nOperands = append(nOperands, e.operands[:i]...)
|
||||||
|
}
|
||||||
|
if nOperands != nil {
|
||||||
|
nOperands = append(nOperands, nOp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if nOperands == nil {
|
||||||
|
// Nothing changed.
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
if e.kind == envProduct {
|
||||||
|
return newEnvExprProduct(nOperands...)
|
||||||
|
} else {
|
||||||
|
return newEnvExprSum(nOperands...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A smallSet is a set optimized for stack allocation when small.
|
||||||
|
type smallSet[T comparable] struct {
|
||||||
|
array [32]T
|
||||||
|
n int
|
||||||
|
|
||||||
|
m map[T]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Has returns whether val is in set.
|
||||||
|
func (s *smallSet[T]) Has(val T) bool {
|
||||||
|
arr := s.array[:s.n]
|
||||||
|
for i := range arr {
|
||||||
|
if arr[i] == val {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, ok := s.m[val]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds val to the set and returns true if it was added (not already
|
||||||
|
// present).
|
||||||
|
func (s *smallSet[T]) Add(val T) bool {
|
||||||
|
// Test for presence.
|
||||||
|
if s.Has(val) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add it
|
||||||
|
if s.n < len(s.array) {
|
||||||
|
s.array[s.n] = val
|
||||||
|
s.n++
|
||||||
|
} else {
|
||||||
|
if s.m == nil {
|
||||||
|
s.m = make(map[T]struct{})
|
||||||
|
}
|
||||||
|
s.m[val] = struct{}{}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type ident struct {
|
||||||
|
_ [0]func() // Not comparable (only compare *ident)
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Var struct {
|
||||||
|
id *ident
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Var) Exact() bool {
|
||||||
|
// These can't appear in concrete Values.
|
||||||
|
panic("Exact called on non-concrete Value")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Var) WhyNotExact() string {
|
||||||
|
// These can't appear in concrete Values.
|
||||||
|
return "WhyNotExact called on non-concrete Value"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Var) decode(rv reflect.Value) error {
|
||||||
|
return &inexactError{"var", rv.Type().String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Var) unify(w *Value, e envSet, swap bool, uf *unifier) (Domain, envSet, error) {
|
||||||
|
// TODO: Vars from !sums in the input can have a huge number of values.
|
||||||
|
// Unifying these could be way more efficient with some indexes over any
|
||||||
|
// exact values we can pull out, like Def fields that are exact Strings.
|
||||||
|
// Maybe we try to produce an array of yes/no/maybe matches and then we only
|
||||||
|
// have to do deeper evaluation of the maybes. We could probably cache this
|
||||||
|
// on an envTerm. It may also help to special-case Var/Var unification to
|
||||||
|
// pick which one to index versus enumerate.
|
||||||
|
|
||||||
|
if vd, ok := w.Domain.(Var); ok && d.id == vd.id {
|
||||||
|
// Unifying $x with $x results in $x. If we descend into this we'll have
|
||||||
|
// problems because we strip $x out of the environment to keep ourselves
|
||||||
|
// honest and then can't find it on the other side.
|
||||||
|
//
|
||||||
|
// TODO: I'm not positive this is the right fix.
|
||||||
|
return vd, e, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need to unify w with the value of d in each possible environment. We
|
||||||
|
// can save some work by grouping environments by the value of d, since
|
||||||
|
// there will be a lot of redundancy here.
|
||||||
|
var nEnvs []envSet
|
||||||
|
envParts := e.partitionBy(d.id)
|
||||||
|
for i, envPart := range envParts {
|
||||||
|
exit := uf.enterVar(d.id, i)
|
||||||
|
// Each branch logically gets its own copy of the initial environment
|
||||||
|
// (narrowed down to just this binding of the variable), and each branch
|
||||||
|
// may result in different changes to that starting environment.
|
||||||
|
res, e2, err := w.unify(envPart.value, envPart.env, swap, uf)
|
||||||
|
exit.exit()
|
||||||
|
if err != nil {
|
||||||
|
return nil, envSet{}, err
|
||||||
|
}
|
||||||
|
if res.Domain == nil {
|
||||||
|
// This branch entirely failed to unify, so it's gone.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nEnv := e2.bind(d.id, res)
|
||||||
|
nEnvs = append(nEnvs, nEnv)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(nEnvs) == 0 {
|
||||||
|
// All branches failed
|
||||||
|
return nil, bottomEnv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The effect of this is entirely captured in the environment. We can return
|
||||||
|
// back the same Bind node.
|
||||||
|
return d, unionEnvs(nEnvs...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// An identPrinter maps [ident]s to unique string names.
|
||||||
|
type identPrinter struct {
|
||||||
|
ids map[*ident]string
|
||||||
|
idGen map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *identPrinter) unique(id *ident) string {
|
||||||
|
if p.ids == nil {
|
||||||
|
p.ids = make(map[*ident]string)
|
||||||
|
p.idGen = make(map[string]int)
|
||||||
|
}
|
||||||
|
|
||||||
|
name, ok := p.ids[id]
|
||||||
|
if !ok {
|
||||||
|
gen := p.idGen[id.name]
|
||||||
|
p.idGen[id.name]++
|
||||||
|
if gen == 0 {
|
||||||
|
name = id.name
|
||||||
|
} else {
|
||||||
|
name = fmt.Sprintf("%s#%d", id.name, gen)
|
||||||
|
}
|
||||||
|
p.ids[id] = name
|
||||||
|
}
|
||||||
|
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *identPrinter) slice(ids []*ident) string {
|
||||||
|
var strs []string
|
||||||
|
for _, id := range ids {
|
||||||
|
strs = append(strs, p.unique(id))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("[%s]", strings.Join(strs, ", "))
|
||||||
|
}
|
||||||
123
src/simd/_gen/unify/html.go
Normal file
123
src/simd/_gen/unify/html.go
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package unify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"html"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (t *tracer) writeHTML(w io.Writer) {
|
||||||
|
if !t.saveTree {
|
||||||
|
panic("writeHTML called without tracer.saveTree")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "<html><head><style>%s</style></head>", htmlCSS)
|
||||||
|
for _, root := range t.trees {
|
||||||
|
dot := newDotEncoder()
|
||||||
|
html := htmlTracer{w: w, dot: dot}
|
||||||
|
html.writeTree(root)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "</html>\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
const htmlCSS = `
|
||||||
|
.unify {
|
||||||
|
display: grid;
|
||||||
|
grid-auto-columns: min-content;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.header {
|
||||||
|
grid-row: 1;
|
||||||
|
font-weight: bold;
|
||||||
|
padding: 0.25em;
|
||||||
|
position: sticky;
|
||||||
|
top: 0;
|
||||||
|
background: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.envFactor {
|
||||||
|
display: grid;
|
||||||
|
grid-auto-rows: min-content;
|
||||||
|
grid-template-columns: subgrid;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
type htmlTracer struct {
|
||||||
|
w io.Writer
|
||||||
|
dot *dotEncoder
|
||||||
|
svgs map[any]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *htmlTracer) writeTree(node *traceTree) {
|
||||||
|
// TODO: This could be really nice.
|
||||||
|
//
|
||||||
|
// - Put nodes that were unified on the same rank with {rank=same; a; b}
|
||||||
|
//
|
||||||
|
// - On hover, highlight nodes that node was unified with and the result. If
|
||||||
|
// it's a variable, highlight it in the environment, too.
|
||||||
|
//
|
||||||
|
// - On click, show the details of unifying that node.
|
||||||
|
//
|
||||||
|
// This could be the only way to navigate, without necessarily needing the
|
||||||
|
// whole nest of <detail> nodes.
|
||||||
|
|
||||||
|
// TODO: It might be possible to write this out on the fly.
|
||||||
|
|
||||||
|
t.emit([]*Value{node.v, node.w}, []string{"v", "w"}, node.envIn)
|
||||||
|
|
||||||
|
// Render children.
|
||||||
|
for i, child := range node.children {
|
||||||
|
if i >= 10 {
|
||||||
|
fmt.Fprintf(t.w, `<div style="margin-left: 4em">...</div>`)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fmt.Fprintf(t.w, `<details style="margin-left: 4em"><summary>%s</summary>`, html.EscapeString(child.label))
|
||||||
|
t.writeTree(child)
|
||||||
|
fmt.Fprintf(t.w, "</details>\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render result.
|
||||||
|
if node.err != nil {
|
||||||
|
fmt.Fprintf(t.w, "Error: %s\n", html.EscapeString(node.err.Error()))
|
||||||
|
} else {
|
||||||
|
t.emit([]*Value{node.res}, []string{"res"}, node.env)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func htmlSVG[Key comparable](t *htmlTracer, f func(Key), arg Key) string {
|
||||||
|
if s, ok := t.svgs[arg]; ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
var buf strings.Builder
|
||||||
|
f(arg)
|
||||||
|
t.dot.writeSvg(&buf)
|
||||||
|
t.dot.clear()
|
||||||
|
svg := buf.String()
|
||||||
|
if t.svgs == nil {
|
||||||
|
t.svgs = make(map[any]string)
|
||||||
|
}
|
||||||
|
t.svgs[arg] = svg
|
||||||
|
buf.Reset()
|
||||||
|
return svg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *htmlTracer) emit(vs []*Value, labels []string, env envSet) {
|
||||||
|
fmt.Fprintf(t.w, `<div class="unify">`)
|
||||||
|
for i, v := range vs {
|
||||||
|
fmt.Fprintf(t.w, `<div class="header" style="grid-column: %d">%s</div>`, i+1, html.EscapeString(labels[i]))
|
||||||
|
fmt.Fprintf(t.w, `<div style="grid-area: 2 / %d">%s</div>`, i+1, htmlSVG(t, t.dot.valueSubgraph, v))
|
||||||
|
}
|
||||||
|
col := len(vs)
|
||||||
|
|
||||||
|
fmt.Fprintf(t.w, `<div class="header" style="grid-column: %d">in</div>`, col+1)
|
||||||
|
fmt.Fprintf(t.w, `<div style="grid-area: 2 / %d">%s</div>`, col+1, htmlSVG(t, t.dot.envSubgraph, env))
|
||||||
|
|
||||||
|
fmt.Fprintf(t.w, `</div>`)
|
||||||
|
}
|
||||||
33
src/simd/_gen/unify/pos.go
Normal file
33
src/simd/_gen/unify/pos.go
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package unify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Pos struct {
|
||||||
|
Path string
|
||||||
|
Line int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Pos) String() string {
|
||||||
|
var b []byte
|
||||||
|
b, _ = p.AppendText(b)
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Pos) AppendText(b []byte) ([]byte, error) {
|
||||||
|
if p.Line == 0 {
|
||||||
|
if p.Path == "" {
|
||||||
|
return append(b, "?:?"...), nil
|
||||||
|
} else {
|
||||||
|
return append(b, p.Path...), nil
|
||||||
|
}
|
||||||
|
} else if p.Path == "" {
|
||||||
|
return fmt.Appendf(b, "?:%d", p.Line), nil
|
||||||
|
}
|
||||||
|
return fmt.Appendf(b, "%s:%d", p.Path, p.Line), nil
|
||||||
|
}
|
||||||
33
src/simd/_gen/unify/testdata/stress.yaml
vendored
Normal file
33
src/simd/_gen/unify/testdata/stress.yaml
vendored
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
# In the original representation of environments, this caused an exponential
|
||||||
|
# blowup in time and allocation. With that representation, this took about 20
|
||||||
|
# seconds on my laptop and had a max RSS of ~12 GB. Big enough to be really
|
||||||
|
# noticeable, but not so big it's likely to crash a developer machine. With the
|
||||||
|
# better environment representation, it runs almost instantly and has an RSS of
|
||||||
|
# ~90 MB.
|
||||||
|
unify:
|
||||||
|
- !sum
|
||||||
|
- !sum [1, 2]
|
||||||
|
- !sum [3, 4]
|
||||||
|
- !sum [5, 6]
|
||||||
|
- !sum [7, 8]
|
||||||
|
- !sum [9, 10]
|
||||||
|
- !sum [11, 12]
|
||||||
|
- !sum [13, 14]
|
||||||
|
- !sum [15, 16]
|
||||||
|
- !sum [17, 18]
|
||||||
|
- !sum [19, 20]
|
||||||
|
- !sum [21, 22]
|
||||||
|
- !sum
|
||||||
|
- !sum [1, 2]
|
||||||
|
- !sum [3, 4]
|
||||||
|
- !sum [5, 6]
|
||||||
|
- !sum [7, 8]
|
||||||
|
- !sum [9, 10]
|
||||||
|
- !sum [11, 12]
|
||||||
|
- !sum [13, 14]
|
||||||
|
- !sum [15, 16]
|
||||||
|
- !sum [17, 18]
|
||||||
|
- !sum [19, 20]
|
||||||
|
- !sum [21, 22]
|
||||||
|
all:
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
|
||||||
174
src/simd/_gen/unify/testdata/unify.yaml
vendored
Normal file
174
src/simd/_gen/unify/testdata/unify.yaml
vendored
Normal file
|
|
@ -0,0 +1,174 @@
|
||||||
|
# Basic tests of unification
|
||||||
|
|
||||||
|
#
|
||||||
|
# Terminals
|
||||||
|
#
|
||||||
|
|
||||||
|
unify:
|
||||||
|
- _
|
||||||
|
- _
|
||||||
|
want:
|
||||||
|
_
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- _
|
||||||
|
- test
|
||||||
|
want:
|
||||||
|
test
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- test
|
||||||
|
- t?est
|
||||||
|
want:
|
||||||
|
test
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
want:
|
||||||
|
1
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- test
|
||||||
|
- foo
|
||||||
|
want:
|
||||||
|
_|_
|
||||||
|
|
||||||
|
#
|
||||||
|
# Tuple
|
||||||
|
#
|
||||||
|
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- [a, b]
|
||||||
|
- [a, b]
|
||||||
|
want:
|
||||||
|
[a, b]
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- [a, _]
|
||||||
|
- [_, b]
|
||||||
|
want:
|
||||||
|
[a, b]
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- ["ab?c", "de?f"]
|
||||||
|
- [ac, def]
|
||||||
|
want:
|
||||||
|
[ac, def]
|
||||||
|
|
||||||
|
#
|
||||||
|
# Repeats
|
||||||
|
#
|
||||||
|
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- !repeat [a]
|
||||||
|
- [_]
|
||||||
|
want:
|
||||||
|
[a]
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- !repeat [a]
|
||||||
|
- [_, _]
|
||||||
|
want:
|
||||||
|
[a, a]
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- !repeat [a]
|
||||||
|
- [b]
|
||||||
|
want:
|
||||||
|
_|_
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- !repeat [xy*]
|
||||||
|
- [x, xy, xyy]
|
||||||
|
want:
|
||||||
|
[x, xy, xyy]
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- !repeat [xy*]
|
||||||
|
- !repeat ["xz?y*"]
|
||||||
|
- [x, xy, xyy]
|
||||||
|
want:
|
||||||
|
[x, xy, xyy]
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- !repeat [!sum [a, b]]
|
||||||
|
- [a, b, a]
|
||||||
|
all:
|
||||||
|
- [a, b, a]
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- !repeat [!sum [a, b]]
|
||||||
|
- !repeat [!sum [b, c]]
|
||||||
|
- [b, b, b]
|
||||||
|
all:
|
||||||
|
- [b, b, b]
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- !repeat [!sum [a, b]]
|
||||||
|
- !repeat [!sum [b, c]]
|
||||||
|
- [a]
|
||||||
|
all: []
|
||||||
|
|
||||||
|
#
|
||||||
|
# Def
|
||||||
|
#
|
||||||
|
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- {a: a, b: b}
|
||||||
|
- {a: a, b: b}
|
||||||
|
want:
|
||||||
|
{a: a, b: b}
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- {a: a}
|
||||||
|
- {b: b}
|
||||||
|
want:
|
||||||
|
{a: a, b: b}
|
||||||
|
|
||||||
|
#
|
||||||
|
# Sum
|
||||||
|
#
|
||||||
|
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- !sum [1, 2]
|
||||||
|
- !sum [2, 3]
|
||||||
|
all:
|
||||||
|
- 2
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- !sum [{label: a, value: abc}, {label: b, value: def}]
|
||||||
|
- !sum [{value: "ab?c", extra: d}, {value: "def?", extra: g}]
|
||||||
|
all:
|
||||||
|
- {extra: d, label: a, value: abc}
|
||||||
|
- {extra: g, label: b, value: def}
|
||||||
|
---
|
||||||
|
# A sum of repeats must deal with different dynamically-created variables in
|
||||||
|
# each branch.
|
||||||
|
unify:
|
||||||
|
- !sum [!repeat [a], !repeat [b]]
|
||||||
|
- [a, a, a]
|
||||||
|
all:
|
||||||
|
- [a, a, a]
|
||||||
|
---
|
||||||
|
unify:
|
||||||
|
- !sum [!repeat [a], !repeat [b]]
|
||||||
|
- [a, a, b]
|
||||||
|
all: []
|
||||||
|
---
|
||||||
|
# Exercise sumEnvs with more than one result
|
||||||
|
unify:
|
||||||
|
- !sum
|
||||||
|
- [a|b, c|d]
|
||||||
|
- [e, g]
|
||||||
|
- [!sum [a, b, e, f], !sum [c, d, g, h]]
|
||||||
|
all:
|
||||||
|
- [a, c]
|
||||||
|
- [a, d]
|
||||||
|
- [b, c]
|
||||||
|
- [b, d]
|
||||||
|
- [e, g]
|
||||||
175
src/simd/_gen/unify/testdata/vars.yaml
vendored
Normal file
175
src/simd/_gen/unify/testdata/vars.yaml
vendored
Normal file
|
|
@ -0,0 +1,175 @@
|
||||||
|
#
|
||||||
|
# Basic tests
|
||||||
|
#
|
||||||
|
|
||||||
|
name: "basic string"
|
||||||
|
unify:
|
||||||
|
- $x
|
||||||
|
- test
|
||||||
|
all:
|
||||||
|
- test
|
||||||
|
---
|
||||||
|
name: "basic tuple"
|
||||||
|
unify:
|
||||||
|
- [$x, $x]
|
||||||
|
- [test, test]
|
||||||
|
all:
|
||||||
|
- [test, test]
|
||||||
|
---
|
||||||
|
name: "three tuples"
|
||||||
|
unify:
|
||||||
|
- [$x, $x]
|
||||||
|
- [test, _]
|
||||||
|
- [_, test]
|
||||||
|
all:
|
||||||
|
- [test, test]
|
||||||
|
---
|
||||||
|
name: "basic def"
|
||||||
|
unify:
|
||||||
|
- {a: $x, b: $x}
|
||||||
|
- {a: test, b: test}
|
||||||
|
all:
|
||||||
|
- {a: test, b: test}
|
||||||
|
---
|
||||||
|
name: "three defs"
|
||||||
|
unify:
|
||||||
|
- {a: $x, b: $x}
|
||||||
|
- {a: test}
|
||||||
|
- {b: test}
|
||||||
|
all:
|
||||||
|
- {a: test, b: test}
|
||||||
|
|
||||||
|
#
|
||||||
|
# Bottom tests
|
||||||
|
#
|
||||||
|
|
||||||
|
---
|
||||||
|
name: "basic bottom"
|
||||||
|
unify:
|
||||||
|
- [$x, $x]
|
||||||
|
- [test, foo]
|
||||||
|
all: []
|
||||||
|
---
|
||||||
|
name: "three-way bottom"
|
||||||
|
unify:
|
||||||
|
- [$x, $x]
|
||||||
|
- [test, _]
|
||||||
|
- [_, foo]
|
||||||
|
all: []
|
||||||
|
|
||||||
|
#
|
||||||
|
# Basic sum tests
|
||||||
|
#
|
||||||
|
|
||||||
|
---
|
||||||
|
name: "basic sum"
|
||||||
|
unify:
|
||||||
|
- $x
|
||||||
|
- !sum [a, b]
|
||||||
|
all:
|
||||||
|
- a
|
||||||
|
- b
|
||||||
|
---
|
||||||
|
name: "sum of tuples"
|
||||||
|
unify:
|
||||||
|
- [$x]
|
||||||
|
- !sum [[a], [b]]
|
||||||
|
all:
|
||||||
|
- [a]
|
||||||
|
- [b]
|
||||||
|
---
|
||||||
|
name: "acausal sum"
|
||||||
|
unify:
|
||||||
|
- [_, !sum [a, b]]
|
||||||
|
- [$x, $x]
|
||||||
|
all:
|
||||||
|
- [a, a]
|
||||||
|
- [b, b]
|
||||||
|
|
||||||
|
#
|
||||||
|
# Transitivity tests
|
||||||
|
#
|
||||||
|
|
||||||
|
---
|
||||||
|
name: "transitivity"
|
||||||
|
unify:
|
||||||
|
- [_, _, _, test]
|
||||||
|
- [$x, $x, _, _]
|
||||||
|
- [ _, $x, $x, _]
|
||||||
|
- [ _, _, $x, $x]
|
||||||
|
all:
|
||||||
|
- [test, test, test, test]
|
||||||
|
|
||||||
|
#
|
||||||
|
# Multiple vars
|
||||||
|
#
|
||||||
|
|
||||||
|
---
|
||||||
|
name: "basic uncorrelated vars"
|
||||||
|
unify:
|
||||||
|
- - !sum [1, 2]
|
||||||
|
- !sum [3, 4]
|
||||||
|
- - $a
|
||||||
|
- $b
|
||||||
|
all:
|
||||||
|
- [1, 3]
|
||||||
|
- [1, 4]
|
||||||
|
- [2, 3]
|
||||||
|
- [2, 4]
|
||||||
|
---
|
||||||
|
name: "uncorrelated vars"
|
||||||
|
unify:
|
||||||
|
- - !sum [1, 2]
|
||||||
|
- !sum [3, 4]
|
||||||
|
- !sum [1, 2]
|
||||||
|
- - $a
|
||||||
|
- $b
|
||||||
|
- $a
|
||||||
|
all:
|
||||||
|
- [1, 3, 1]
|
||||||
|
- [1, 4, 1]
|
||||||
|
- [2, 3, 2]
|
||||||
|
- [2, 4, 2]
|
||||||
|
---
|
||||||
|
name: "entangled vars"
|
||||||
|
unify:
|
||||||
|
- - !sum [[1,2],[3,4]]
|
||||||
|
- !sum [[2,1],[3,4],[4,3]]
|
||||||
|
- - [$a, $b]
|
||||||
|
- [$b, $a]
|
||||||
|
all:
|
||||||
|
- - [1, 2]
|
||||||
|
- [2, 1]
|
||||||
|
- - [3, 4]
|
||||||
|
- [4, 3]
|
||||||
|
|
||||||
|
#
|
||||||
|
# End-to-end examples
|
||||||
|
#
|
||||||
|
|
||||||
|
---
|
||||||
|
name: "end-to-end"
|
||||||
|
unify:
|
||||||
|
- go: Add
|
||||||
|
in:
|
||||||
|
- go: $t
|
||||||
|
- go: $t
|
||||||
|
- in: !repeat
|
||||||
|
- !sum
|
||||||
|
- go: Int32x4
|
||||||
|
base: int
|
||||||
|
- go: Uint32x4
|
||||||
|
base: uint
|
||||||
|
all:
|
||||||
|
- go: Add
|
||||||
|
in:
|
||||||
|
- base: int
|
||||||
|
go: Int32x4
|
||||||
|
- base: int
|
||||||
|
go: Int32x4
|
||||||
|
- go: Add
|
||||||
|
in:
|
||||||
|
- base: uint
|
||||||
|
go: Uint32x4
|
||||||
|
- base: uint
|
||||||
|
go: Uint32x4
|
||||||
168
src/simd/_gen/unify/trace.go
Normal file
168
src/simd/_gen/unify/trace.go
Normal file
|
|
@ -0,0 +1,168 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package unify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// debugDotInHTML, if true, includes dot code for all graphs in the HTML. Useful
|
||||||
|
// for debugging the dot output itself.
|
||||||
|
const debugDotInHTML = false
|
||||||
|
|
||||||
|
var Debug struct {
|
||||||
|
// UnifyLog, if non-nil, receives a streaming text trace of unification.
|
||||||
|
UnifyLog io.Writer
|
||||||
|
|
||||||
|
// HTML, if non-nil, writes an HTML trace of unification to HTML.
|
||||||
|
HTML io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
type tracer struct {
|
||||||
|
logw io.Writer
|
||||||
|
|
||||||
|
enc yamlEncoder // Print consistent idents throughout
|
||||||
|
|
||||||
|
saveTree bool // if set, record tree; required for HTML output
|
||||||
|
|
||||||
|
path []string
|
||||||
|
|
||||||
|
node *traceTree
|
||||||
|
trees []*traceTree
|
||||||
|
}
|
||||||
|
|
||||||
|
type traceTree struct {
|
||||||
|
label string // Identifies this node as a child of parent
|
||||||
|
v, w *Value // Unification inputs
|
||||||
|
envIn envSet
|
||||||
|
res *Value // Unification result
|
||||||
|
env envSet
|
||||||
|
err error // or error
|
||||||
|
|
||||||
|
parent *traceTree
|
||||||
|
children []*traceTree
|
||||||
|
}
|
||||||
|
|
||||||
|
type tracerExit struct {
|
||||||
|
t *tracer
|
||||||
|
len int
|
||||||
|
node *traceTree
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tracer) enter(pat string, vals ...any) tracerExit {
|
||||||
|
if t == nil {
|
||||||
|
return tracerExit{}
|
||||||
|
}
|
||||||
|
|
||||||
|
label := fmt.Sprintf(pat, vals...)
|
||||||
|
|
||||||
|
var p *traceTree
|
||||||
|
if t.saveTree {
|
||||||
|
p = t.node
|
||||||
|
if p != nil {
|
||||||
|
t.node = &traceTree{label: label, parent: p}
|
||||||
|
p.children = append(p.children, t.node)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.path = append(t.path, label)
|
||||||
|
return tracerExit{t, len(t.path) - 1, p}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tracer) enterVar(id *ident, branch int) tracerExit {
|
||||||
|
if t == nil {
|
||||||
|
return tracerExit{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the tracer's ident printer
|
||||||
|
return t.enter("Var %s br %d", t.enc.idp.unique(id), branch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (te tracerExit) exit() {
|
||||||
|
if te.t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
te.t.path = te.t.path[:te.len]
|
||||||
|
te.t.node = te.node
|
||||||
|
}
|
||||||
|
|
||||||
|
func indentf(prefix string, pat string, vals ...any) string {
|
||||||
|
s := fmt.Sprintf(pat, vals...)
|
||||||
|
if len(prefix) == 0 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
if !strings.Contains(s, "\n") {
|
||||||
|
return prefix + s
|
||||||
|
}
|
||||||
|
|
||||||
|
indent := prefix
|
||||||
|
if strings.TrimLeft(prefix, " ") != "" {
|
||||||
|
// Prefix has non-space characters in it. Construct an all space-indent.
|
||||||
|
indent = strings.Repeat(" ", len(prefix))
|
||||||
|
}
|
||||||
|
return prefix + strings.ReplaceAll(s, "\n", "\n"+indent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func yamlf(prefix string, node *yaml.Node) string {
|
||||||
|
b, err := yaml.Marshal(node)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("<marshal failed: %s>", err)
|
||||||
|
}
|
||||||
|
return strings.TrimRight(indentf(prefix, "%s", b), " \n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tracer) logf(pat string, vals ...any) {
|
||||||
|
if t == nil || t.logw == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
prefix := fmt.Sprintf("[%s] ", strings.Join(t.path, "/"))
|
||||||
|
s := indentf(prefix, pat, vals...)
|
||||||
|
s = strings.TrimRight(s, " \n")
|
||||||
|
fmt.Fprintf(t.logw, "%s\n", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tracer) traceUnify(v, w *Value, e envSet) {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.logf("Unify\n%s\nwith\n%s\nin\n%s",
|
||||||
|
yamlf(" ", t.enc.value(v)),
|
||||||
|
yamlf(" ", t.enc.value(w)),
|
||||||
|
yamlf(" ", t.enc.env(e)))
|
||||||
|
|
||||||
|
if t.saveTree {
|
||||||
|
if t.node == nil {
|
||||||
|
t.node = &traceTree{}
|
||||||
|
t.trees = append(t.trees, t.node)
|
||||||
|
}
|
||||||
|
t.node.v, t.node.w, t.node.envIn = v, w, e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tracer) traceDone(res *Value, e envSet, err error) {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.logf("==> %s", err)
|
||||||
|
} else {
|
||||||
|
t.logf("==>\n%s", yamlf(" ", t.enc.closure(Closure{res, e})))
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.saveTree {
|
||||||
|
node := t.node
|
||||||
|
if node == nil {
|
||||||
|
panic("popped top of trace stack")
|
||||||
|
}
|
||||||
|
node.res, node.err = res, err
|
||||||
|
node.env = e
|
||||||
|
}
|
||||||
|
}
|
||||||
322
src/simd/_gen/unify/unify.go
Normal file
322
src/simd/_gen/unify/unify.go
Normal file
|
|
@ -0,0 +1,322 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package unify implements unification of structured values.
|
||||||
|
//
|
||||||
|
// A [Value] represents a possibly infinite set of concrete values, where a
|
||||||
|
// value is either a string ([String]), a tuple of values ([Tuple]), or a
|
||||||
|
// string-keyed map of values called a "def" ([Def]). These sets can be further
|
||||||
|
// constrained by variables ([Var]). A [Value] combined with bindings of
|
||||||
|
// variables is a [Closure].
|
||||||
|
//
|
||||||
|
// [Unify] finds a [Closure] that satisfies two or more other [Closure]s. This
|
||||||
|
// can be thought of as intersecting the sets represented by these Closures'
|
||||||
|
// values, or as the greatest lower bound/infimum of these Closures. If no such
|
||||||
|
// Closure exists, the result of unification is "bottom", or the empty set.
|
||||||
|
//
|
||||||
|
// # Examples
|
||||||
|
//
|
||||||
|
// The regular expression "a*" is the infinite set of strings of zero or more
|
||||||
|
// "a"s. "a*" can be unified with "a" or "aa" or "aaa", and the result is just
|
||||||
|
// "a", "aa", or "aaa", respectively. However, unifying "a*" with "b" fails
|
||||||
|
// because there are no values that satisfy both.
|
||||||
|
//
|
||||||
|
// Sums express sets directly. For example, !sum [a, b] is the set consisting of
|
||||||
|
// "a" and "b". Unifying this with !sum [b, c] results in just "b". This also
|
||||||
|
// makes it easy to demonstrate that unification isn't necessarily a single
|
||||||
|
// concrete value. For example, unifying !sum [a, b, c] with !sum [b, c, d]
|
||||||
|
// results in two concrete values: "b" and "c".
|
||||||
|
//
|
||||||
|
// The special value _ or "top" represents all possible values. Unifying _ with
|
||||||
|
// any value x results in x.
|
||||||
|
//
|
||||||
|
// Unifying composite values—tuples and defs—unifies their elements.
|
||||||
|
//
|
||||||
|
// The value [a*, aa] is an infinite set of tuples. If we unify that with the
|
||||||
|
// value [aaa, a*], the only possible value that satisfies both is [aaa, aa].
|
||||||
|
// Likewise, this is the intersection of the sets described by these two values.
|
||||||
|
//
|
||||||
|
// Defs are similar to tuples, but they are indexed by strings and don't have a
|
||||||
|
// fixed length. For example, {x: a, y: b} is a def with two fields. Any field
|
||||||
|
// not mentioned in a def is implicitly top. Thus, unifying this with {y: b, z:
|
||||||
|
// c} results in {x: a, y: b, z: c}.
|
||||||
|
//
|
||||||
|
// Variables constrain values. For example, the value [$x, $x] represents all
|
||||||
|
// tuples whose first and second values are the same, but doesn't otherwise
|
||||||
|
// constrain that value. Thus, this set includes [a, a] as well as [[b, c, d],
|
||||||
|
// [b, c, d]], but it doesn't include [a, b].
|
||||||
|
//
|
||||||
|
// Sums are internally implemented as fresh variables that are simultaneously
|
||||||
|
// bound to all values of the sum. That is !sum [a, b] is actually $var (where
|
||||||
|
// var is some fresh name), closed under the environment $var=a | $var=b.
|
||||||
|
package unify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Unify computes a Closure that satisfies each input Closure. If no such
|
||||||
|
// Closure exists, it returns bottom.
|
||||||
|
func Unify(closures ...Closure) (Closure, error) {
|
||||||
|
if len(closures) == 0 {
|
||||||
|
return Closure{topValue, topEnv}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var trace *tracer
|
||||||
|
if Debug.UnifyLog != nil || Debug.HTML != nil {
|
||||||
|
trace = &tracer{
|
||||||
|
logw: Debug.UnifyLog,
|
||||||
|
saveTree: Debug.HTML != nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unified := closures[0]
|
||||||
|
for _, c := range closures[1:] {
|
||||||
|
var err error
|
||||||
|
uf := newUnifier()
|
||||||
|
uf.tracer = trace
|
||||||
|
e := crossEnvs(unified.env, c.env)
|
||||||
|
unified.val, unified.env, err = unified.val.unify(c.val, e, false, uf)
|
||||||
|
if Debug.HTML != nil {
|
||||||
|
uf.writeHTML(Debug.HTML)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return Closure{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return unified, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type unifier struct {
|
||||||
|
*tracer
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUnifier() *unifier {
|
||||||
|
return &unifier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// errDomains is a sentinel error used between unify and unify1 to indicate that
|
||||||
|
// unify1 could not unify the domains of the two values.
|
||||||
|
var errDomains = errors.New("cannot unify domains")
|
||||||
|
|
||||||
|
func (v *Value) unify(w *Value, e envSet, swap bool, uf *unifier) (*Value, envSet, error) {
|
||||||
|
if swap {
|
||||||
|
// Put the values in order. This just happens to be a handy choke-point
|
||||||
|
// to do this at.
|
||||||
|
v, w = w, v
|
||||||
|
}
|
||||||
|
|
||||||
|
uf.traceUnify(v, w, e)
|
||||||
|
|
||||||
|
d, e2, err := v.unify1(w, e, false, uf)
|
||||||
|
if err == errDomains {
|
||||||
|
// Try the other order.
|
||||||
|
d, e2, err = w.unify1(v, e, true, uf)
|
||||||
|
if err == errDomains {
|
||||||
|
// Okay, we really can't unify these.
|
||||||
|
err = fmt.Errorf("cannot unify %T (%s) and %T (%s): kind mismatch", v.Domain, v.PosString(), w.Domain, w.PosString())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
uf.traceDone(nil, envSet{}, err)
|
||||||
|
return nil, envSet{}, err
|
||||||
|
}
|
||||||
|
res := unified(d, v, w)
|
||||||
|
uf.traceDone(res, e2, nil)
|
||||||
|
if d == nil {
|
||||||
|
// Double check that a bottom Value also has a bottom env.
|
||||||
|
if !e2.isEmpty() {
|
||||||
|
panic("bottom Value has non-bottom environment")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, e2, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Value) unify1(w *Value, e envSet, swap bool, uf *unifier) (Domain, envSet, error) {
|
||||||
|
// TODO: If there's an error, attach position information to it.
|
||||||
|
|
||||||
|
vd, wd := v.Domain, w.Domain
|
||||||
|
|
||||||
|
// Bottom returns bottom, and eliminates all possible environments.
|
||||||
|
if vd == nil || wd == nil {
|
||||||
|
return nil, bottomEnv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Top always returns the other.
|
||||||
|
if _, ok := vd.(Top); ok {
|
||||||
|
return wd, e, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Variables
|
||||||
|
if vd, ok := vd.(Var); ok {
|
||||||
|
return vd.unify(w, e, swap, uf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Composite values
|
||||||
|
if vd, ok := vd.(Def); ok {
|
||||||
|
if wd, ok := wd.(Def); ok {
|
||||||
|
return vd.unify(wd, e, swap, uf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if vd, ok := vd.(Tuple); ok {
|
||||||
|
if wd, ok := wd.(Tuple); ok {
|
||||||
|
return vd.unify(wd, e, swap, uf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scalar values
|
||||||
|
if vd, ok := vd.(String); ok {
|
||||||
|
if wd, ok := wd.(String); ok {
|
||||||
|
res := vd.unify(wd)
|
||||||
|
if res == nil {
|
||||||
|
e = bottomEnv
|
||||||
|
}
|
||||||
|
return res, e, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, envSet{}, errDomains
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Def) unify(o Def, e envSet, swap bool, uf *unifier) (Domain, envSet, error) {
|
||||||
|
out := Def{fields: make(map[string]*Value)}
|
||||||
|
|
||||||
|
// Check keys of d against o.
|
||||||
|
for key, dv := range d.All() {
|
||||||
|
ov, ok := o.fields[key]
|
||||||
|
if !ok {
|
||||||
|
// ov is implicitly Top. Bypass unification.
|
||||||
|
out.fields[key] = dv
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
exit := uf.enter("%s", key)
|
||||||
|
res, e2, err := dv.unify(ov, e, swap, uf)
|
||||||
|
exit.exit()
|
||||||
|
if err != nil {
|
||||||
|
return nil, envSet{}, err
|
||||||
|
} else if res.Domain == nil {
|
||||||
|
// No match.
|
||||||
|
return nil, bottomEnv, nil
|
||||||
|
}
|
||||||
|
out.fields[key] = res
|
||||||
|
e = e2
|
||||||
|
}
|
||||||
|
// Check keys of o that we didn't already check. These all implicitly match
|
||||||
|
// because we know the corresponding fields in d are all Top.
|
||||||
|
for key, dv := range o.All() {
|
||||||
|
if _, ok := d.fields[key]; !ok {
|
||||||
|
out.fields[key] = dv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, e, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v Tuple) unify(w Tuple, e envSet, swap bool, uf *unifier) (Domain, envSet, error) {
|
||||||
|
if v.repeat != nil && w.repeat != nil {
|
||||||
|
// Since we generate the content of these lazily, there's not much we
|
||||||
|
// can do but just stick them on a list to unify later.
|
||||||
|
return Tuple{repeat: concat(v.repeat, w.repeat)}, e, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand any repeated tuples.
|
||||||
|
tuples := make([]Tuple, 0, 2)
|
||||||
|
if v.repeat == nil {
|
||||||
|
tuples = append(tuples, v)
|
||||||
|
} else {
|
||||||
|
v2, e2 := v.doRepeat(e, len(w.vs))
|
||||||
|
tuples = append(tuples, v2...)
|
||||||
|
e = e2
|
||||||
|
}
|
||||||
|
if w.repeat == nil {
|
||||||
|
tuples = append(tuples, w)
|
||||||
|
} else {
|
||||||
|
w2, e2 := w.doRepeat(e, len(v.vs))
|
||||||
|
tuples = append(tuples, w2...)
|
||||||
|
e = e2
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now unify all of the tuples (usually this will be just 2 tuples)
|
||||||
|
out := tuples[0]
|
||||||
|
for _, t := range tuples[1:] {
|
||||||
|
if len(out.vs) != len(t.vs) {
|
||||||
|
uf.logf("tuple length mismatch")
|
||||||
|
return nil, bottomEnv, nil
|
||||||
|
}
|
||||||
|
zs := make([]*Value, len(out.vs))
|
||||||
|
for i, v1 := range out.vs {
|
||||||
|
exit := uf.enter("%d", i)
|
||||||
|
z, e2, err := v1.unify(t.vs[i], e, swap, uf)
|
||||||
|
exit.exit()
|
||||||
|
if err != nil {
|
||||||
|
return nil, envSet{}, err
|
||||||
|
} else if z.Domain == nil {
|
||||||
|
return nil, bottomEnv, nil
|
||||||
|
}
|
||||||
|
zs[i] = z
|
||||||
|
e = e2
|
||||||
|
}
|
||||||
|
out = Tuple{vs: zs}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, e, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// doRepeat creates a fixed-length tuple from a repeated tuple. The caller is
|
||||||
|
// expected to unify the returned tuples.
|
||||||
|
func (v Tuple) doRepeat(e envSet, n int) ([]Tuple, envSet) {
|
||||||
|
res := make([]Tuple, len(v.repeat))
|
||||||
|
for i, gen := range v.repeat {
|
||||||
|
res[i].vs = make([]*Value, n)
|
||||||
|
for j := range n {
|
||||||
|
res[i].vs[j], e = gen(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res, e
|
||||||
|
}
|
||||||
|
|
||||||
|
// unify intersects the domains of two [String]s. If it can prove that this
|
||||||
|
// domain is empty, it returns nil (bottom).
|
||||||
|
//
|
||||||
|
// TODO: Consider splitting literals and regexps into two domains.
|
||||||
|
func (v String) unify(w String) Domain {
|
||||||
|
// Unification is symmetric, so put them in order of string kind so we only
|
||||||
|
// have to deal with half the cases.
|
||||||
|
if v.kind > w.kind {
|
||||||
|
v, w = w, v
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v.kind {
|
||||||
|
case stringRegex:
|
||||||
|
switch w.kind {
|
||||||
|
case stringRegex:
|
||||||
|
// Construct a match against all of the regexps
|
||||||
|
return String{kind: stringRegex, re: slices.Concat(v.re, w.re)}
|
||||||
|
case stringExact:
|
||||||
|
for _, re := range v.re {
|
||||||
|
if !re.MatchString(w.exact) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
case stringExact:
|
||||||
|
if v.exact != w.exact {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
panic("bad string kind")
|
||||||
|
}
|
||||||
|
|
||||||
|
func concat[T any](s1, s2 []T) []T {
|
||||||
|
// Reuse s1 or s2 if possible.
|
||||||
|
if len(s1) == 0 {
|
||||||
|
return s2
|
||||||
|
}
|
||||||
|
return append(s1[:len(s1):len(s1)], s2...)
|
||||||
|
}
|
||||||
154
src/simd/_gen/unify/unify_test.go
Normal file
154
src/simd/_gen/unify/unify_test.go
Normal file
|
|
@ -0,0 +1,154 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package unify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUnify(t *testing.T) {
|
||||||
|
paths, err := filepath.Glob("testdata/*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(paths) == 0 {
|
||||||
|
t.Fatal("no testdata found")
|
||||||
|
}
|
||||||
|
for _, path := range paths {
|
||||||
|
// Skip paths starting with _ so experimental files can be added.
|
||||||
|
base := filepath.Base(path)
|
||||||
|
if base[0] == '_' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(base, ".yaml") {
|
||||||
|
t.Errorf("non-.yaml file in testdata: %s", base)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
base = strings.TrimSuffix(base, ".yaml")
|
||||||
|
|
||||||
|
t.Run(base, func(t *testing.T) {
|
||||||
|
testUnify(t, path)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testUnify(t *testing.T, path string) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
Skip bool
|
||||||
|
Name string
|
||||||
|
Unify []Closure
|
||||||
|
Want yaml.Node
|
||||||
|
All yaml.Node
|
||||||
|
}
|
||||||
|
dec := yaml.NewDecoder(f)
|
||||||
|
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
var tc testCase
|
||||||
|
err := dec.Decode(&tc)
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
name := tc.Name
|
||||||
|
if name == "" {
|
||||||
|
name = fmt.Sprint(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
if tc.Skip {
|
||||||
|
t.Skip("skip: true set in test case")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
p := recover()
|
||||||
|
if p != nil || t.Failed() {
|
||||||
|
// Redo with a trace
|
||||||
|
//
|
||||||
|
// TODO: Use t.Output() in Go 1.25.
|
||||||
|
var buf bytes.Buffer
|
||||||
|
Debug.UnifyLog = &buf
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
// If the original unify panicked, the second one
|
||||||
|
// probably will, too. Ignore it and let the first panic
|
||||||
|
// bubble.
|
||||||
|
recover()
|
||||||
|
}()
|
||||||
|
Unify(tc.Unify...)
|
||||||
|
}()
|
||||||
|
Debug.UnifyLog = nil
|
||||||
|
t.Logf("Trace:\n%s", buf.String())
|
||||||
|
}
|
||||||
|
if p != nil {
|
||||||
|
panic(p)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Unify the test cases
|
||||||
|
//
|
||||||
|
// TODO: Try reordering the inputs also
|
||||||
|
c, err := Unify(tc.Unify...)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: Tests of errors
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode the result back to YAML so we can check if it's structurally
|
||||||
|
// equal.
|
||||||
|
clean := func(val any) *yaml.Node {
|
||||||
|
var node yaml.Node
|
||||||
|
node.Encode(val)
|
||||||
|
for n := range allYamlNodes(&node) {
|
||||||
|
// Canonicalize the style. There may be other style flags we need to
|
||||||
|
// muck with.
|
||||||
|
n.Style &^= yaml.FlowStyle
|
||||||
|
n.HeadComment = ""
|
||||||
|
n.LineComment = ""
|
||||||
|
n.FootComment = ""
|
||||||
|
}
|
||||||
|
return &node
|
||||||
|
}
|
||||||
|
check := func(gotVal any, wantNode *yaml.Node) {
|
||||||
|
got, err := yaml.Marshal(clean(gotVal))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Encoding Value back to yaml failed: %s", err)
|
||||||
|
}
|
||||||
|
want, err := yaml.Marshal(clean(wantNode))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Encoding Want back to yaml failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Errorf("%s:%d:\nwant:\n%sgot\n%s", f.Name(), wantNode.Line, want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tc.Want.Kind != 0 {
|
||||||
|
check(c.val, &tc.Want)
|
||||||
|
}
|
||||||
|
if tc.All.Kind != 0 {
|
||||||
|
fVal := slices.Collect(c.All())
|
||||||
|
check(fVal, &tc.All)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
167
src/simd/_gen/unify/value.go
Normal file
167
src/simd/_gen/unify/value.go
Normal file
|
|
@ -0,0 +1,167 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package unify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Value represents a structured, non-deterministic value consisting of
|
||||||
|
// strings, tuples of Values, and string-keyed maps of Values. A
|
||||||
|
// non-deterministic Value will also contain variables, which are resolved via
|
||||||
|
// an environment as part of a [Closure].
|
||||||
|
//
|
||||||
|
// For debugging, a Value can also track the source position it was read from in
|
||||||
|
// an input file, and its provenance from other Values.
|
||||||
|
type Value struct {
|
||||||
|
Domain Domain
|
||||||
|
|
||||||
|
// A Value has either a pos or parents (or neither).
|
||||||
|
pos *Pos
|
||||||
|
parents *[2]*Value
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
topValue = &Value{Domain: Top{}}
|
||||||
|
bottomValue = &Value{Domain: nil}
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewValue returns a new [Value] with the given domain and no position
|
||||||
|
// information.
|
||||||
|
func NewValue(d Domain) *Value {
|
||||||
|
return &Value{Domain: d}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewValuePos returns a new [Value] with the given domain at position p.
|
||||||
|
func NewValuePos(d Domain, p Pos) *Value {
|
||||||
|
return &Value{Domain: d, pos: &p}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newValueFrom returns a new [Value] with the given domain that copies the
|
||||||
|
// position information of p.
|
||||||
|
func newValueFrom(d Domain, p *Value) *Value {
|
||||||
|
return &Value{Domain: d, pos: p.pos, parents: p.parents}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unified(d Domain, p1, p2 *Value) *Value {
|
||||||
|
return &Value{Domain: d, parents: &[2]*Value{p1, p2}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Value) Pos() Pos {
|
||||||
|
if v.pos == nil {
|
||||||
|
return Pos{}
|
||||||
|
}
|
||||||
|
return *v.pos
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Value) PosString() string {
|
||||||
|
var b []byte
|
||||||
|
for root := range v.Provenance() {
|
||||||
|
if len(b) > 0 {
|
||||||
|
b = append(b, ' ')
|
||||||
|
}
|
||||||
|
b, _ = root.pos.AppendText(b)
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Value) WhyNotExact() string {
|
||||||
|
if v.Domain == nil {
|
||||||
|
return "v.Domain is nil"
|
||||||
|
}
|
||||||
|
return v.Domain.WhyNotExact()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Value) Exact() bool {
|
||||||
|
if v.Domain == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return v.Domain.Exact()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode decodes v into a Go value.
|
||||||
|
//
|
||||||
|
// v must be exact, except that it can include Top. into must be a pointer.
|
||||||
|
// [Def]s are decoded into structs. [Tuple]s are decoded into slices. [String]s
|
||||||
|
// are decoded into strings or ints. Any field can itself be a pointer to one of
|
||||||
|
// these types. Top can be decoded into a pointer-typed field and will set the
|
||||||
|
// field to nil. Anything else will allocate a value if necessary.
|
||||||
|
//
|
||||||
|
// Any type may implement [Decoder], in which case its DecodeUnified method will
|
||||||
|
// be called instead of using the default decoding scheme.
|
||||||
|
func (v *Value) Decode(into any) error {
|
||||||
|
rv := reflect.ValueOf(into)
|
||||||
|
if rv.Kind() != reflect.Pointer {
|
||||||
|
return fmt.Errorf("cannot decode into non-pointer %T", into)
|
||||||
|
}
|
||||||
|
return decodeReflect(v, rv.Elem())
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeReflect(v *Value, rv reflect.Value) error {
|
||||||
|
var ptr reflect.Value
|
||||||
|
if rv.Kind() == reflect.Pointer {
|
||||||
|
if rv.IsNil() {
|
||||||
|
// Transparently allocate through pointers, *except* for Top, which
|
||||||
|
// wants to set the pointer to nil.
|
||||||
|
//
|
||||||
|
// TODO: Drop this condition if I switch to an explicit Optional[T]
|
||||||
|
// or move the Top logic into Def.
|
||||||
|
if _, ok := v.Domain.(Top); !ok {
|
||||||
|
// Allocate the value to fill in, but don't actually store it in
|
||||||
|
// the pointer until we successfully decode.
|
||||||
|
ptr = rv
|
||||||
|
rv = reflect.New(rv.Type().Elem()).Elem()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
rv = rv.Elem()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if reflect.PointerTo(rv.Type()).Implements(decoderType) {
|
||||||
|
// Use the custom decoder.
|
||||||
|
err = rv.Addr().Interface().(Decoder).DecodeUnified(v)
|
||||||
|
} else {
|
||||||
|
err = v.Domain.decode(rv)
|
||||||
|
}
|
||||||
|
if err == nil && ptr.IsValid() {
|
||||||
|
ptr.Set(rv.Addr())
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decoder can be implemented by types as a custom implementation of [Decode]
|
||||||
|
// for that type.
|
||||||
|
type Decoder interface {
|
||||||
|
DecodeUnified(v *Value) error
|
||||||
|
}
|
||||||
|
|
||||||
|
var decoderType = reflect.TypeOf((*Decoder)(nil)).Elem()
|
||||||
|
|
||||||
|
// Provenance iterates over all of the source Values that have contributed to
|
||||||
|
// this Value.
|
||||||
|
func (v *Value) Provenance() iter.Seq[*Value] {
|
||||||
|
return func(yield func(*Value) bool) {
|
||||||
|
var rec func(d *Value) bool
|
||||||
|
rec = func(d *Value) bool {
|
||||||
|
if d.pos != nil {
|
||||||
|
if !yield(d) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if d.parents != nil {
|
||||||
|
for _, p := range d.parents {
|
||||||
|
if !rec(p) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
rec(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
50
src/simd/_gen/unify/value_test.go
Normal file
50
src/simd/_gen/unify/value_test.go
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package unify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExampleClosure_All_tuple() {
|
||||||
|
v := mustParse(`
|
||||||
|
- !sum [1, 2]
|
||||||
|
- !sum [3, 4]
|
||||||
|
`)
|
||||||
|
printYaml(slices.Collect(v.All()))
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// - [1, 3]
|
||||||
|
// - [1, 4]
|
||||||
|
// - [2, 3]
|
||||||
|
// - [2, 4]
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleClosure_All_def() {
|
||||||
|
v := mustParse(`
|
||||||
|
a: !sum [1, 2]
|
||||||
|
b: !sum [3, 4]
|
||||||
|
c: 5
|
||||||
|
`)
|
||||||
|
printYaml(slices.Collect(v.All()))
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// - {a: 1, b: 3, c: 5}
|
||||||
|
// - {a: 1, b: 4, c: 5}
|
||||||
|
// - {a: 2, b: 3, c: 5}
|
||||||
|
// - {a: 2, b: 4, c: 5}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkDecode[T any](t *testing.T, got *Value, want T) {
|
||||||
|
var gotT T
|
||||||
|
if err := got.Decode(&gotT); err != nil {
|
||||||
|
t.Fatalf("Decode failed: %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(&gotT, &want) {
|
||||||
|
t.Fatalf("got:\n%s\nwant:\n%s", prettyYaml(gotT), prettyYaml(want))
|
||||||
|
}
|
||||||
|
}
|
||||||
619
src/simd/_gen/unify/yaml.go
Normal file
619
src/simd/_gen/unify/yaml.go
Normal file
|
|
@ -0,0 +1,619 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package unify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReadOpts provides options to [Read] and related functions. The zero value is
|
||||||
|
// the default options.
|
||||||
|
type ReadOpts struct {
|
||||||
|
// FS, if non-nil, is the file system from which to resolve !import file
|
||||||
|
// names.
|
||||||
|
FS fs.FS
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads a [Closure] in YAML format from r, using path for error messages.
|
||||||
|
//
|
||||||
|
// It maps YAML nodes into terminal Values as follows:
|
||||||
|
//
|
||||||
|
// - "_" or !top _ is the top value ([Top]).
|
||||||
|
//
|
||||||
|
// - "_|_" or !bottom _ is the bottom value. This is an error during
|
||||||
|
// unmarshaling, but can appear in marshaled values.
|
||||||
|
//
|
||||||
|
// - "$<name>" or !var <name> is a variable ([Var]). Everywhere the same name
|
||||||
|
// appears within a single unmarshal operation, it is mapped to the same
|
||||||
|
// variable. Different unmarshal operations get different variables, even if
|
||||||
|
// they have the same string name.
|
||||||
|
//
|
||||||
|
// - !regex "x" is a regular expression ([String]), as is any string that
|
||||||
|
// doesn't match "_", "_|_", or "$...". Regular expressions are implicitly
|
||||||
|
// anchored at the beginning and end. If the string doesn't contain any
|
||||||
|
// meta-characters (that is, it's a "literal" regular expression), then it's
|
||||||
|
// treated as an exact string.
|
||||||
|
//
|
||||||
|
// - !string "x", or any int, float, bool, or binary value is an exact string
|
||||||
|
// ([String]).
|
||||||
|
//
|
||||||
|
// - !regex [x, y, ...] is an intersection of regular expressions ([String]).
|
||||||
|
//
|
||||||
|
// It maps YAML nodes into non-terminal Values as follows:
|
||||||
|
//
|
||||||
|
// - Sequence nodes like [x, y, z] are tuples ([Tuple]).
|
||||||
|
//
|
||||||
|
// - !repeat [x] is a repeated tuple ([Tuple]), which is 0 or more instances of
|
||||||
|
// x. There must be exactly one element in the list.
|
||||||
|
//
|
||||||
|
// - Mapping nodes like {a: x, b: y} are defs ([Def]). Any fields not listed are
|
||||||
|
// implicitly top.
|
||||||
|
//
|
||||||
|
// - !sum [x, y, z] is a sum of its children. This can be thought of as a union
|
||||||
|
// of the values x, y, and z, or as a non-deterministic choice between x, y, and
|
||||||
|
// z. If a variable appears both inside the sum and outside of it, only the
|
||||||
|
// non-deterministic choice view really works. The unifier does not directly
|
||||||
|
// implement sums; instead, this is decoded as a fresh variable that's
|
||||||
|
// simultaneously bound to x, y, and z.
|
||||||
|
//
|
||||||
|
// - !import glob is like a !sum, but its children are read from all files
|
||||||
|
// matching the given glob pattern, which is interpreted relative to the current
|
||||||
|
// file path. Each file gets its own variable scope.
|
||||||
|
func Read(r io.Reader, path string, opts ReadOpts) (Closure, error) {
|
||||||
|
dec := yamlDecoder{opts: opts, path: path, env: topEnv}
|
||||||
|
v, err := dec.read(r)
|
||||||
|
if err != nil {
|
||||||
|
return Closure{}, err
|
||||||
|
}
|
||||||
|
return dec.close(v), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFile reads a [Closure] in YAML format from a file.
|
||||||
|
//
|
||||||
|
// The file must consist of a single YAML document.
|
||||||
|
//
|
||||||
|
// If opts.FS is not set, this sets it to a FS rooted at path's directory.
|
||||||
|
//
|
||||||
|
// See [Read] for details.
|
||||||
|
func ReadFile(path string, opts ReadOpts) (Closure, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return Closure{}, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if opts.FS == nil {
|
||||||
|
opts.FS = os.DirFS(filepath.Dir(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
return Read(f, path, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalYAML implements [yaml.Unmarshaler].
|
||||||
|
//
|
||||||
|
// Since there is no way to pass [ReadOpts] to this function, it assumes default
|
||||||
|
// options.
|
||||||
|
func (c *Closure) UnmarshalYAML(node *yaml.Node) error {
|
||||||
|
dec := yamlDecoder{path: "<yaml.Node>", env: topEnv}
|
||||||
|
v, err := dec.root(node)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*c = dec.close(v)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type yamlDecoder struct {
|
||||||
|
opts ReadOpts
|
||||||
|
path string
|
||||||
|
|
||||||
|
vars map[string]*ident
|
||||||
|
nSums int
|
||||||
|
|
||||||
|
env envSet
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dec *yamlDecoder) read(r io.Reader) (*Value, error) {
|
||||||
|
n, err := readOneNode(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%s: %w", dec.path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode YAML node to a Value
|
||||||
|
v, err := dec.root(n)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%s: %w", dec.path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readOneNode reads a single YAML document from r and returns an error if there
|
||||||
|
// are more documents in r.
|
||||||
|
func readOneNode(r io.Reader) (*yaml.Node, error) {
|
||||||
|
yd := yaml.NewDecoder(r)
|
||||||
|
|
||||||
|
// Decode as a YAML node
|
||||||
|
var node yaml.Node
|
||||||
|
if err := yd.Decode(&node); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
np := &node
|
||||||
|
if np.Kind == yaml.DocumentNode {
|
||||||
|
np = node.Content[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure there are no more YAML docs in this file
|
||||||
|
if err := yd.Decode(nil); err == nil {
|
||||||
|
return nil, fmt.Errorf("must not contain multiple documents")
|
||||||
|
} else if err != io.EOF {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return np, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// root parses the root of a file.
|
||||||
|
func (dec *yamlDecoder) root(node *yaml.Node) (*Value, error) {
|
||||||
|
// Prepare for variable name resolution in this file. This may be a nested
|
||||||
|
// root, so restore the current values when we're done.
|
||||||
|
oldVars, oldNSums := dec.vars, dec.nSums
|
||||||
|
defer func() {
|
||||||
|
dec.vars, dec.nSums = oldVars, oldNSums
|
||||||
|
}()
|
||||||
|
dec.vars = make(map[string]*ident, 0)
|
||||||
|
dec.nSums = 0
|
||||||
|
|
||||||
|
return dec.value(node)
|
||||||
|
}
|
||||||
|
|
||||||
|
// close wraps a decoded [Value] into a [Closure].
|
||||||
|
func (dec *yamlDecoder) close(v *Value) Closure {
|
||||||
|
return Closure{v, dec.env}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dec *yamlDecoder) value(node *yaml.Node) (vOut *Value, errOut error) {
|
||||||
|
pos := &Pos{Path: dec.path, Line: node.Line}
|
||||||
|
|
||||||
|
// Resolve alias nodes.
|
||||||
|
if node.Kind == yaml.AliasNode {
|
||||||
|
node = node.Alias
|
||||||
|
}
|
||||||
|
|
||||||
|
mk := func(d Domain) (*Value, error) {
|
||||||
|
v := &Value{Domain: d, pos: pos}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
mk2 := func(d Domain, err error) (*Value, error) {
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return mk(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// is tests the kind and long tag of node.
|
||||||
|
is := func(kind yaml.Kind, tag string) bool {
|
||||||
|
return node.Kind == kind && node.LongTag() == tag
|
||||||
|
}
|
||||||
|
isExact := func() bool {
|
||||||
|
if node.Kind != yaml.ScalarNode {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// We treat any string-ish YAML node as a string.
|
||||||
|
switch node.LongTag() {
|
||||||
|
case "!string", "tag:yaml.org,2002:int", "tag:yaml.org,2002:float", "tag:yaml.org,2002:bool", "tag:yaml.org,2002:binary":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// !!str nodes provide a short-hand syntax for several leaf domains that are
|
||||||
|
// also available under explicit tags. To simplify checking below, we set
|
||||||
|
// strVal to non-"" only for !!str nodes.
|
||||||
|
strVal := ""
|
||||||
|
isStr := is(yaml.ScalarNode, "tag:yaml.org,2002:str")
|
||||||
|
if isStr {
|
||||||
|
strVal = node.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case is(yaml.ScalarNode, "!var"):
|
||||||
|
strVal = "$" + node.Value
|
||||||
|
fallthrough
|
||||||
|
case strings.HasPrefix(strVal, "$"):
|
||||||
|
id, ok := dec.vars[strVal]
|
||||||
|
if !ok {
|
||||||
|
// We encode different idents with the same string name by adding a
|
||||||
|
// #N suffix. Strip that off so it doesn't accumulate. This isn't
|
||||||
|
// meant to be used in user-written input, though nothing stops that.
|
||||||
|
name, _, _ := strings.Cut(strVal, "#")
|
||||||
|
id = &ident{name: name}
|
||||||
|
dec.vars[strVal] = id
|
||||||
|
dec.env = dec.env.bind(id, topValue)
|
||||||
|
}
|
||||||
|
return mk(Var{id: id})
|
||||||
|
|
||||||
|
case strVal == "_" || is(yaml.ScalarNode, "!top"):
|
||||||
|
return mk(Top{})
|
||||||
|
|
||||||
|
case strVal == "_|_" || is(yaml.ScalarNode, "!bottom"):
|
||||||
|
return nil, errors.New("found bottom")
|
||||||
|
|
||||||
|
case isExact():
|
||||||
|
val := node.Value
|
||||||
|
return mk(NewStringExact(val))
|
||||||
|
|
||||||
|
case isStr || is(yaml.ScalarNode, "!regex"):
|
||||||
|
// Any other string we treat as a regex. This will produce an exact
|
||||||
|
// string anyway if the regex is literal.
|
||||||
|
val := node.Value
|
||||||
|
return mk2(NewStringRegex(val))
|
||||||
|
|
||||||
|
case is(yaml.SequenceNode, "!regex"):
|
||||||
|
var vals []string
|
||||||
|
if err := node.Decode(&vals); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return mk2(NewStringRegex(vals...))
|
||||||
|
|
||||||
|
case is(yaml.MappingNode, "tag:yaml.org,2002:map"):
|
||||||
|
var db DefBuilder
|
||||||
|
for i := 0; i < len(node.Content); i += 2 {
|
||||||
|
key := node.Content[i]
|
||||||
|
if key.Kind != yaml.ScalarNode {
|
||||||
|
return nil, fmt.Errorf("non-scalar key %q", key.Value)
|
||||||
|
}
|
||||||
|
val, err := dec.value(node.Content[i+1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
db.Add(key.Value, val)
|
||||||
|
}
|
||||||
|
return mk(db.Build())
|
||||||
|
|
||||||
|
case is(yaml.SequenceNode, "tag:yaml.org,2002:seq"):
|
||||||
|
elts := node.Content
|
||||||
|
vs := make([]*Value, 0, len(elts))
|
||||||
|
for _, elt := range elts {
|
||||||
|
v, err := dec.value(elt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
vs = append(vs, v)
|
||||||
|
}
|
||||||
|
return mk(NewTuple(vs...))
|
||||||
|
|
||||||
|
case is(yaml.SequenceNode, "!repeat") || is(yaml.SequenceNode, "!repeat-unify"):
|
||||||
|
// !repeat must have one child. !repeat-unify is used internally for
|
||||||
|
// delayed unification, and is the same, it's just allowed to have more
|
||||||
|
// than one child.
|
||||||
|
if node.LongTag() == "!repeat" && len(node.Content) != 1 {
|
||||||
|
return nil, fmt.Errorf("!repeat must have exactly one child")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the children to make sure they're well-formed, but otherwise
|
||||||
|
// discard that decoding and do it again every time we need a new
|
||||||
|
// element.
|
||||||
|
var gen []func(e envSet) (*Value, envSet)
|
||||||
|
origEnv := dec.env
|
||||||
|
elts := node.Content
|
||||||
|
for i, elt := range elts {
|
||||||
|
_, err := dec.value(elt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Undo any effects on the environment. We *do* keep any named
|
||||||
|
// variables that were added to the vars map in case they were
|
||||||
|
// introduced within the element.
|
||||||
|
dec.env = origEnv
|
||||||
|
// Add a generator function
|
||||||
|
gen = append(gen, func(e envSet) (*Value, envSet) {
|
||||||
|
dec.env = e
|
||||||
|
// TODO: If this is in a sum, this tends to generate a ton of
|
||||||
|
// fresh variables that are different on each branch of the
|
||||||
|
// parent sum. Does it make sense to hold on to the i'th value
|
||||||
|
// of the tuple after we've generated it?
|
||||||
|
v, err := dec.value(elts[i])
|
||||||
|
if err != nil {
|
||||||
|
// It worked the first time, so this really shouldn't hapen.
|
||||||
|
panic("decoding repeat element failed")
|
||||||
|
}
|
||||||
|
return v, dec.env
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return mk(NewRepeat(gen...))
|
||||||
|
|
||||||
|
case is(yaml.SequenceNode, "!sum"):
|
||||||
|
vs := make([]*Value, 0, len(node.Content))
|
||||||
|
for _, elt := range node.Content {
|
||||||
|
v, err := dec.value(elt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
vs = append(vs, v)
|
||||||
|
}
|
||||||
|
if len(vs) == 1 {
|
||||||
|
return vs[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// A sum is implemented as a fresh variable that's simultaneously bound
|
||||||
|
// to each of the descendants.
|
||||||
|
id := &ident{name: fmt.Sprintf("sum%d", dec.nSums)}
|
||||||
|
dec.nSums++
|
||||||
|
dec.env = dec.env.bind(id, vs...)
|
||||||
|
return mk(Var{id: id})
|
||||||
|
|
||||||
|
case is(yaml.ScalarNode, "!import"):
|
||||||
|
if dec.opts.FS == nil {
|
||||||
|
return nil, fmt.Errorf("!import not allowed (ReadOpts.FS not set)")
|
||||||
|
}
|
||||||
|
pat := node.Value
|
||||||
|
|
||||||
|
if !fs.ValidPath(pat) {
|
||||||
|
// This will result in Glob returning no results. Give a more useful
|
||||||
|
// error message for this case.
|
||||||
|
return nil, fmt.Errorf("!import path must not contain '.' or '..'")
|
||||||
|
}
|
||||||
|
|
||||||
|
ms, err := fs.Glob(dec.opts.FS, pat)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolving !import: %w", err)
|
||||||
|
}
|
||||||
|
if len(ms) == 0 {
|
||||||
|
return nil, fmt.Errorf("!import did not match any files")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse each file
|
||||||
|
vs := make([]*Value, 0, len(ms))
|
||||||
|
for _, m := range ms {
|
||||||
|
v, err := dec.import1(m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
vs = append(vs, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a sum.
|
||||||
|
if len(vs) == 1 {
|
||||||
|
return vs[0], nil
|
||||||
|
}
|
||||||
|
id := &ident{name: "import"}
|
||||||
|
dec.env = dec.env.bind(id, vs...)
|
||||||
|
return mk(Var{id: id})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unknown node kind %d %v", node.Kind, node.Tag)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dec *yamlDecoder) import1(path string) (*Value, error) {
|
||||||
|
// Make sure we can open the path first.
|
||||||
|
f, err := dec.opts.FS.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("!import failed: %w", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// Prepare the enter path.
|
||||||
|
oldFS, oldPath := dec.opts.FS, dec.path
|
||||||
|
defer func() {
|
||||||
|
dec.opts.FS, dec.path = oldFS, oldPath
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Enter path, which is relative to the current path's directory.
|
||||||
|
newPath := filepath.Join(filepath.Dir(dec.path), path)
|
||||||
|
subFS, err := fs.Sub(dec.opts.FS, filepath.Dir(path))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dec.opts.FS, dec.path = subFS, newPath
|
||||||
|
|
||||||
|
// Parse the file.
|
||||||
|
return dec.read(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
type yamlEncoder struct {
|
||||||
|
idp identPrinter
|
||||||
|
e envSet // We track the environment for !repeat nodes.
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Switch some Value marshaling to Closure?
|
||||||
|
|
||||||
|
func (c Closure) MarshalYAML() (any, error) {
|
||||||
|
// TODO: If the environment is trivial, just marshal the value.
|
||||||
|
enc := &yamlEncoder{}
|
||||||
|
return enc.closure(c), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Closure) String() string {
|
||||||
|
b, err := yaml.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("marshal failed: %s", err)
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Value) MarshalYAML() (any, error) {
|
||||||
|
enc := &yamlEncoder{}
|
||||||
|
return enc.value(v), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Value) String() string {
|
||||||
|
b, err := yaml.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("marshal failed: %s", err)
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *yamlEncoder) closure(c Closure) *yaml.Node {
|
||||||
|
enc.e = c.env
|
||||||
|
var n yaml.Node
|
||||||
|
n.Kind = yaml.MappingNode
|
||||||
|
n.Tag = "!closure"
|
||||||
|
n.Content = make([]*yaml.Node, 4)
|
||||||
|
n.Content[0] = new(yaml.Node)
|
||||||
|
n.Content[0].SetString("env")
|
||||||
|
n.Content[2] = new(yaml.Node)
|
||||||
|
n.Content[2].SetString("in")
|
||||||
|
n.Content[3] = enc.value(c.val)
|
||||||
|
// Fill in the env after we've written the value in case value encoding
|
||||||
|
// affects the env.
|
||||||
|
n.Content[1] = enc.env(enc.e)
|
||||||
|
enc.e = envSet{} // Allow GC'ing the env
|
||||||
|
return &n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (enc *yamlEncoder) env(e envSet) *yaml.Node {
|
||||||
|
var encode func(e *envExpr) *yaml.Node
|
||||||
|
encode = func(e *envExpr) *yaml.Node {
|
||||||
|
var n yaml.Node
|
||||||
|
switch e.kind {
|
||||||
|
default:
|
||||||
|
panic("bad kind")
|
||||||
|
case envZero:
|
||||||
|
n.SetString("0")
|
||||||
|
case envUnit:
|
||||||
|
n.SetString("1")
|
||||||
|
case envBinding:
|
||||||
|
var id yaml.Node
|
||||||
|
id.SetString(enc.idp.unique(e.id))
|
||||||
|
n.Kind = yaml.MappingNode
|
||||||
|
n.Content = []*yaml.Node{&id, enc.value(e.val)}
|
||||||
|
case envProduct, envSum:
|
||||||
|
n.Kind = yaml.SequenceNode
|
||||||
|
if e.kind == envProduct {
|
||||||
|
n.Tag = "!product"
|
||||||
|
} else {
|
||||||
|
n.Tag = "!sum"
|
||||||
|
}
|
||||||
|
for _, e2 := range e.operands {
|
||||||
|
n.Content = append(n.Content, encode(e2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &n
|
||||||
|
}
|
||||||
|
return encode(e.root)
|
||||||
|
}
|
||||||
|
|
||||||
|
var yamlIntRe = regexp.MustCompile(`^-?[0-9]+$`)
|
||||||
|
|
||||||
|
func (enc *yamlEncoder) value(v *Value) *yaml.Node {
|
||||||
|
var n yaml.Node
|
||||||
|
switch d := v.Domain.(type) {
|
||||||
|
case nil:
|
||||||
|
// Not allowed by unmarshaler, but useful for understanding when
|
||||||
|
// something goes horribly wrong.
|
||||||
|
//
|
||||||
|
// TODO: We might be able to track useful provenance for this, which
|
||||||
|
// would really help with debugging unexpected bottoms.
|
||||||
|
n.SetString("_|_")
|
||||||
|
return &n
|
||||||
|
|
||||||
|
case Top:
|
||||||
|
n.SetString("_")
|
||||||
|
return &n
|
||||||
|
|
||||||
|
case Def:
|
||||||
|
n.Kind = yaml.MappingNode
|
||||||
|
for k, elt := range d.All() {
|
||||||
|
var kn yaml.Node
|
||||||
|
kn.SetString(k)
|
||||||
|
n.Content = append(n.Content, &kn, enc.value(elt))
|
||||||
|
}
|
||||||
|
n.HeadComment = v.PosString()
|
||||||
|
return &n
|
||||||
|
|
||||||
|
case Tuple:
|
||||||
|
n.Kind = yaml.SequenceNode
|
||||||
|
if d.repeat == nil {
|
||||||
|
for _, elt := range d.vs {
|
||||||
|
n.Content = append(n.Content, enc.value(elt))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if len(d.repeat) == 1 {
|
||||||
|
n.Tag = "!repeat"
|
||||||
|
} else {
|
||||||
|
n.Tag = "!repeat-unify"
|
||||||
|
}
|
||||||
|
// TODO: I'm not positive this will round-trip everything correctly.
|
||||||
|
for _, gen := range d.repeat {
|
||||||
|
v, e := gen(enc.e)
|
||||||
|
enc.e = e
|
||||||
|
n.Content = append(n.Content, enc.value(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &n
|
||||||
|
|
||||||
|
case String:
|
||||||
|
switch d.kind {
|
||||||
|
case stringExact:
|
||||||
|
n.SetString(d.exact)
|
||||||
|
switch {
|
||||||
|
// Make this into a "nice" !!int node if I can.
|
||||||
|
case yamlIntRe.MatchString(d.exact):
|
||||||
|
n.Tag = "tag:yaml.org,2002:int"
|
||||||
|
|
||||||
|
// Or a "nice" !!bool node.
|
||||||
|
case d.exact == "false" || d.exact == "true":
|
||||||
|
n.Tag = "tag:yaml.org,2002:bool"
|
||||||
|
|
||||||
|
// If this doesn't require escaping, leave it as a str node to avoid
|
||||||
|
// the annoying YAML tags. Otherwise, mark it as an exact string.
|
||||||
|
// Alternatively, we could always emit a str node with regexp
|
||||||
|
// quoting.
|
||||||
|
case d.exact != regexp.QuoteMeta(d.exact):
|
||||||
|
n.Tag = "!string"
|
||||||
|
}
|
||||||
|
return &n
|
||||||
|
case stringRegex:
|
||||||
|
o := make([]string, 0, 1)
|
||||||
|
for _, re := range d.re {
|
||||||
|
s := re.String()
|
||||||
|
s = strings.TrimSuffix(strings.TrimPrefix(s, `\A(?:`), `)\z`)
|
||||||
|
o = append(o, s)
|
||||||
|
}
|
||||||
|
if len(o) == 1 {
|
||||||
|
n.SetString(o[0])
|
||||||
|
return &n
|
||||||
|
}
|
||||||
|
n.Encode(o)
|
||||||
|
n.Tag = "!regex"
|
||||||
|
return &n
|
||||||
|
}
|
||||||
|
panic("bad String kind")
|
||||||
|
|
||||||
|
case Var:
|
||||||
|
// TODO: If Var only appears once in the whole Value and is independent
|
||||||
|
// in the environment (part of a term that is only over Var), then emit
|
||||||
|
// this as a !sum instead.
|
||||||
|
if false {
|
||||||
|
var vs []*Value // TODO: Get values of this var.
|
||||||
|
if len(vs) == 1 {
|
||||||
|
return enc.value(vs[0])
|
||||||
|
}
|
||||||
|
n.Kind = yaml.SequenceNode
|
||||||
|
n.Tag = "!sum"
|
||||||
|
for _, elt := range vs {
|
||||||
|
n.Content = append(n.Content, enc.value(elt))
|
||||||
|
}
|
||||||
|
return &n
|
||||||
|
}
|
||||||
|
n.SetString(enc.idp.unique(d.id))
|
||||||
|
if !strings.HasPrefix(d.id.name, "$") {
|
||||||
|
n.Tag = "!var"
|
||||||
|
}
|
||||||
|
return &n
|
||||||
|
}
|
||||||
|
panic(fmt.Sprintf("unknown domain type %T", v.Domain))
|
||||||
|
}
|
||||||
202
src/simd/_gen/unify/yaml_test.go
Normal file
202
src/simd/_gen/unify/yaml_test.go
Normal file
|
|
@ -0,0 +1,202 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package unify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustParse(expr string) Closure {
|
||||||
|
var c Closure
|
||||||
|
if err := yaml.Unmarshal([]byte(expr), &c); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func oneValue(t *testing.T, c Closure) *Value {
|
||||||
|
t.Helper()
|
||||||
|
var v *Value
|
||||||
|
var i int
|
||||||
|
for v = range c.All() {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i != 1 {
|
||||||
|
t.Fatalf("expected 1 value, got %d", i)
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func printYaml(val any) {
|
||||||
|
fmt.Println(prettyYaml(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func prettyYaml(val any) string {
|
||||||
|
b, err := yaml.Marshal(val)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
var node yaml.Node
|
||||||
|
if err := yaml.Unmarshal(b, &node); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map lines to start offsets. We'll use this to figure out when nodes are
|
||||||
|
// "small" and should use inline style.
|
||||||
|
lines := []int{-1, 0}
|
||||||
|
for pos := 0; pos < len(b); {
|
||||||
|
next := bytes.IndexByte(b[pos:], '\n')
|
||||||
|
if next == -1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pos += next + 1
|
||||||
|
lines = append(lines, pos)
|
||||||
|
}
|
||||||
|
lines = append(lines, len(b))
|
||||||
|
|
||||||
|
// Strip comments and switch small nodes to inline style
|
||||||
|
cleanYaml(&node, lines, len(b))
|
||||||
|
|
||||||
|
b, err = yaml.Marshal(&node)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanYaml(node *yaml.Node, lines []int, endPos int) {
|
||||||
|
node.HeadComment = ""
|
||||||
|
node.FootComment = ""
|
||||||
|
node.LineComment = ""
|
||||||
|
|
||||||
|
for i, n2 := range node.Content {
|
||||||
|
end2 := endPos
|
||||||
|
if i < len(node.Content)-1 {
|
||||||
|
end2 = lines[node.Content[i+1].Line]
|
||||||
|
}
|
||||||
|
cleanYaml(n2, lines, end2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use inline style?
|
||||||
|
switch node.Kind {
|
||||||
|
case yaml.MappingNode, yaml.SequenceNode:
|
||||||
|
if endPos-lines[node.Line] < 40 {
|
||||||
|
node.Style = yaml.FlowStyle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func allYamlNodes(n *yaml.Node) iter.Seq[*yaml.Node] {
|
||||||
|
return func(yield func(*yaml.Node) bool) {
|
||||||
|
if !yield(n) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, n2 := range n.Content {
|
||||||
|
for n3 := range allYamlNodes(n2) {
|
||||||
|
if !yield(n3) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundTripString(t *testing.T) {
|
||||||
|
// Check that we can round-trip a string with regexp meta-characters in it.
|
||||||
|
const y = `!string test*`
|
||||||
|
t.Logf("input:\n%s", y)
|
||||||
|
|
||||||
|
v1 := oneValue(t, mustParse(y))
|
||||||
|
var buf1 strings.Builder
|
||||||
|
enc := yaml.NewEncoder(&buf1)
|
||||||
|
if err := enc.Encode(v1); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
enc.Close()
|
||||||
|
t.Logf("after parse 1:\n%s", buf1.String())
|
||||||
|
|
||||||
|
v2 := oneValue(t, mustParse(buf1.String()))
|
||||||
|
var buf2 strings.Builder
|
||||||
|
enc = yaml.NewEncoder(&buf2)
|
||||||
|
if err := enc.Encode(v2); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
enc.Close()
|
||||||
|
t.Logf("after parse 2:\n%s", buf2.String())
|
||||||
|
|
||||||
|
if buf1.String() != buf2.String() {
|
||||||
|
t.Fatal("parse 1 and parse 2 differ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmptyString(t *testing.T) {
|
||||||
|
// Regression test. Make sure an empty string is parsed as an exact string,
|
||||||
|
// not a regexp.
|
||||||
|
const y = `""`
|
||||||
|
t.Logf("input:\n%s", y)
|
||||||
|
|
||||||
|
v1 := oneValue(t, mustParse(y))
|
||||||
|
if !v1.Exact() {
|
||||||
|
t.Fatal("expected exact string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestImport(t *testing.T) {
|
||||||
|
// Test a basic import
|
||||||
|
main := strings.NewReader("!import x/y.yaml")
|
||||||
|
fs := fstest.MapFS{
|
||||||
|
// Test a glob import with a relative path
|
||||||
|
"x/y.yaml": {Data: []byte("!import y/*.yaml")},
|
||||||
|
"x/y/z.yaml": {Data: []byte("42")},
|
||||||
|
}
|
||||||
|
cl, err := Read(main, "x.yaml", ReadOpts{FS: fs})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
x := 42
|
||||||
|
checkDecode(t, oneValue(t, cl), &x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestImportEscape(t *testing.T) {
|
||||||
|
// Make sure an import can't escape its subdirectory.
|
||||||
|
main := strings.NewReader("!import x/y.yaml")
|
||||||
|
fs := fstest.MapFS{
|
||||||
|
"x/y.yaml": {Data: []byte("!import ../y/*.yaml")},
|
||||||
|
"y/z.yaml": {Data: []byte("42")},
|
||||||
|
}
|
||||||
|
_, err := Read(main, "x.yaml", ReadOpts{FS: fs})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("relative !import should have failed")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "must not contain") {
|
||||||
|
t.Fatalf("unexpected error %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestImportScope(t *testing.T) {
|
||||||
|
// Test that imports have different variable scopes.
|
||||||
|
main := strings.NewReader("[!import y.yaml, !import y.yaml]")
|
||||||
|
fs := fstest.MapFS{
|
||||||
|
"y.yaml": {Data: []byte("$v")},
|
||||||
|
}
|
||||||
|
cl1, err := Read(main, "x.yaml", ReadOpts{FS: fs})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
cl2 := mustParse("[1, 2]")
|
||||||
|
res, err := Unify(cl1, cl2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
checkDecode(t, oneValue(t, res), []int{1, 2})
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue