[dev.simd] simd, cmd/compile: added .Masked() peephole opt for many operations.

This should get many of the low-hanging and important fruit.
Others can follow later.
It needs more testing.

Change-Id: Ic186b075987e85c87197ef9e1ca0b4f33ff96697
Reviewed-on: https://go-review.googlesource.com/c/go/+/697515
Reviewed-by: Junyang Shao <shaojunyang@google.com>
Commit-Queue: David Chase <drchase@google.com>
TryBot-Bypass: David Chase <drchase@google.com>
This commit is contained in:
David Chase 2025-08-19 17:54:38 -04:00
parent 1334285862
commit cf31b15635
4 changed files with 2847 additions and 4 deletions

View file

@ -851,6 +851,15 @@
(ShiftAllLeftConcatUint64x2 ...) => (VPSHLDQ128 ...) (ShiftAllLeftConcatUint64x2 ...) => (VPSHLDQ128 ...)
(ShiftAllLeftConcatUint64x4 ...) => (VPSHLDQ256 ...) (ShiftAllLeftConcatUint64x4 ...) => (VPSHLDQ256 ...)
(ShiftAllLeftConcatUint64x8 ...) => (VPSHLDQ512 ...) (ShiftAllLeftConcatUint64x8 ...) => (VPSHLDQ512 ...)
(VPSLLWMasked128 x (MOVQconst [c]) mask) => (VPSLLWMasked128const [uint8(c)] x mask)
(VPSLLWMasked256 x (MOVQconst [c]) mask) => (VPSLLWMasked256const [uint8(c)] x mask)
(VPSLLWMasked512 x (MOVQconst [c]) mask) => (VPSLLWMasked512const [uint8(c)] x mask)
(VPSLLDMasked128 x (MOVQconst [c]) mask) => (VPSLLDMasked128const [uint8(c)] x mask)
(VPSLLDMasked256 x (MOVQconst [c]) mask) => (VPSLLDMasked256const [uint8(c)] x mask)
(VPSLLDMasked512 x (MOVQconst [c]) mask) => (VPSLLDMasked512const [uint8(c)] x mask)
(VPSLLQMasked128 x (MOVQconst [c]) mask) => (VPSLLQMasked128const [uint8(c)] x mask)
(VPSLLQMasked256 x (MOVQconst [c]) mask) => (VPSLLQMasked256const [uint8(c)] x mask)
(VPSLLQMasked512 x (MOVQconst [c]) mask) => (VPSLLQMasked512const [uint8(c)] x mask)
(ShiftAllRightInt16x8 ...) => (VPSRAW128 ...) (ShiftAllRightInt16x8 ...) => (VPSRAW128 ...)
(VPSRAW128 x (MOVQconst [c])) => (VPSRAW128const [uint8(c)] x) (VPSRAW128 x (MOVQconst [c])) => (VPSRAW128const [uint8(c)] x)
(ShiftAllRightInt16x16 ...) => (VPSRAW256 ...) (ShiftAllRightInt16x16 ...) => (VPSRAW256 ...)
@ -896,6 +905,15 @@
(ShiftAllRightConcatUint64x2 ...) => (VPSHRDQ128 ...) (ShiftAllRightConcatUint64x2 ...) => (VPSHRDQ128 ...)
(ShiftAllRightConcatUint64x4 ...) => (VPSHRDQ256 ...) (ShiftAllRightConcatUint64x4 ...) => (VPSHRDQ256 ...)
(ShiftAllRightConcatUint64x8 ...) => (VPSHRDQ512 ...) (ShiftAllRightConcatUint64x8 ...) => (VPSHRDQ512 ...)
(VPSRAWMasked128 x (MOVQconst [c]) mask) => (VPSRAWMasked128const [uint8(c)] x mask)
(VPSRAWMasked256 x (MOVQconst [c]) mask) => (VPSRAWMasked256const [uint8(c)] x mask)
(VPSRAWMasked512 x (MOVQconst [c]) mask) => (VPSRAWMasked512const [uint8(c)] x mask)
(VPSRADMasked128 x (MOVQconst [c]) mask) => (VPSRADMasked128const [uint8(c)] x mask)
(VPSRADMasked256 x (MOVQconst [c]) mask) => (VPSRADMasked256const [uint8(c)] x mask)
(VPSRADMasked512 x (MOVQconst [c]) mask) => (VPSRADMasked512const [uint8(c)] x mask)
(VPSRAQMasked128 x (MOVQconst [c]) mask) => (VPSRAQMasked128const [uint8(c)] x mask)
(VPSRAQMasked256 x (MOVQconst [c]) mask) => (VPSRAQMasked256const [uint8(c)] x mask)
(VPSRAQMasked512 x (MOVQconst [c]) mask) => (VPSRAQMasked512const [uint8(c)] x mask)
(ShiftLeftInt16x8 ...) => (VPSLLVW128 ...) (ShiftLeftInt16x8 ...) => (VPSLLVW128 ...)
(ShiftLeftInt16x16 ...) => (VPSLLVW256 ...) (ShiftLeftInt16x16 ...) => (VPSLLVW256 ...)
(ShiftLeftInt16x32 ...) => (VPSLLVW512 ...) (ShiftLeftInt16x32 ...) => (VPSLLVW512 ...)
@ -1086,3 +1104,166 @@
(moveMaskedUint16x32 x mask) => (VMOVDQU16Masked512 x (VPMOVVec16x32ToM <types.TypeMask> mask)) (moveMaskedUint16x32 x mask) => (VMOVDQU16Masked512 x (VPMOVVec16x32ToM <types.TypeMask> mask))
(moveMaskedUint32x16 x mask) => (VMOVDQU32Masked512 x (VPMOVVec32x16ToM <types.TypeMask> mask)) (moveMaskedUint32x16 x mask) => (VMOVDQU32Masked512 x (VPMOVVec32x16ToM <types.TypeMask> mask))
(moveMaskedUint64x8 x mask) => (VMOVDQU64Masked512 x (VPMOVVec64x8ToM <types.TypeMask> mask)) (moveMaskedUint64x8 x mask) => (VMOVDQU64Masked512 x (VPMOVVec64x8ToM <types.TypeMask> mask))
(VMOVDQU8Masked512 (VPABSB512 x) mask) => (VPABSBMasked512 x mask)
(VMOVDQU16Masked512 (VPABSW512 x) mask) => (VPABSWMasked512 x mask)
(VMOVDQU32Masked512 (VPABSD512 x) mask) => (VPABSDMasked512 x mask)
(VMOVDQU64Masked512 (VPABSQ512 x) mask) => (VPABSQMasked512 x mask)
(VMOVDQU32Masked512 (VPDPWSSD512 x y z) mask) => (VPDPWSSDMasked512 x y z mask)
(VMOVDQU32Masked512 (VPDPWSSDS512 x y z) mask) => (VPDPWSSDSMasked512 x y z mask)
(VMOVDQU32Masked512 (VPDPBUSD512 x y z) mask) => (VPDPBUSDMasked512 x y z mask)
(VMOVDQU32Masked512 (VPDPBUSDS512 x y z) mask) => (VPDPBUSDSMasked512 x y z mask)
(VMOVDQU32Masked512 (VADDPS512 x y) mask) => (VADDPSMasked512 x y mask)
(VMOVDQU64Masked512 (VADDPD512 x y) mask) => (VADDPDMasked512 x y mask)
(VMOVDQU8Masked512 (VPADDB512 x y) mask) => (VPADDBMasked512 x y mask)
(VMOVDQU16Masked512 (VPADDW512 x y) mask) => (VPADDWMasked512 x y mask)
(VMOVDQU32Masked512 (VPADDD512 x y) mask) => (VPADDDMasked512 x y mask)
(VMOVDQU64Masked512 (VPADDQ512 x y) mask) => (VPADDQMasked512 x y mask)
(VMOVDQU8Masked512 (VPADDSB512 x y) mask) => (VPADDSBMasked512 x y mask)
(VMOVDQU16Masked512 (VPADDSW512 x y) mask) => (VPADDSWMasked512 x y mask)
(VMOVDQU8Masked512 (VPADDUSB512 x y) mask) => (VPADDUSBMasked512 x y mask)
(VMOVDQU16Masked512 (VPADDUSW512 x y) mask) => (VPADDUSWMasked512 x y mask)
(VMOVDQU32Masked512 (VPANDD512 x y) mask) => (VPANDDMasked512 x y mask)
(VMOVDQU64Masked512 (VPANDQ512 x y) mask) => (VPANDQMasked512 x y mask)
(VMOVDQU32Masked512 (VPANDND512 x y) mask) => (VPANDNDMasked512 x y mask)
(VMOVDQU64Masked512 (VPANDNQ512 x y) mask) => (VPANDNQMasked512 x y mask)
(VMOVDQU8Masked512 (VPAVGB512 x y) mask) => (VPAVGBMasked512 x y mask)
(VMOVDQU16Masked512 (VPAVGW512 x y) mask) => (VPAVGWMasked512 x y mask)
(VMOVDQU32Masked512 (VBROADCASTSS512 x) mask) => (VBROADCASTSSMasked512 x mask)
(VMOVDQU64Masked512 (VBROADCASTSD512 x) mask) => (VBROADCASTSDMasked512 x mask)
(VMOVDQU8Masked512 (VPBROADCASTB512 x) mask) => (VPBROADCASTBMasked512 x mask)
(VMOVDQU16Masked512 (VPBROADCASTW512 x) mask) => (VPBROADCASTWMasked512 x mask)
(VMOVDQU32Masked512 (VPBROADCASTD512 x) mask) => (VPBROADCASTDMasked512 x mask)
(VMOVDQU64Masked512 (VPBROADCASTQ512 x) mask) => (VPBROADCASTQMasked512 x mask)
(VMOVDQU32Masked512 (VRNDSCALEPS512 [a] x) mask) => (VRNDSCALEPSMasked512 [a] x mask)
(VMOVDQU64Masked512 (VRNDSCALEPD512 [a] x) mask) => (VRNDSCALEPDMasked512 [a] x mask)
(VMOVDQU32Masked512 (VREDUCEPS512 [a] x) mask) => (VREDUCEPSMasked512 [a] x mask)
(VMOVDQU64Masked512 (VREDUCEPD512 [a] x) mask) => (VREDUCEPDMasked512 [a] x mask)
(VMOVDQU32Masked512 (VCVTTPS2DQ512 x) mask) => (VCVTTPS2DQMasked512 x mask)
(VMOVDQU8Masked512 (VPMOVZXBW512 x) mask) => (VPMOVZXBWMasked512 x mask)
(VMOVDQU32Masked512 (VCVTPS2UDQ512 x) mask) => (VCVTPS2UDQMasked512 x mask)
(VMOVDQU16Masked512 (VPMOVZXWD512 x) mask) => (VPMOVZXWDMasked512 x mask)
(VMOVDQU32Masked512 (VDIVPS512 x y) mask) => (VDIVPSMasked512 x y mask)
(VMOVDQU64Masked512 (VDIVPD512 x y) mask) => (VDIVPDMasked512 x y mask)
(VMOVDQU16Masked512 (VPMADDWD512 x y) mask) => (VPMADDWDMasked512 x y mask)
(VMOVDQU16Masked512 (VPMADDUBSW512 x y) mask) => (VPMADDUBSWMasked512 x y mask)
(VMOVDQU8Masked512 (VGF2P8AFFINEINVQB512 [a] x y) mask) => (VGF2P8AFFINEINVQBMasked512 [a] x y mask)
(VMOVDQU8Masked512 (VGF2P8AFFINEQB512 [a] x y) mask) => (VGF2P8AFFINEQBMasked512 [a] x y mask)
(VMOVDQU8Masked512 (VGF2P8MULB512 x y) mask) => (VGF2P8MULBMasked512 x y mask)
(VMOVDQU32Masked512 (VMAXPS512 x y) mask) => (VMAXPSMasked512 x y mask)
(VMOVDQU64Masked512 (VMAXPD512 x y) mask) => (VMAXPDMasked512 x y mask)
(VMOVDQU8Masked512 (VPMAXSB512 x y) mask) => (VPMAXSBMasked512 x y mask)
(VMOVDQU16Masked512 (VPMAXSW512 x y) mask) => (VPMAXSWMasked512 x y mask)
(VMOVDQU32Masked512 (VPMAXSD512 x y) mask) => (VPMAXSDMasked512 x y mask)
(VMOVDQU64Masked512 (VPMAXSQ512 x y) mask) => (VPMAXSQMasked512 x y mask)
(VMOVDQU8Masked512 (VPMAXUB512 x y) mask) => (VPMAXUBMasked512 x y mask)
(VMOVDQU16Masked512 (VPMAXUW512 x y) mask) => (VPMAXUWMasked512 x y mask)
(VMOVDQU32Masked512 (VPMAXUD512 x y) mask) => (VPMAXUDMasked512 x y mask)
(VMOVDQU64Masked512 (VPMAXUQ512 x y) mask) => (VPMAXUQMasked512 x y mask)
(VMOVDQU32Masked512 (VMINPS512 x y) mask) => (VMINPSMasked512 x y mask)
(VMOVDQU64Masked512 (VMINPD512 x y) mask) => (VMINPDMasked512 x y mask)
(VMOVDQU8Masked512 (VPMINSB512 x y) mask) => (VPMINSBMasked512 x y mask)
(VMOVDQU16Masked512 (VPMINSW512 x y) mask) => (VPMINSWMasked512 x y mask)
(VMOVDQU32Masked512 (VPMINSD512 x y) mask) => (VPMINSDMasked512 x y mask)
(VMOVDQU64Masked512 (VPMINSQ512 x y) mask) => (VPMINSQMasked512 x y mask)
(VMOVDQU8Masked512 (VPMINUB512 x y) mask) => (VPMINUBMasked512 x y mask)
(VMOVDQU16Masked512 (VPMINUW512 x y) mask) => (VPMINUWMasked512 x y mask)
(VMOVDQU32Masked512 (VPMINUD512 x y) mask) => (VPMINUDMasked512 x y mask)
(VMOVDQU64Masked512 (VPMINUQ512 x y) mask) => (VPMINUQMasked512 x y mask)
(VMOVDQU32Masked512 (VFMADD213PS512 x y z) mask) => (VFMADD213PSMasked512 x y z mask)
(VMOVDQU64Masked512 (VFMADD213PD512 x y z) mask) => (VFMADD213PDMasked512 x y z mask)
(VMOVDQU32Masked512 (VFMADDSUB213PS512 x y z) mask) => (VFMADDSUB213PSMasked512 x y z mask)
(VMOVDQU64Masked512 (VFMADDSUB213PD512 x y z) mask) => (VFMADDSUB213PDMasked512 x y z mask)
(VMOVDQU16Masked512 (VPMULHW512 x y) mask) => (VPMULHWMasked512 x y mask)
(VMOVDQU16Masked512 (VPMULHUW512 x y) mask) => (VPMULHUWMasked512 x y mask)
(VMOVDQU32Masked512 (VMULPS512 x y) mask) => (VMULPSMasked512 x y mask)
(VMOVDQU64Masked512 (VMULPD512 x y) mask) => (VMULPDMasked512 x y mask)
(VMOVDQU16Masked512 (VPMULLW512 x y) mask) => (VPMULLWMasked512 x y mask)
(VMOVDQU32Masked512 (VPMULLD512 x y) mask) => (VPMULLDMasked512 x y mask)
(VMOVDQU64Masked512 (VPMULLQ512 x y) mask) => (VPMULLQMasked512 x y mask)
(VMOVDQU32Masked512 (VFMSUBADD213PS512 x y z) mask) => (VFMSUBADD213PSMasked512 x y z mask)
(VMOVDQU64Masked512 (VFMSUBADD213PD512 x y z) mask) => (VFMSUBADD213PDMasked512 x y z mask)
(VMOVDQU8Masked512 (VPOPCNTB512 x) mask) => (VPOPCNTBMasked512 x mask)
(VMOVDQU16Masked512 (VPOPCNTW512 x) mask) => (VPOPCNTWMasked512 x mask)
(VMOVDQU32Masked512 (VPOPCNTD512 x) mask) => (VPOPCNTDMasked512 x mask)
(VMOVDQU64Masked512 (VPOPCNTQ512 x) mask) => (VPOPCNTQMasked512 x mask)
(VMOVDQU32Masked512 (VPORD512 x y) mask) => (VPORDMasked512 x y mask)
(VMOVDQU64Masked512 (VPORQ512 x y) mask) => (VPORQMasked512 x y mask)
(VMOVDQU8Masked512 (VPERMI2B512 x y z) mask) => (VPERMI2BMasked512 x y z mask)
(VMOVDQU16Masked512 (VPERMI2W512 x y z) mask) => (VPERMI2WMasked512 x y z mask)
(VMOVDQU32Masked512 (VPERMI2PS512 x y z) mask) => (VPERMI2PSMasked512 x y z mask)
(VMOVDQU32Masked512 (VPERMI2D512 x y z) mask) => (VPERMI2DMasked512 x y z mask)
(VMOVDQU64Masked512 (VPERMI2PD512 x y z) mask) => (VPERMI2PDMasked512 x y z mask)
(VMOVDQU64Masked512 (VPERMI2Q512 x y z) mask) => (VPERMI2QMasked512 x y z mask)
(VMOVDQU8Masked512 (VPERMB512 x y) mask) => (VPERMBMasked512 x y mask)
(VMOVDQU16Masked512 (VPERMW512 x y) mask) => (VPERMWMasked512 x y mask)
(VMOVDQU32Masked512 (VPERMPS512 x y) mask) => (VPERMPSMasked512 x y mask)
(VMOVDQU32Masked512 (VPERMD512 x y) mask) => (VPERMDMasked512 x y mask)
(VMOVDQU64Masked512 (VPERMPD512 x y) mask) => (VPERMPDMasked512 x y mask)
(VMOVDQU64Masked512 (VPERMQ512 x y) mask) => (VPERMQMasked512 x y mask)
(VMOVDQU32Masked512 (VRCP14PS512 x) mask) => (VRCP14PSMasked512 x mask)
(VMOVDQU64Masked512 (VRCP14PD512 x) mask) => (VRCP14PDMasked512 x mask)
(VMOVDQU32Masked512 (VRSQRT14PS512 x) mask) => (VRSQRT14PSMasked512 x mask)
(VMOVDQU64Masked512 (VRSQRT14PD512 x) mask) => (VRSQRT14PDMasked512 x mask)
(VMOVDQU32Masked512 (VPROLD512 [a] x) mask) => (VPROLDMasked512 [a] x mask)
(VMOVDQU64Masked512 (VPROLQ512 [a] x) mask) => (VPROLQMasked512 [a] x mask)
(VMOVDQU32Masked512 (VPRORD512 [a] x) mask) => (VPRORDMasked512 [a] x mask)
(VMOVDQU64Masked512 (VPRORQ512 [a] x) mask) => (VPRORQMasked512 [a] x mask)
(VMOVDQU32Masked512 (VPROLVD512 x y) mask) => (VPROLVDMasked512 x y mask)
(VMOVDQU64Masked512 (VPROLVQ512 x y) mask) => (VPROLVQMasked512 x y mask)
(VMOVDQU32Masked512 (VPRORVD512 x y) mask) => (VPRORVDMasked512 x y mask)
(VMOVDQU64Masked512 (VPRORVQ512 x y) mask) => (VPRORVQMasked512 x y mask)
(VMOVDQU32Masked512 (VSCALEFPS512 x y) mask) => (VSCALEFPSMasked512 x y mask)
(VMOVDQU64Masked512 (VSCALEFPD512 x y) mask) => (VSCALEFPDMasked512 x y mask)
(VMOVDQU16Masked512 (VPSHLDW512 [a] x y) mask) => (VPSHLDWMasked512 [a] x y mask)
(VMOVDQU32Masked512 (VPSHLDD512 [a] x y) mask) => (VPSHLDDMasked512 [a] x y mask)
(VMOVDQU64Masked512 (VPSHLDQ512 [a] x y) mask) => (VPSHLDQMasked512 [a] x y mask)
(VMOVDQU16Masked512 (VPSLLW512 x y) mask) => (VPSLLWMasked512 x y mask)
(VMOVDQU32Masked512 (VPSLLD512 x y) mask) => (VPSLLDMasked512 x y mask)
(VMOVDQU64Masked512 (VPSLLQ512 x y) mask) => (VPSLLQMasked512 x y mask)
(VMOVDQU16Masked512 (VPSHRDW512 [a] x y) mask) => (VPSHRDWMasked512 [a] x y mask)
(VMOVDQU32Masked512 (VPSHRDD512 [a] x y) mask) => (VPSHRDDMasked512 [a] x y mask)
(VMOVDQU64Masked512 (VPSHRDQ512 [a] x y) mask) => (VPSHRDQMasked512 [a] x y mask)
(VMOVDQU16Masked512 (VPSRAW512 x y) mask) => (VPSRAWMasked512 x y mask)
(VMOVDQU32Masked512 (VPSRAD512 x y) mask) => (VPSRADMasked512 x y mask)
(VMOVDQU64Masked512 (VPSRAQ512 x y) mask) => (VPSRAQMasked512 x y mask)
(VMOVDQU16Masked512 (VPSRLW512 x y) mask) => (VPSRLWMasked512 x y mask)
(VMOVDQU32Masked512 (VPSRLD512 x y) mask) => (VPSRLDMasked512 x y mask)
(VMOVDQU64Masked512 (VPSRLQ512 x y) mask) => (VPSRLQMasked512 x y mask)
(VMOVDQU16Masked512 (VPSHLDVW512 x y z) mask) => (VPSHLDVWMasked512 x y z mask)
(VMOVDQU32Masked512 (VPSHLDVD512 x y z) mask) => (VPSHLDVDMasked512 x y z mask)
(VMOVDQU64Masked512 (VPSHLDVQ512 x y z) mask) => (VPSHLDVQMasked512 x y z mask)
(VMOVDQU16Masked512 (VPSLLVW512 x y) mask) => (VPSLLVWMasked512 x y mask)
(VMOVDQU32Masked512 (VPSLLVD512 x y) mask) => (VPSLLVDMasked512 x y mask)
(VMOVDQU64Masked512 (VPSLLVQ512 x y) mask) => (VPSLLVQMasked512 x y mask)
(VMOVDQU16Masked512 (VPSHRDVW512 x y z) mask) => (VPSHRDVWMasked512 x y z mask)
(VMOVDQU32Masked512 (VPSHRDVD512 x y z) mask) => (VPSHRDVDMasked512 x y z mask)
(VMOVDQU64Masked512 (VPSHRDVQ512 x y z) mask) => (VPSHRDVQMasked512 x y z mask)
(VMOVDQU16Masked512 (VPSRAVW512 x y) mask) => (VPSRAVWMasked512 x y mask)
(VMOVDQU32Masked512 (VPSRAVD512 x y) mask) => (VPSRAVDMasked512 x y mask)
(VMOVDQU64Masked512 (VPSRAVQ512 x y) mask) => (VPSRAVQMasked512 x y mask)
(VMOVDQU16Masked512 (VPSRLVW512 x y) mask) => (VPSRLVWMasked512 x y mask)
(VMOVDQU32Masked512 (VPSRLVD512 x y) mask) => (VPSRLVDMasked512 x y mask)
(VMOVDQU64Masked512 (VPSRLVQ512 x y) mask) => (VPSRLVQMasked512 x y mask)
(VMOVDQU32Masked512 (VSQRTPS512 x) mask) => (VSQRTPSMasked512 x mask)
(VMOVDQU64Masked512 (VSQRTPD512 x) mask) => (VSQRTPDMasked512 x mask)
(VMOVDQU32Masked512 (VSUBPS512 x y) mask) => (VSUBPSMasked512 x y mask)
(VMOVDQU64Masked512 (VSUBPD512 x y) mask) => (VSUBPDMasked512 x y mask)
(VMOVDQU8Masked512 (VPSUBB512 x y) mask) => (VPSUBBMasked512 x y mask)
(VMOVDQU16Masked512 (VPSUBW512 x y) mask) => (VPSUBWMasked512 x y mask)
(VMOVDQU32Masked512 (VPSUBD512 x y) mask) => (VPSUBDMasked512 x y mask)
(VMOVDQU64Masked512 (VPSUBQ512 x y) mask) => (VPSUBQMasked512 x y mask)
(VMOVDQU8Masked512 (VPSUBSB512 x y) mask) => (VPSUBSBMasked512 x y mask)
(VMOVDQU16Masked512 (VPSUBSW512 x y) mask) => (VPSUBSWMasked512 x y mask)
(VMOVDQU8Masked512 (VPSUBUSB512 x y) mask) => (VPSUBUSBMasked512 x y mask)
(VMOVDQU16Masked512 (VPSUBUSW512 x y) mask) => (VPSUBUSWMasked512 x y mask)
(VMOVDQU32Masked512 (VPXORD512 x y) mask) => (VPXORDMasked512 x y mask)
(VMOVDQU64Masked512 (VPXORQ512 x y) mask) => (VPXORQMasked512 x y mask)
(VMOVDQU16Masked512 (VPSLLW512const [a] x) mask) => (VPSLLWMasked512const [a] x mask)
(VMOVDQU32Masked512 (VPSLLD512const [a] x) mask) => (VPSLLDMasked512const [a] x mask)
(VMOVDQU64Masked512 (VPSLLQ512const [a] x) mask) => (VPSLLQMasked512const [a] x mask)
(VMOVDQU16Masked512 (VPSRLW512const [a] x) mask) => (VPSRLWMasked512const [a] x mask)
(VMOVDQU32Masked512 (VPSRLD512const [a] x) mask) => (VPSRLDMasked512const [a] x mask)
(VMOVDQU64Masked512 (VPSRLQ512const [a] x) mask) => (VPSRLQMasked512const [a] x mask)
(VMOVDQU16Masked512 (VPSRAW512const [a] x) mask) => (VPSRAWMasked512const [a] x mask)
(VMOVDQU32Masked512 (VPSRAD512const [a] x) mask) => (VPSRADMasked512const [a] x mask)
(VMOVDQU64Masked512 (VPSRAQ512const [a] x) mask) => (VPSRAQMasked512const [a] x mask)

