diff --git a/src/cmd/compile/internal/ssa/check.go b/src/cmd/compile/internal/ssa/check.go index 398f06053e1..4ea85613040 100644 --- a/src/cmd/compile/internal/ssa/check.go +++ b/src/cmd/compile/internal/ssa/check.go @@ -152,7 +152,7 @@ func checkFunc(f *Func) { case auxUInt8: // Cast to int8 due to requirement of AuxInt, check its comment for details. if v.AuxInt != int64(int8(v.AuxInt)) { - f.Fatalf("bad uint8 AuxInt value for %v", v) + f.Fatalf("bad uint8 AuxInt value for %v, saw %d but need %d", v, v.AuxInt, int64(int8(v.AuxInt))) } canHaveAuxInt = true case auxFloat32: diff --git a/src/cmd/compile/internal/ssagen/intrinsics.go b/src/cmd/compile/internal/ssagen/intrinsics.go index ce9a76f6b84..985d899a71e 100644 --- a/src/cmd/compile/internal/ssagen/intrinsics.go +++ b/src/cmd/compile/internal/ssagen/intrinsics.go @@ -1614,6 +1614,7 @@ func initIntrinsics(cfg *intrinsicBuildConfig) { return nil }, sys.AMD64) + addF(simdPackage, "Int8x16.IsZero", opLen1(ssa.OpIsZeroVec, types.Types[types.TBOOL]), sys.AMD64) addF(simdPackage, "Int16x8.IsZero", opLen1(ssa.OpIsZeroVec, types.Types[types.TBOOL]), sys.AMD64) addF(simdPackage, "Int32x4.IsZero", opLen1(ssa.OpIsZeroVec, types.Types[types.TBOOL]), sys.AMD64) @@ -1630,9 +1631,126 @@ func initIntrinsics(cfg *intrinsicBuildConfig) { addF(simdPackage, "Uint16x16.IsZero", opLen1(ssa.OpIsZeroVec, types.Types[types.TBOOL]), sys.AMD64) addF(simdPackage, "Uint32x8.IsZero", opLen1(ssa.OpIsZeroVec, types.Types[types.TBOOL]), sys.AMD64) addF(simdPackage, "Uint64x4.IsZero", opLen1(ssa.OpIsZeroVec, types.Types[types.TBOOL]), sys.AMD64) + + sfp := func(method string, hwop ssa.Op, vectype *types.Type) { + addF("simd", method, + func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { + x, a, b, c, d, y := args[0], args[1], args[2], args[3], args[4], args[5] + if a.Op == ssa.OpConst8 && b.Op == ssa.OpConst8 && c.Op == ssa.OpConst8 && d.Op == ssa.OpConst8 { + return selectFromPair(x, a, b, c, d, y, s, hwop, vectype) + } else { + return s.callResult(n, callNormal) + } + }, + sys.AMD64) + } + + sfp("Int32x4.SelectFromPair", ssa.OpconcatSelectedConstantInt32x4, types.TypeVec128) + sfp("Uint32x4.SelectFromPair", ssa.OpconcatSelectedConstantUint32x4, types.TypeVec128) + sfp("Float32x4.SelectFromPair", ssa.OpconcatSelectedConstantFloat32x4, types.TypeVec128) + + sfp("Int32x8.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedInt32x8, types.TypeVec256) + sfp("Uint32x8.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedUint32x8, types.TypeVec256) + sfp("Float32x8.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedFloat32x8, types.TypeVec256) + + sfp("Int32x16.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedInt32x16, types.TypeVec512) + sfp("Uint32x16.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedUint32x16, types.TypeVec512) + sfp("Float32x16.SelectFromPairGrouped", ssa.OpconcatSelectedConstantGroupedFloat32x16, types.TypeVec512) + } } +func cscimm(a, b, c, d uint8) int64 { + return se(a + b<<2 + c<<4 + d<<6) +} + +const ( + _LLLL = iota + _HLLL + _LHLL + _HHLL + _LLHL + _HLHL + _LHHL + _HHHL + _LLLH + _HLLH + _LHLH + _HHLH + _LLHH + _HLHH + _LHHH + _HHHH +) + +func selectFromPair(x, _a, _b, _c, _d, y *ssa.Value, s *state, op ssa.Op, t *types.Type) *ssa.Value { + a, b, c, d := uint8(_a.AuxInt8()), uint8(_b.AuxInt8()), uint8(_c.AuxInt8()), uint8(_d.AuxInt8()) + pattern := a>>2 + (b&4)>>1 + (c & 4) + (d&4)<<1 + + a, b, c, d = a&3, b&3, c&3, d&3 + + switch pattern { + case _LLLL: + // TODO DETECT 0,1,2,3, 0,0,0,0 + return s.newValue2I(op, t, cscimm(a, b, c, d), x, x) + case _HHHH: + // TODO DETECT 0,1,2,3, 0,0,0,0 + return s.newValue2I(op, t, cscimm(a, b, c, d), y, y) + case _LLHH: + return s.newValue2I(op, t, cscimm(a, b, c, d), x, y) + case _HHLL: + return s.newValue2I(op, t, cscimm(a, b, c, d), y, x) + + case _HLLL: + z := s.newValue2I(op, t, cscimm(a, a, b, b), y, x) + return s.newValue2I(op, t, cscimm(0, 2, c, d), z, x) + case _LHLL: + z := s.newValue2I(op, t, cscimm(a, a, b, b), x, y) + return s.newValue2I(op, t, cscimm(0, 2, c, d), z, x) + case _HLHH: + z := s.newValue2I(op, t, cscimm(a, a, b, b), y, x) + return s.newValue2I(op, t, cscimm(0, 2, c, d), z, y) + case _LHHH: + z := s.newValue2I(op, t, cscimm(a, a, b, b), x, y) + return s.newValue2I(op, t, cscimm(0, 2, c, d), z, y) + + case _LLLH: + z := s.newValue2I(op, t, cscimm(c, c, d, d), x, y) + return s.newValue2I(op, t, cscimm(a, b, 0, 2), x, z) + case _LLHL: + z := s.newValue2I(op, t, cscimm(c, c, d, d), y, x) + return s.newValue2I(op, t, cscimm(a, b, 0, 2), x, z) + + case _HHLH: + z := s.newValue2I(op, t, cscimm(c, c, d, d), x, y) + return s.newValue2I(op, t, cscimm(a, b, 0, 2), y, z) + + case _HHHL: + z := s.newValue2I(op, t, cscimm(c, c, d, d), y, x) + return s.newValue2I(op, t, cscimm(a, b, 0, 2), y, z) + + case _LHLH: + z := s.newValue2I(op, t, cscimm(a, c, b, d), x, y) + return s.newValue2I(op, t, se(0b11_01_10_00), z, z) + case _HLHL: + z := s.newValue2I(op, t, cscimm(b, d, a, c), x, y) + return s.newValue2I(op, t, se(0b01_11_00_10), z, z) + case _HLLH: + z := s.newValue2I(op, t, cscimm(b, c, a, d), x, y) + return s.newValue2I(op, t, se(0b11_01_00_10), z, z) + case _LHHL: + z := s.newValue2I(op, t, cscimm(a, d, b, c), x, y) + return s.newValue2I(op, t, se(0b01_11_10_00), z, z) + } + panic("The preceding switch should have been exhaustive") +} + +// se smears the not-really-a-sign bit of a uint8 to conform to the conventions +// for representing AuxInt in ssa. +func se(x uint8) int64 { + return int64(int8(x)) +} + func opLen1(op ssa.Op, t *types.Type) func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return s.newValue1(op, t, args[0]) diff --git a/src/simd/internal/simd_test/simd_test.go b/src/simd/internal/simd_test/simd_test.go index f05c6d6f66e..6deadde45e6 100644 --- a/src/simd/internal/simd_test/simd_test.go +++ b/src/simd/internal/simd_test/simd_test.go @@ -594,3 +594,224 @@ func TestIsZero(t *testing.T) { t.Errorf("Result incorrect, want true, got false") } } + +func TestSelectFromPairConst(t *testing.T) { + x := simd.LoadInt32x4Slice([]int32{0, 1, 2, 3}) + y := simd.LoadInt32x4Slice([]int32{4, 5, 6, 7}) + + llll := x.SelectFromPair(0, 1, 2, 3, y) + hhhh := x.SelectFromPair(4, 5, 6, 7, y) + llhh := x.SelectFromPair(0, 1, 6, 7, y) + hhll := x.SelectFromPair(6, 7, 0, 1, y) + + lllh := x.SelectFromPair(0, 1, 2, 7, y) + llhl := x.SelectFromPair(0, 1, 7, 2, y) + lhll := x.SelectFromPair(0, 7, 1, 2, y) + hlll := x.SelectFromPair(7, 0, 1, 2, y) + + hhhl := x.SelectFromPair(4, 5, 6, 0, y) + hhlh := x.SelectFromPair(4, 5, 0, 6, y) + hlhh := x.SelectFromPair(4, 0, 5, 6, y) + lhhh := x.SelectFromPair(0, 4, 5, 6, y) + + lhlh := x.SelectFromPair(0, 4, 1, 5, y) + hlhl := x.SelectFromPair(4, 0, 5, 1, y) + lhhl := x.SelectFromPair(0, 4, 5, 1, y) + hllh := x.SelectFromPair(4, 0, 1, 5, y) + + r := make([]int32, 4, 4) + + foo := func(v simd.Int32x4, a, b, c, d int32) { + v.StoreSlice(r) + checkSlices[int32](t, r, []int32{a, b, c, d}) + } + + foo(llll, 0, 1, 2, 3) + foo(hhhh, 4, 5, 6, 7) + foo(llhh, 0, 1, 6, 7) + foo(hhll, 6, 7, 0, 1) + + foo(lllh, 0, 1, 2, 7) + foo(llhl, 0, 1, 7, 2) + foo(lhll, 0, 7, 1, 2) + foo(hlll, 7, 0, 1, 2) + + foo(hhhl, 4, 5, 6, 0) + foo(hhlh, 4, 5, 0, 6) + foo(hlhh, 4, 0, 5, 6) + foo(lhhh, 0, 4, 5, 6) + + foo(lhlh, 0, 4, 1, 5) + foo(hlhl, 4, 0, 5, 1) + foo(lhhl, 0, 4, 5, 1) + foo(hllh, 4, 0, 1, 5) +} + +//go:noinline +func selectFromPairInt32x4(x simd.Int32x4, a, b, c, d uint8, y simd.Int32x4) simd.Int32x4 { + return x.SelectFromPair(a, b, c, d, y) +} + +func TestSelectFromPairVar(t *testing.T) { + x := simd.LoadInt32x4Slice([]int32{0, 1, 2, 3}) + y := simd.LoadInt32x4Slice([]int32{4, 5, 6, 7}) + + llll := selectFromPairInt32x4(x, 0, 1, 2, 3, y) + hhhh := selectFromPairInt32x4(x, 4, 5, 6, 7, y) + llhh := selectFromPairInt32x4(x, 0, 1, 6, 7, y) + hhll := selectFromPairInt32x4(x, 6, 7, 0, 1, y) + + lllh := selectFromPairInt32x4(x, 0, 1, 2, 7, y) + llhl := selectFromPairInt32x4(x, 0, 1, 7, 2, y) + lhll := selectFromPairInt32x4(x, 0, 7, 1, 2, y) + hlll := selectFromPairInt32x4(x, 7, 0, 1, 2, y) + + hhhl := selectFromPairInt32x4(x, 4, 5, 6, 0, y) + hhlh := selectFromPairInt32x4(x, 4, 5, 0, 6, y) + hlhh := selectFromPairInt32x4(x, 4, 0, 5, 6, y) + lhhh := selectFromPairInt32x4(x, 0, 4, 5, 6, y) + + lhlh := selectFromPairInt32x4(x, 0, 4, 1, 5, y) + hlhl := selectFromPairInt32x4(x, 4, 0, 5, 1, y) + lhhl := selectFromPairInt32x4(x, 0, 4, 5, 1, y) + hllh := selectFromPairInt32x4(x, 4, 0, 1, 5, y) + + r := make([]int32, 4, 4) + + foo := func(v simd.Int32x4, a, b, c, d int32) { + v.StoreSlice(r) + checkSlices[int32](t, r, []int32{a, b, c, d}) + } + + foo(llll, 0, 1, 2, 3) + foo(hhhh, 4, 5, 6, 7) + foo(llhh, 0, 1, 6, 7) + foo(hhll, 6, 7, 0, 1) + + foo(lllh, 0, 1, 2, 7) + foo(llhl, 0, 1, 7, 2) + foo(lhll, 0, 7, 1, 2) + foo(hlll, 7, 0, 1, 2) + + foo(hhhl, 4, 5, 6, 0) + foo(hhlh, 4, 5, 0, 6) + foo(hlhh, 4, 0, 5, 6) + foo(lhhh, 0, 4, 5, 6) + + foo(lhlh, 0, 4, 1, 5) + foo(hlhl, 4, 0, 5, 1) + foo(lhhl, 0, 4, 5, 1) + foo(hllh, 4, 0, 1, 5) +} + +func TestSelectFromPairConstGroupedFloat32x8(t *testing.T) { + x := simd.LoadFloat32x8Slice([]float32{0, 1, 2, 3, 10, 11, 12, 13}) + y := simd.LoadFloat32x8Slice([]float32{4, 5, 6, 7, 14, 15, 16, 17}) + + llll := x.SelectFromPairGrouped(0, 1, 2, 3, y) + hhhh := x.SelectFromPairGrouped(4, 5, 6, 7, y) + llhh := x.SelectFromPairGrouped(0, 1, 6, 7, y) + hhll := x.SelectFromPairGrouped(6, 7, 0, 1, y) + + lllh := x.SelectFromPairGrouped(0, 1, 2, 7, y) + llhl := x.SelectFromPairGrouped(0, 1, 7, 2, y) + lhll := x.SelectFromPairGrouped(0, 7, 1, 2, y) + hlll := x.SelectFromPairGrouped(7, 0, 1, 2, y) + + hhhl := x.SelectFromPairGrouped(4, 5, 6, 0, y) + hhlh := x.SelectFromPairGrouped(4, 5, 0, 6, y) + hlhh := x.SelectFromPairGrouped(4, 0, 5, 6, y) + lhhh := x.SelectFromPairGrouped(0, 4, 5, 6, y) + + lhlh := x.SelectFromPairGrouped(0, 4, 1, 5, y) + hlhl := x.SelectFromPairGrouped(4, 0, 5, 1, y) + lhhl := x.SelectFromPairGrouped(0, 4, 5, 1, y) + hllh := x.SelectFromPairGrouped(4, 0, 1, 5, y) + + r := make([]float32, 8, 8) + + foo := func(v simd.Float32x8, a, b, c, d float32) { + v.StoreSlice(r) + checkSlices[float32](t, r, []float32{a, b, c, d, 10 + a, 10 + b, 10 + c, 10 + d}) + } + + foo(llll, 0, 1, 2, 3) + foo(hhhh, 4, 5, 6, 7) + foo(llhh, 0, 1, 6, 7) + foo(hhll, 6, 7, 0, 1) + + foo(lllh, 0, 1, 2, 7) + foo(llhl, 0, 1, 7, 2) + foo(lhll, 0, 7, 1, 2) + foo(hlll, 7, 0, 1, 2) + + foo(hhhl, 4, 5, 6, 0) + foo(hhlh, 4, 5, 0, 6) + foo(hlhh, 4, 0, 5, 6) + foo(lhhh, 0, 4, 5, 6) + + foo(lhlh, 0, 4, 1, 5) + foo(hlhl, 4, 0, 5, 1) + foo(lhhl, 0, 4, 5, 1) + foo(hllh, 4, 0, 1, 5) +} + +func TestSelectFromPairConstGroupedUint32x16(t *testing.T) { + if !simd.HasAVX512() { + t.Skip("Test requires HasAVX512, not available on this hardware") + return + } + x := simd.LoadUint32x16Slice([]uint32{0, 1, 2, 3, 10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33}) + y := simd.LoadUint32x16Slice([]uint32{4, 5, 6, 7, 14, 15, 16, 17, 24, 25, 26, 27, 34, 35, 36, 37}) + + llll := x.SelectFromPairGrouped(0, 1, 2, 3, y) + hhhh := x.SelectFromPairGrouped(4, 5, 6, 7, y) + llhh := x.SelectFromPairGrouped(0, 1, 6, 7, y) + hhll := x.SelectFromPairGrouped(6, 7, 0, 1, y) + + lllh := x.SelectFromPairGrouped(0, 1, 2, 7, y) + llhl := x.SelectFromPairGrouped(0, 1, 7, 2, y) + lhll := x.SelectFromPairGrouped(0, 7, 1, 2, y) + hlll := x.SelectFromPairGrouped(7, 0, 1, 2, y) + + hhhl := x.SelectFromPairGrouped(4, 5, 6, 0, y) + hhlh := x.SelectFromPairGrouped(4, 5, 0, 6, y) + hlhh := x.SelectFromPairGrouped(4, 0, 5, 6, y) + lhhh := x.SelectFromPairGrouped(0, 4, 5, 6, y) + + lhlh := x.SelectFromPairGrouped(0, 4, 1, 5, y) + hlhl := x.SelectFromPairGrouped(4, 0, 5, 1, y) + lhhl := x.SelectFromPairGrouped(0, 4, 5, 1, y) + hllh := x.SelectFromPairGrouped(4, 0, 1, 5, y) + + r := make([]uint32, 16, 16) + + foo := func(v simd.Uint32x16, a, b, c, d uint32) { + v.StoreSlice(r) + checkSlices[uint32](t, r, []uint32{a, b, c, d, + 10 + a, 10 + b, 10 + c, 10 + d, + 20 + a, 20 + b, 20 + c, 20 + d, + 30 + a, 30 + b, 30 + c, 30 + d, + }) + } + + foo(llll, 0, 1, 2, 3) + foo(hhhh, 4, 5, 6, 7) + foo(llhh, 0, 1, 6, 7) + foo(hhll, 6, 7, 0, 1) + + foo(lllh, 0, 1, 2, 7) + foo(llhl, 0, 1, 7, 2) + foo(lhll, 0, 7, 1, 2) + foo(hlll, 7, 0, 1, 2) + + foo(hhhl, 4, 5, 6, 0) + foo(hhlh, 4, 5, 0, 6) + foo(hlhh, 4, 0, 5, 6) + foo(lhhh, 0, 4, 5, 6) + + foo(lhlh, 0, 4, 1, 5) + foo(hlhl, 4, 0, 5, 1) + foo(lhhl, 0, 4, 5, 1) + foo(hllh, 4, 0, 1, 5) +} diff --git a/src/simd/pkginternal_test.go b/src/simd/pkginternal_test.go index 801cd0d17af..557a0537b4e 100644 --- a/src/simd/pkginternal_test.go +++ b/src/simd/pkginternal_test.go @@ -46,3 +46,187 @@ func TestConcatSelectedConstantGrouped32(t *testing.T) { z.StoreSlice(a) test_helpers.CheckSlices[uint32](t, a, []uint32{2, 0, 5, 7, 10, 8, 13, 15}) } + +func TestSelect2x4x32(t *testing.T) { + for a := range uint8(8) { + for b := range uint8(8) { + for c := range uint8(8) { + for d := range uint8(8) { + x := LoadInt32x4Slice([]int32{0, 1, 2, 3}) + y := LoadInt32x4Slice([]int32{4, 5, 6, 7}) + z := select2x4x32(x, a, b, c, d, y) + w := make([]int32, 4, 4) + z.StoreSlice(w) + if w[0] != int32(a) || w[1] != int32(b) || + w[2] != int32(c) || w[3] != int32(d) { + t.Errorf("Expected [%d %d %d %d] got %v", a, b, c, d, w) + } + } + } + } + } +} + +func TestSelect2x8x32Grouped(t *testing.T) { + for a := range uint8(8) { + for b := range uint8(8) { + for c := range uint8(8) { + for d := range uint8(8) { + x := LoadInt32x8Slice([]int32{0, 1, 2, 3, 10, 11, 12, 13}) + y := LoadInt32x8Slice([]int32{4, 5, 6, 7, 14, 15, 16, 17}) + z := select2x8x32Grouped(x, a, b, c, d, y) + w := make([]int32, 8, 8) + z.StoreSlice(w) + if w[0] != int32(a) || w[1] != int32(b) || + w[2] != int32(c) || w[3] != int32(d) || + w[4] != int32(10+a) || w[5] != int32(10+b) || + w[6] != int32(10+c) || w[7] != int32(10+d) { + t.Errorf("Expected [%d %d %d %d %d %d %d %d] got %v", a, b, c, d, 10+a, 10+b, 10+c, 10+d, w) + } + } + } + } + } +} + +// select2x4x32 returns a selection of 4 elements in x and y, numbered +// 0-7, where 0-3 are the four elements of x and 4-7 are the four elements +// of y. +func select2x4x32(x Int32x4, a, b, c, d uint8, y Int32x4) Int32x4 { + pattern := a>>2 + (b&4)>>1 + (c & 4) + (d&4)<<1 + + a, b, c, d = a&3, b&3, c&3, d&3 + + switch pattern { + case _LLLL: + return x.concatSelectedConstant(cscimm(a, b, c, d), x) + case _HHHH: + return y.concatSelectedConstant(cscimm(a, b, c, d), y) + case _LLHH: + return x.concatSelectedConstant(cscimm(a, b, c, d), y) + case _HHLL: + return y.concatSelectedConstant(cscimm(a, b, c, d), x) + + case _HLLL: + z := y.concatSelectedConstant(cscimm(a, a, b, b), x) + return z.concatSelectedConstant(cscimm(0, 2, c, d), x) + case _LHLL: + z := x.concatSelectedConstant(cscimm(a, a, b, b), y) + return z.concatSelectedConstant(cscimm(0, 2, c, d), x) + + case _HLHH: + z := y.concatSelectedConstant(cscimm(a, a, b, b), x) + return z.concatSelectedConstant(cscimm(0, 2, c, d), y) + case _LHHH: + z := x.concatSelectedConstant(cscimm(a, a, b, b), y) + return z.concatSelectedConstant(cscimm(0, 2, c, d), y) + + case _LLLH: + z := x.concatSelectedConstant(cscimm(c, c, d, d), y) + return x.concatSelectedConstant(cscimm(a, b, 0, 2), z) + case _LLHL: + z := y.concatSelectedConstant(cscimm(c, c, d, d), x) + return x.concatSelectedConstant(cscimm(a, b, 0, 2), z) + case _HHLH: + z := x.concatSelectedConstant(cscimm(c, c, d, d), y) + return y.concatSelectedConstant(cscimm(a, b, 0, 2), z) + case _HHHL: + z := y.concatSelectedConstant(cscimm(c, c, d, d), x) + return y.concatSelectedConstant(cscimm(a, b, 0, 2), z) + + case _LHLH: + z := x.concatSelectedConstant(cscimm(a, c, b, d), y) + return z.concatSelectedConstant(0b11_01_10_00 /* =cscimm(0, 2, 1, 3) */, z) + case _HLHL: + z := x.concatSelectedConstant(cscimm(b, d, a, c), y) + return z.concatSelectedConstant(0b01_11_00_10 /* =cscimm(2, 0, 3, 1) */, z) + case _HLLH: + z := x.concatSelectedConstant(cscimm(b, c, a, d), y) + return z.concatSelectedConstant(0b11_01_00_10 /* =cscimm(2, 0, 1, 3) */, z) + case _LHHL: + z := x.concatSelectedConstant(cscimm(a, d, b, c), y) + return z.concatSelectedConstant(0b01_11_10_00 /* =cscimm(0, 2, 3, 1) */, z) + } + panic("missing case, switch should be exhaustive") +} + +// select2x8x32Grouped returns a pair of selection of 4 elements in x and y, +// numbered 0-7, where 0-3 are the four elements of x's two groups (lower and +// upper 128 bits) and 4-7 are the four elements of y's two groups. + +func select2x8x32Grouped(x Int32x8, a, b, c, d uint8, y Int32x8) Int32x8 { + // selections as being expressible in the concatSelectedConstant pattern, + // or not. Classification is by H and L, where H is a selection from 4-7 + // and L is a selection from 0-3. + // _LLHH -> CSC(x,y, a, b, c&3, d&3) + // _HHLL -> CSC(y,x, a&3, b&3, c, d) + // _LLLL -> CSC(x,x, a, b, c, d) + // _HHHH -> CSC(y,y, a&3, b&3, c&3, d&3) + + // _LLLH -> z = CSC(x, y, c, c, d&3, d&3); CSC(x, z, a, b, 0, 2) + // _LLHL -> z = CSC(x, y, c&3, c&3, d, d); CSC(x, z, a, b, 0, 2) + // _HHLH -> z = CSC(x, y, c, c, d&3, d&3); CSC(y, z, a&3, b&3, 0, 2) + // _HHHL -> z = CSC(x, y, c&3, c&3, d, d); CSC(y, z, a&3, b&3, 0, 2) + + // _LHLL -> z = CSC(x, y, a, a, b&3, b&3); CSC(z, x, 0, 2, c, d) + // etc + + // _LHLH -> z = CSC(x, y, a, c, b&3, d&3); CSC(z, z, 0, 2, 1, 3) + // _HLHL -> z = CSC(x, y, b, d, a&3, c&3); CSC(z, z, 2, 0, 3, 1) + + pattern := a>>2 + (b&4)>>1 + (c & 4) + (d&4)<<1 + + a, b, c, d = a&3, b&3, c&3, d&3 + + switch pattern { + case _LLLL: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + case _HHHH: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _LLHH: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _HHLL: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + + case _HLLL: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + case _LHLL: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + + case _HLHH: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + case _LHHH: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + + case _LLLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _LLHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + + case _LHLH: + z := x.concatSelectedConstantGrouped(cscimm(a, c, b, d), y) + return z.concatSelectedConstantGrouped(0b11_01_10_00 /* =cscimm(0, 2, 1, 3) */, z) + case _HLHL: + z := x.concatSelectedConstantGrouped(cscimm(b, d, a, c), y) + return z.concatSelectedConstantGrouped(0b01_11_00_10 /* =cscimm(2, 0, 3, 1) */, z) + case _HLLH: + z := x.concatSelectedConstantGrouped(cscimm(b, c, a, d), y) + return z.concatSelectedConstantGrouped(0b11_01_00_10 /* =cscimm(2, 0, 1, 3) */, z) + case _LHHL: + z := x.concatSelectedConstantGrouped(cscimm(a, d, b, c), y) + return z.concatSelectedConstantGrouped(0b01_11_10_00 /* =cscimm(0, 2, 3, 1) */, z) + } + panic("missing case, switch should be exhaustive") +} diff --git a/src/simd/shuffles_amd64.go b/src/simd/shuffles_amd64.go index 4445a88f31c..68c840730b3 100644 --- a/src/simd/shuffles_amd64.go +++ b/src/simd/shuffles_amd64.go @@ -13,3 +13,697 @@ package simd func (x Int32x4) FlattenedTranspose(y Int32x4) (a, b Int32x4) { return x.InterleaveLo(y), x.InterleaveHi(y) } + +// These constants represent the source pattern for the four parameters +// (a, b, c, d) passed to SelectFromPair and SelectFromPairGrouped. +// L means the element comes from the 'x' vector (Low), and +// H means it comes from the 'y' vector (High). +// The order of the letters corresponds to elements a, b, c, d. +// The underlying integer value is a bitmask where: +// Bit 0: Source of element 'a' (0 for x, 1 for y) +// Bit 1: Source of element 'b' (0 for x, 1 for y) +// Bit 2: Source of element 'c' (0 for x, 1 for y) +// Bit 3: Source of element 'd' (0 for x, 1 for y) +// Note that the least-significant bit is on the LEFT in this encoding. +const ( + _LLLL = iota // a:x, b:x, c:x, d:x + _HLLL // a:y, b:x, c:x, d:x + _LHLL // a:x, b:y, c:x, d:x + _HHLL // a:y, b:y, c:x, d:x + _LLHL // a:x, b:x, c:y, d:x + _HLHL // a:y, b:x, c:y, d:x + _LHHL // a:x, b:y, c:y, d:x + _HHHL // a:y, b:y, c:y, d:x + _LLLH // a:x, b:x, c:x, d:y + _HLLH // a:y, b:x, c:x, d:y + _LHLH // a:x, b:y, c:x, d:y + _HHLH // a:y, b:y, c:x, d:y + _LLHH // a:x, b:x, c:y, d:y + _HLHH // a:y, b:x, c:y, d:y + _LHHH // a:x, b:y, c:y, d:y + _HHHH // a:y, b:y, c:y, d:y +) + +// SelectFromPair returns the selection of four elements from the two +// vectors x and y, where selector values in the range 0-3 specify +// elements from x and values in the range 4-7 specify the 0-3 elements +// of y. When the selectors are constants and the selection can be +// implemented in a single instruction, it will be, otherwise it +// requires two. a is the source index of the least element in the +// output, and b, c, and d are the indices of the 2nd, 3rd, and 4th +// elements in the output. For example, +// {1,2,4,8}.SelectFromPair(2,3,5,7,{9,25,49,81}) returns {4,8,25,81} +// +// If the selectors are not constant this will translate to a function +// call. +// +// Asm: VSHUFPS, CPU Feature: AVX +func (x Int32x4) SelectFromPair(a, b, c, d uint8, y Int32x4) Int32x4 { + // pattern gets the concatenation of "x or y?" bits + // (0 == x, 1 == y) + // This will determine operand choice/order and whether a second + // instruction is needed. + pattern := a>>2 + (b&4)>>1 + (c & 4) + (d&4)<<1 + + // a-d are masked down to their offsets within x or y + // this is not necessary for x, but this is easier on the + // eyes and reduces the risk of an error now or later. + a, b, c, d = a&3, b&3, c&3, d&3 + + switch pattern { + case _LLLL: + return x.concatSelectedConstant(cscimm(a, b, c, d), x) + case _HHHH: + return y.concatSelectedConstant(cscimm(a, b, c, d), y) + case _LLHH: + return x.concatSelectedConstant(cscimm(a, b, c, d), y) + case _HHLL: + return y.concatSelectedConstant(cscimm(a, b, c, d), x) + + case _HLLL: + z := y.concatSelectedConstant(cscimm(a, a, b, b), x) + return z.concatSelectedConstant(cscimm(0, 2, c, d), x) + case _LHLL: + z := x.concatSelectedConstant(cscimm(a, a, b, b), y) + return z.concatSelectedConstant(cscimm(0, 2, c, d), x) + + case _HLHH: + z := y.concatSelectedConstant(cscimm(a, a, b, b), x) + return z.concatSelectedConstant(cscimm(0, 2, c, d), y) + case _LHHH: + z := x.concatSelectedConstant(cscimm(a, a, b, b), y) + return z.concatSelectedConstant(cscimm(0, 2, c, d), y) + + case _LLLH: + z := x.concatSelectedConstant(cscimm(c, c, d, d), y) + return x.concatSelectedConstant(cscimm(a, b, 0, 2), z) + case _LLHL: + z := y.concatSelectedConstant(cscimm(c, c, d, d), x) + return x.concatSelectedConstant(cscimm(a, b, 0, 2), z) + case _HHLH: + z := x.concatSelectedConstant(cscimm(c, c, d, d), y) + return y.concatSelectedConstant(cscimm(a, b, 0, 2), z) + case _HHHL: + z := y.concatSelectedConstant(cscimm(c, c, d, d), x) + return y.concatSelectedConstant(cscimm(a, b, 0, 2), z) + + case _LHLH: + z := x.concatSelectedConstant(cscimm(a, c, b, d), y) + return z.concatSelectedConstant(0b11_01_10_00 /* =cscimm(0, 2, 1, 3) */, z) + case _HLHL: + z := x.concatSelectedConstant(cscimm(b, d, a, c), y) + return z.concatSelectedConstant(0b01_11_00_10 /* =cscimm(2, 0, 3, 1) */, z) + case _HLLH: + z := x.concatSelectedConstant(cscimm(b, c, a, d), y) + return z.concatSelectedConstant(0b11_01_00_10 /* =cscimm(2, 0, 1, 3) */, z) + case _LHHL: + z := x.concatSelectedConstant(cscimm(a, d, b, c), y) + return z.concatSelectedConstant(0b01_11_10_00 /* =cscimm(0, 2, 3, 1) */, z) + } + panic("missing case, switch should be exhaustive") +} + +// SelectFromPair returns the selection of four elements from the two +// vectors x and y, where selector values in the range 0-3 specify +// elements from x and values in the range 4-7 specify the 0-3 elements +// of y. When the selectors are constants and can be the selection +// can be implemented in a single instruction, it will be, otherwise +// it requires two. a is the source index of the least element in the +// output, and b, c, and d are the indices of the 2nd, 3rd, and 4th +// elements in the output. For example, +// {1,2,4,8}.SelectFromPair(2,3,5,7,{9,25,49,81}) returns {4,8,25,81} +// +// If the selectors are not constant this will translate to a function +// call. +// +// Asm: VSHUFPS, CPU Feature: AVX +func (x Uint32x4) SelectFromPair(a, b, c, d uint8, y Uint32x4) Uint32x4 { + pattern := a>>2 + (b&4)>>1 + (c & 4) + (d&4)<<1 + + a, b, c, d = a&3, b&3, c&3, d&3 + + switch pattern { + case _LLLL: + return x.concatSelectedConstant(cscimm(a, b, c, d), x) + case _HHHH: + return y.concatSelectedConstant(cscimm(a, b, c, d), y) + case _LLHH: + return x.concatSelectedConstant(cscimm(a, b, c, d), y) + case _HHLL: + return y.concatSelectedConstant(cscimm(a, b, c, d), x) + + case _HLLL: + z := y.concatSelectedConstant(cscimm(a, a, b, b), x) + return z.concatSelectedConstant(cscimm(0, 2, c, d), x) + case _LHLL: + z := x.concatSelectedConstant(cscimm(a, a, b, b), y) + return z.concatSelectedConstant(cscimm(0, 2, c, d), x) + + case _HLHH: + z := y.concatSelectedConstant(cscimm(a, a, b, b), x) + return z.concatSelectedConstant(cscimm(0, 2, c, d), y) + case _LHHH: + z := x.concatSelectedConstant(cscimm(a, a, b, b), y) + return z.concatSelectedConstant(cscimm(0, 2, c, d), y) + + case _LLLH: + z := x.concatSelectedConstant(cscimm(c, c, d, d), y) + return x.concatSelectedConstant(cscimm(a, b, 0, 2), z) + case _LLHL: + z := y.concatSelectedConstant(cscimm(c, c, d, d), x) + return x.concatSelectedConstant(cscimm(a, b, 0, 2), z) + case _HHLH: + z := x.concatSelectedConstant(cscimm(c, c, d, d), y) + return y.concatSelectedConstant(cscimm(a, b, 0, 2), z) + case _HHHL: + z := y.concatSelectedConstant(cscimm(c, c, d, d), x) + return y.concatSelectedConstant(cscimm(a, b, 0, 2), z) + + case _LHLH: + z := x.concatSelectedConstant(cscimm(a, c, b, d), y) + return z.concatSelectedConstant(0b11_01_10_00 /* =cscimm(0, 2, 1, 3) */, z) + case _HLHL: + z := x.concatSelectedConstant(cscimm(b, d, a, c), y) + return z.concatSelectedConstant(0b01_11_00_10 /* =cscimm(2, 0, 3, 1) */, z) + case _HLLH: + z := x.concatSelectedConstant(cscimm(b, c, a, d), y) + return z.concatSelectedConstant(0b11_01_00_10 /* =cscimm(2, 0, 1, 3) */, z) + case _LHHL: + z := x.concatSelectedConstant(cscimm(a, d, b, c), y) + return z.concatSelectedConstant(0b01_11_10_00 /* =cscimm(0, 2, 3, 1) */, z) + } + panic("missing case, switch should be exhaustive") +} + +// SelectFromPair returns the selection of four elements from the two +// vectors x and y, where selector values in the range 0-3 specify +// elements from x and values in the range 4-7 specify the 0-3 elements +// of y. When the selectors are constants and can be the selection +// can be implemented in a single instruction, it will be, otherwise +// it requires two. a is the source index of the least element in the +// output, and b, c, and d are the indices of the 2nd, 3rd, and 4th +// elements in the output. For example, +// {1,2,4,8}.SelectFromPair(2,3,5,7,{9,25,49,81}) returns {4,8,25,81} +// +// If the selectors are not constant this will translate to a function +// call. +// +// Asm: VSHUFPS, CPU Feature: AVX +func (x Float32x4) SelectFromPair(a, b, c, d uint8, y Float32x4) Float32x4 { + pattern := a>>2 + (b&4)>>1 + (c & 4) + (d&4)<<1 + + a, b, c, d = a&3, b&3, c&3, d&3 + + switch pattern { + case _LLLL: + return x.concatSelectedConstant(cscimm(a, b, c, d), x) + case _HHHH: + return y.concatSelectedConstant(cscimm(a, b, c, d), y) + case _LLHH: + return x.concatSelectedConstant(cscimm(a, b, c, d), y) + case _HHLL: + return y.concatSelectedConstant(cscimm(a, b, c, d), x) + + case _HLLL: + z := y.concatSelectedConstant(cscimm(a, a, b, b), x) + return z.concatSelectedConstant(cscimm(0, 2, c, d), x) + case _LHLL: + z := x.concatSelectedConstant(cscimm(a, a, b, b), y) + return z.concatSelectedConstant(cscimm(0, 2, c, d), x) + + case _HLHH: + z := y.concatSelectedConstant(cscimm(a, a, b, b), x) + return z.concatSelectedConstant(cscimm(0, 2, c, d), y) + case _LHHH: + z := x.concatSelectedConstant(cscimm(a, a, b, b), y) + return z.concatSelectedConstant(cscimm(0, 2, c, d), y) + + case _LLLH: + z := x.concatSelectedConstant(cscimm(c, c, d, d), y) + return x.concatSelectedConstant(cscimm(a, b, 0, 2), z) + case _LLHL: + z := y.concatSelectedConstant(cscimm(c, c, d, d), x) + return x.concatSelectedConstant(cscimm(a, b, 0, 2), z) + case _HHLH: + z := x.concatSelectedConstant(cscimm(c, c, d, d), y) + return y.concatSelectedConstant(cscimm(a, b, 0, 2), z) + case _HHHL: + z := y.concatSelectedConstant(cscimm(c, c, d, d), x) + return y.concatSelectedConstant(cscimm(a, b, 0, 2), z) + + case _LHLH: + z := x.concatSelectedConstant(cscimm(a, c, b, d), y) + return z.concatSelectedConstant(0b11_01_10_00 /* =cscimm(0, 2, 1, 3) */, z) + case _HLHL: + z := x.concatSelectedConstant(cscimm(b, d, a, c), y) + return z.concatSelectedConstant(0b01_11_00_10 /* =cscimm(2, 0, 3, 1) */, z) + case _HLLH: + z := x.concatSelectedConstant(cscimm(b, c, a, d), y) + return z.concatSelectedConstant(0b11_01_00_10 /* =cscimm(2, 0, 1, 3) */, z) + case _LHHL: + z := x.concatSelectedConstant(cscimm(a, d, b, c), y) + return z.concatSelectedConstant(0b01_11_10_00 /* =cscimm(0, 2, 3, 1) */, z) + } + panic("missing case, switch should be exhaustive") +} + +// SelectFromPairGrouped returns, for each of the two 128-bit halves of +// the vectors x and y, the selection of four elements from x and y, +// where selector values in the range 0-3 specify elements from x and +// values in the range 4-7 specify the 0-3 elements of y. +// When the selectors are constants and can be the selection +// can be implemented in a single instruction, it will be, otherwise +// it requires two. a is the source index of the least element in the +// output, and b, c, and d are the indices of the 2nd, 3rd, and 4th +// elements in the output. For example, +// {1,2,4,8,16,32,64,128}.SelectFromPair(2,3,5,7,{9,25,49,81,121,169,225,289}) +// +// returns {4,8,25,81,64,128,169,289} +// +// If the selectors are not constant this will translate to a function +// call. +// +// Asm: VSHUFPS, CPU Feature: AVX +func (x Int32x8) SelectFromPairGrouped(a, b, c, d uint8, y Int32x8) Int32x8 { + pattern := a>>2 + (b&4)>>1 + (c & 4) + (d&4)<<1 + + a, b, c, d = a&3, b&3, c&3, d&3 + + switch pattern { + case _LLLL: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + case _HHHH: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _LLHH: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _HHLL: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + + case _HLLL: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + case _LHLL: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + + case _HLHH: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + case _LHHH: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + + case _LLLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _LLHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + + case _LHLH: + z := x.concatSelectedConstantGrouped(cscimm(a, c, b, d), y) + return z.concatSelectedConstantGrouped(0b11_01_10_00 /* =cscimm(0, 2, 1, 3) */, z) + case _HLHL: + z := x.concatSelectedConstantGrouped(cscimm(b, d, a, c), y) + return z.concatSelectedConstantGrouped(0b01_11_00_10 /* =cscimm(2, 0, 3, 1) */, z) + case _HLLH: + z := x.concatSelectedConstantGrouped(cscimm(b, c, a, d), y) + return z.concatSelectedConstantGrouped(0b11_01_00_10 /* =cscimm(2, 0, 1, 3) */, z) + case _LHHL: + z := x.concatSelectedConstantGrouped(cscimm(a, d, b, c), y) + return z.concatSelectedConstantGrouped(0b01_11_10_00 /* =cscimm(0, 2, 3, 1) */, z) + } + panic("missing case, switch should be exhaustive") +} + +// SelectFromPairGrouped returns, for each of the two 128-bit halves of +// the vectors x and y, the selection of four elements from x and y, +// where selector values in the range 0-3 specify elements from x and +// values in the range 4-7 specify the 0-3 elements of y. +// When the selectors are constants and can be the selection +// can be implemented in a single instruction, it will be, otherwise +// it requires two. a is the source index of the least element in the +// output, and b, c, and d are the indices of the 2nd, 3rd, and 4th +// elements in the output. For example, +// {1,2,4,8,16,32,64,128}.SelectFromPair(2,3,5,7,{9,25,49,81,121,169,225,289}) +// +// returns {4,8,25,81,64,128,169,289} +// +// If the selectors are not constant this will translate to a function +// call. +// +// Asm: VSHUFPS, CPU Feature: AVX +func (x Uint32x8) SelectFromPairGrouped(a, b, c, d uint8, y Uint32x8) Uint32x8 { + pattern := a>>2 + (b&4)>>1 + (c & 4) + (d&4)<<1 + + a, b, c, d = a&3, b&3, c&3, d&3 + + switch pattern { + case _LLLL: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + case _HHHH: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _LLHH: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _HHLL: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + + case _HLLL: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + case _LHLL: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + + case _HLHH: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + case _LHHH: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + + case _LLLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _LLHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + + case _LHLH: + z := x.concatSelectedConstantGrouped(cscimm(a, c, b, d), y) + return z.concatSelectedConstantGrouped(0b11_01_10_00 /* =cscimm(0, 2, 1, 3) */, z) + case _HLHL: + z := x.concatSelectedConstantGrouped(cscimm(b, d, a, c), y) + return z.concatSelectedConstantGrouped(0b01_11_00_10 /* =cscimm(2, 0, 3, 1) */, z) + case _HLLH: + z := x.concatSelectedConstantGrouped(cscimm(b, c, a, d), y) + return z.concatSelectedConstantGrouped(0b11_01_00_10 /* =cscimm(2, 0, 1, 3) */, z) + case _LHHL: + z := x.concatSelectedConstantGrouped(cscimm(a, d, b, c), y) + return z.concatSelectedConstantGrouped(0b01_11_10_00 /* =cscimm(0, 2, 3, 1) */, z) + } + panic("missing case, switch should be exhaustive") +} + +// SelectFromPairGrouped returns, for each of the two 128-bit halves of +// the vectors x and y, the selection of four elements from x and y, +// where selector values in the range 0-3 specify elements from x and +// values in the range 4-7 specify the 0-3 elements of y. +// When the selectors are constants and can be the selection +// can be implemented in a single instruction, it will be, otherwise +// it requires two. a is the source index of the least element in the +// output, and b, c, and d are the indices of the 2nd, 3rd, and 4th +// elements in the output. For example, +// {1,2,4,8,16,32,64,128}.SelectFromPair(2,3,5,7,{9,25,49,81,121,169,225,289}) +// +// returns {4,8,25,81,64,128,169,289} +// +// If the selectors are not constant this will translate to a function +// call. +// +// Asm: VSHUFPS, CPU Feature: AVX +func (x Float32x8) SelectFromPairGrouped(a, b, c, d uint8, y Float32x8) Float32x8 { + pattern := a>>2 + (b&4)>>1 + (c & 4) + (d&4)<<1 + + a, b, c, d = a&3, b&3, c&3, d&3 + + switch pattern { + case _LLLL: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + case _HHHH: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _LLHH: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _HHLL: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + + case _HLLL: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + case _LHLL: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + + case _HLHH: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + case _LHHH: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + + case _LLLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _LLHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + + case _LHLH: + z := x.concatSelectedConstantGrouped(cscimm(a, c, b, d), y) + return z.concatSelectedConstantGrouped(0b11_01_10_00 /* =cscimm(0, 2, 1, 3) */, z) + case _HLHL: + z := x.concatSelectedConstantGrouped(cscimm(b, d, a, c), y) + return z.concatSelectedConstantGrouped(0b01_11_00_10 /* =cscimm(2, 0, 3, 1) */, z) + case _HLLH: + z := x.concatSelectedConstantGrouped(cscimm(b, c, a, d), y) + return z.concatSelectedConstantGrouped(0b11_01_00_10 /* =cscimm(2, 0, 1, 3) */, z) + case _LHHL: + z := x.concatSelectedConstantGrouped(cscimm(a, d, b, c), y) + return z.concatSelectedConstantGrouped(0b01_11_10_00 /* =cscimm(0, 2, 3, 1) */, z) + } + panic("missing case, switch should be exhaustive") +} + +// SelectFromPairGrouped returns, for each of the four 128-bit subvectors +// of the vectors x and y, the selection of four elements from x and y, +// where selector values in the range 0-3 specify elements from x and +// values in the range 4-7 specify the 0-3 elements of y. +// When the selectors are constants and can be the selection +// can be implemented in a single instruction, it will be, otherwise +// it requires two. +// +// If the selectors are not constant this will translate to a function +// call. +// +// Asm: VSHUFPS, CPU Feature: AVX512 +func (x Int32x16) SelectFromPairGrouped(a, b, c, d uint8, y Int32x16) Int32x16 { + pattern := a>>2 + (b&4)>>1 + (c & 4) + (d&4)<<1 + + a, b, c, d = a&3, b&3, c&3, d&3 + + switch pattern { + case _LLLL: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + case _HHHH: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _LLHH: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _HHLL: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + + case _HLLL: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + case _LHLL: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + + case _HLHH: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + case _LHHH: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + + case _LLLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _LLHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + + case _LHLH: + z := x.concatSelectedConstantGrouped(cscimm(a, c, b, d), y) + return z.concatSelectedConstantGrouped(0b11_01_10_00 /* =cscimm(0, 2, 1, 3) */, z) + case _HLHL: + z := x.concatSelectedConstantGrouped(cscimm(b, d, a, c), y) + return z.concatSelectedConstantGrouped(0b01_11_00_10 /* =cscimm(2, 0, 3, 1) */, z) + case _HLLH: + z := x.concatSelectedConstantGrouped(cscimm(b, c, a, d), y) + return z.concatSelectedConstantGrouped(0b11_01_00_10 /* =cscimm(2, 0, 1, 3) */, z) + case _LHHL: + z := x.concatSelectedConstantGrouped(cscimm(a, d, b, c), y) + return z.concatSelectedConstantGrouped(0b01_11_10_00 /* =cscimm(0, 2, 3, 1) */, z) + } + panic("missing case, switch should be exhaustive") +} + +// SelectFromPairGrouped returns, for each of the four 128-bit subvectors +// of the vectors x and y, the selection of four elements from x and y, +// where selector values in the range 0-3 specify elements from x and +// values in the range 4-7 specify the 0-3 elements of y. +// When the selectors are constants and can be the selection +// can be implemented in a single instruction, it will be, otherwise +// it requires two. +// +// If the selectors are not constant this will translate to a function +// call. +// +// Asm: VSHUFPS, CPU Feature: AVX512 +func (x Uint32x16) SelectFromPairGrouped(a, b, c, d uint8, y Uint32x16) Uint32x16 { + pattern := a>>2 + (b&4)>>1 + (c & 4) + (d&4)<<1 + + a, b, c, d = a&3, b&3, c&3, d&3 + + switch pattern { + case _LLLL: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + case _HHHH: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _LLHH: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _HHLL: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + + case _HLLL: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + case _LHLL: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + + case _HLHH: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + case _LHHH: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + + case _LLLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _LLHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + + case _LHLH: + z := x.concatSelectedConstantGrouped(cscimm(a, c, b, d), y) + return z.concatSelectedConstantGrouped(0b11_01_10_00 /* =cscimm(0, 2, 1, 3) */, z) + case _HLHL: + z := x.concatSelectedConstantGrouped(cscimm(b, d, a, c), y) + return z.concatSelectedConstantGrouped(0b01_11_00_10 /* =cscimm(2, 0, 3, 1) */, z) + case _HLLH: + z := x.concatSelectedConstantGrouped(cscimm(b, c, a, d), y) + return z.concatSelectedConstantGrouped(0b11_01_00_10 /* =cscimm(2, 0, 1, 3) */, z) + case _LHHL: + z := x.concatSelectedConstantGrouped(cscimm(a, d, b, c), y) + return z.concatSelectedConstantGrouped(0b01_11_10_00 /* =cscimm(0, 2, 3, 1) */, z) + } + panic("missing case, switch should be exhaustive") +} + +// SelectFromPairGrouped returns, for each of the four 128-bit subvectors +// of the vectors x and y, the selection of four elements from x and y, +// where selector values in the range 0-3 specify elements from x and +// values in the range 4-7 specify the 0-3 elements of y. +// When the selectors are constants and can be the selection +// can be implemented in a single instruction, it will be, otherwise +// it requires two. +// +// If the selectors are not constant this will translate to a function +// call. +// +// Asm: VSHUFPS, CPU Feature: AVX512 +func (x Float32x16) SelectFromPairGrouped(a, b, c, d uint8, y Float32x16) Float32x16 { + pattern := a>>2 + (b&4)>>1 + (c & 4) + (d&4)<<1 + + a, b, c, d = a&3, b&3, c&3, d&3 + + switch pattern { + case _LLLL: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + case _HHHH: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _LLHH: + return x.concatSelectedConstantGrouped(cscimm(a, b, c, d), y) + case _HHLL: + return y.concatSelectedConstantGrouped(cscimm(a, b, c, d), x) + + case _HLLL: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + case _LHLL: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), x) + + case _HLHH: + z := y.concatSelectedConstantGrouped(cscimm(a, a, b, b), x) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + case _LHHH: + z := x.concatSelectedConstantGrouped(cscimm(a, a, b, b), y) + return z.concatSelectedConstantGrouped(cscimm(0, 2, c, d), y) + + case _LLLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _LLHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return x.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHLH: + z := x.concatSelectedConstantGrouped(cscimm(c, c, d, d), y) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + case _HHHL: + z := y.concatSelectedConstantGrouped(cscimm(c, c, d, d), x) + return y.concatSelectedConstantGrouped(cscimm(a, b, 0, 2), z) + + case _LHLH: + z := x.concatSelectedConstantGrouped(cscimm(a, c, b, d), y) + return z.concatSelectedConstantGrouped(0b11_01_10_00 /* =cscimm(0, 2, 1, 3) */, z) + case _HLHL: + z := x.concatSelectedConstantGrouped(cscimm(b, d, a, c), y) + return z.concatSelectedConstantGrouped(0b01_11_00_10 /* =cscimm(2, 0, 3, 1) */, z) + case _HLLH: + z := x.concatSelectedConstantGrouped(cscimm(b, c, a, d), y) + return z.concatSelectedConstantGrouped(0b11_01_00_10 /* =cscimm(2, 0, 1, 3) */, z) + case _LHHL: + z := x.concatSelectedConstantGrouped(cscimm(a, d, b, c), y) + return z.concatSelectedConstantGrouped(0b01_11_10_00 /* =cscimm(0, 2, 3, 1) */, z) + } + panic("missing case, switch should be exhaustive") +} + +// cscimm converts the 4 vector element indices into a single +// uint8 for use as an immediate. +func cscimm(a, b, c, d uint8) uint8 { + return uint8(a + b<<2 + c<<4 + d<<6) +}