File diff suppressed because it is too large Load diff

View file

@ -8,6 +8,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"slices" "slices"
"strings"
"text/template" "text/template"
) )
@ -20,6 +21,7 @@ type tplRuleData struct {
ArgsOut string // e.g. "x y" ArgsOut string // e.g. "x y"
MaskInConvert string // e.g. "VPMOVVec32x8ToM" MaskInConvert string // e.g. "VPMOVVec32x8ToM"
MaskOutConvert string // e.g. "VPMOVMToVec32x8" MaskOutConvert string // e.g. "VPMOVMToVec32x8"
ElementSize int // e.g. 32
} }
var ( var (
@ -39,6 +41,42 @@ var (
`)) `))
) )
func (d tplRuleData) MaskOptimization() string {
asmNoMask := d.Asm
if i := strings.Index(asmNoMask, "Masked"); i == -1 {
return ""
}
asmNoMask = strings.ReplaceAll(asmNoMask, "Masked", "")
for _, nope := range []string{"VMOVDQU", "VPCOMPRESS", "VCOMPRESS", "VPEXPAND", "VEXPAND", "VPBLENDM", "VMOVUP"} {
if strings.HasPrefix(asmNoMask, nope) {
return ""
}
}
size := asmNoMask[len(asmNoMask)-3:]
if strings.HasSuffix(asmNoMask, "const") {
sufLen := len("128const")
size = asmNoMask[len(asmNoMask)-sufLen:][:3]
}
switch size {
case "128", "256":
// TODO don't handle these yet because they will require a feature guard check in rewrite
return ""
case "512":
default:
panic("Unexpected operation size on " + d.Asm)
}
switch d.ElementSize {
case 8, 16, 32, 64:
default:
panic(fmt.Errorf("Unexpected operation width %d on %v", d.ElementSize, d.Asm))
}
return fmt.Sprintf("(VMOVDQU%dMasked512 (%s %s) mask) => (%s %s mask)\n", d.ElementSize, asmNoMask, d.Args, d.Asm, d.Args)
}
// SSA rewrite rules need to appear in a most-to-least-specific order. This works for that. // SSA rewrite rules need to appear in a most-to-least-specific order. This works for that.
var tmplOrder = map[string]int{ var tmplOrder = map[string]int{
"masksftimm": 0, "masksftimm": 0,
@ -80,11 +118,9 @@ func writeSIMDRules(ops []Operation) *bytes.Buffer {
buffer.WriteString(generatedHeader + "\n") buffer.WriteString(generatedHeader + "\n")
var allData []tplRuleData var allData []tplRuleData
var optData []tplRuleData // for peephole optimizations
for _, opr := range ops { for _, opr := range ops {
if opr.NoGenericOps != nil && *opr.NoGenericOps == "true" {
continue
}
opInShape, opOutShape, maskType, immType, gOp := opr.shape() opInShape, opOutShape, maskType, immType, gOp := opr.shape()
asm := machineOpName(maskType, gOp) asm := machineOpName(maskType, gOp)
vregInCnt := len(gOp.In) vregInCnt := len(gOp.In)
@ -146,7 +182,9 @@ func writeSIMDRules(ops []Operation) *bytes.Buffer {
data.GoType = goType(gOp) data.GoType = goType(gOp)
rearIdx := len(gOp.In) - 1 rearIdx := len(gOp.In) - 1
// Mask is at the end. // Mask is at the end.
data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes) width := *gOp.In[rearIdx].ElemBits
data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", width, *gOp.In[rearIdx].Lanes)
data.ElementSize = width
case PureKmaskIn: case PureKmaskIn:
panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations")) panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
} }
@ -196,6 +234,10 @@ func writeSIMDRules(ops []Operation) *bytes.Buffer {
data.ArgsOut = "..." data.ArgsOut = "..."
} }
data.tplName = tplName data.tplName = tplName
if opr.NoGenericOps != nil && *opr.NoGenericOps == "true" {
optData = append(optData, data)
continue
}
allData = append(allData, data) allData = append(allData, data)
} }
@ -207,5 +249,18 @@ func writeSIMDRules(ops []Operation) *bytes.Buffer {
} }
} }
seen := make(map[string]bool)
for _, data := range optData {
if data.tplName == "maskIn" {
rule := data.MaskOptimization()
if seen[rule] {
continue
}
seen[rule] = true
buffer.WriteString(rule)
}
}
return buffer return buffer
} }

View file

@ -445,3 +445,36 @@ func TestBroadcastFloat32x8(t *testing.T) {
simd.BroadcastFloat32x8(123456789).StoreSlice(s) simd.BroadcastFloat32x8(123456789).StoreSlice(s)
checkSlices(t, s, []float32{123456789, 123456789, 123456789, 123456789, 123456789, 123456789, 123456789, 123456789}) checkSlices(t, s, []float32{123456789, 123456789, 123456789, 123456789, 123456789, 123456789, 123456789, 123456789})
} }
func TestBroadcastFloat64x2(t *testing.T) {
s := make([]float64, 2, 2)
simd.BroadcastFloat64x2(123456789).StoreSlice(s)
checkSlices(t, s, []float64{123456789, 123456789})
}
func TestBroadcastUint64x2(t *testing.T) {
s := make([]uint64, 2, 2)
simd.BroadcastUint64x2(123456789).StoreSlice(s)
checkSlices(t, s, []uint64{123456789, 123456789})
}
func TestMaskOpt512(t *testing.T) {
if !simd.HasAVX512() {
t.Skip("Test requires HasAVX512, not available on this hardware")
return
}
k := make([]int64, 8, 8)
s := make([]float64, 8, 8)
a := simd.LoadFloat64x8Slice([]float64{2, 0, 2, 0, 2, 0, 2, 0})
b := simd.LoadFloat64x8Slice([]float64{1, 1, 1, 1, 1, 1, 1, 1})
c := simd.LoadFloat64x8Slice([]float64{1, 2, 3, 4, 5, 6, 7, 8})
d := simd.LoadFloat64x8Slice([]float64{2, 4, 6, 8, 10, 12, 14, 16})
g := a.Greater(b)
e := c.Add(d).Masked(g)
e.StoreSlice(s)
g.AsInt64x8().StoreSlice(k)
checkSlices[int64](t, k, []int64{-1, 0, -1, 0, -1, 0, -1, 0})
checkSlices[float64](t, s, []float64{3, 0, 9, 0, 15, 0, 21, 0})
}