runtime: save AVX2 and AVX-512 state on asynchronous preemption

Based on CL 669415 by shaojunyang@google.com.

This is a cherry-pick of CL 680900 from the dev.simd branch.

Change-Id: I574f15c3b18a7179a1573aaf567caf18d8602ef1
Reviewed-on: https://go-review.googlesource.com/c/go/+/693397
Reviewed-by: Cherry Mui <cherryyz@google.com>
Auto-Submit: Austin Clements <austin@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Austin Clements 2025-06-12 15:33:41 -04:00 committed by Gopher Robot
parent af0c4fe2ca
commit 55d961b202
4 changed files with 229 additions and 56 deletions

View file

@ -13,6 +13,7 @@ import (
const ( const (
offsetX86HasAVX = unsafe.Offsetof(cpu.X86.HasAVX) offsetX86HasAVX = unsafe.Offsetof(cpu.X86.HasAVX)
offsetX86HasAVX2 = unsafe.Offsetof(cpu.X86.HasAVX2) offsetX86HasAVX2 = unsafe.Offsetof(cpu.X86.HasAVX2)
offsetX86HasAVX512 = unsafe.Offsetof(cpu.X86.HasAVX512) // F+CD+BW+DQ+VL
offsetX86HasERMS = unsafe.Offsetof(cpu.X86.HasERMS) offsetX86HasERMS = unsafe.Offsetof(cpu.X86.HasERMS)
offsetX86HasRDTSCP = unsafe.Offsetof(cpu.X86.HasRDTSCP) offsetX86HasRDTSCP = unsafe.Offsetof(cpu.X86.HasRDTSCP)

View file

@ -285,7 +285,7 @@ func gen386(g *gen) {
func genAMD64(g *gen) { func genAMD64(g *gen) {
const xReg = "AX" // *xRegState const xReg = "AX" // *xRegState
p := g.p p, label := g.p, g.label
// Assign stack offsets. // Assign stack offsets.
var l = layout{sp: "SP"} var l = layout{sp: "SP"}
@ -297,15 +297,33 @@ func genAMD64(g *gen) {
l.add("MOVQ", reg, 8) l.add("MOVQ", reg, 8)
} }
} }
lXRegs := layout{sp: xReg} // Non-GP registers // Create layouts for X, Y, and Z registers.
for _, reg := range regNamesAMD64 { const (
if strings.HasPrefix(reg, "X") { numXRegs = 16
lXRegs.add("MOVUPS", reg, 16) numZRegs = 16 // TODO: If we start using upper registers, change to 32
numKRegs = 8
)
lZRegs := layout{sp: xReg} // Non-GP registers
lXRegs, lYRegs := lZRegs, lZRegs
for i := range numZRegs {
lZRegs.add("VMOVDQU64", fmt.Sprintf("Z%d", i), 512/8)
if i < numXRegs {
// Use SSE-only instructions for X registers.
lXRegs.add("MOVUPS", fmt.Sprintf("X%d", i), 128/8)
lYRegs.add("VMOVDQU", fmt.Sprintf("Y%d", i), 256/8)
} }
} }
writeXRegs(g.goarch, &lXRegs) for i := range numKRegs {
lZRegs.add("KMOVQ", fmt.Sprintf("K%d", i), 8)
// TODO: MXCSR register? }
// The Z layout is the most general, so we line up the others with that one.
// We don't have to do this, but it results in a nice Go type. If we split
// this into multiple types, we probably should stop doing this.
for i := range lXRegs.regs {
lXRegs.regs[i].pos = lZRegs.regs[i].pos
lYRegs.regs[i].pos = lZRegs.regs[i].pos
}
writeXRegs(g.goarch, &lZRegs)
p("PUSHQ BP") p("PUSHQ BP")
p("MOVQ SP, BP") p("MOVQ SP, BP")
@ -333,16 +351,56 @@ func genAMD64(g *gen) {
p("MOVQ g_m(R14), %s", xReg) p("MOVQ g_m(R14), %s", xReg)
p("MOVQ m_p(%s), %s", xReg, xReg) p("MOVQ m_p(%s), %s", xReg, xReg)
p("LEAQ (p_xRegs+xRegPerP_scratch)(%s), %s", xReg, xReg) p("LEAQ (p_xRegs+xRegPerP_scratch)(%s), %s", xReg, xReg)
lXRegs.save(g)
// Which registers do we need to save?
p("#ifdef GOEXPERIMENT_simd")
p("CMPB internalcpu·X86+const_offsetX86HasAVX512(SB), $1")
p("JE saveAVX512")
p("CMPB internalcpu·X86+const_offsetX86HasAVX2(SB), $1")
p("JE saveAVX2")
p("#endif")
// No features. Assume only SSE.
label("saveSSE:")
lXRegs.save(g)
p("JMP preempt")
label("saveAVX2:")
lYRegs.save(g)
p("JMP preempt")
label("saveAVX512:")
lZRegs.save(g)
p("JMP preempt")
label("preempt:")
p("CALL ·asyncPreempt2(SB)") p("CALL ·asyncPreempt2(SB)")
p("// Restore non-GPs from *p.xRegs.cache") p("// Restore non-GPs from *p.xRegs.cache")
p("MOVQ g_m(R14), %s", xReg) p("MOVQ g_m(R14), %s", xReg)
p("MOVQ m_p(%s), %s", xReg, xReg) p("MOVQ m_p(%s), %s", xReg, xReg)
p("MOVQ (p_xRegs+xRegPerP_cache)(%s), %s", xReg, xReg) p("MOVQ (p_xRegs+xRegPerP_cache)(%s), %s", xReg, xReg)
lXRegs.restore(g)
p("#ifdef GOEXPERIMENT_simd")
p("CMPB internalcpu·X86+const_offsetX86HasAVX512(SB), $1")
p("JE restoreAVX512")
p("CMPB internalcpu·X86+const_offsetX86HasAVX2(SB), $1")
p("JE restoreAVX2")
p("#endif")
label("restoreSSE:")
lXRegs.restore(g)
p("JMP restoreGPs")
label("restoreAVX2:")
lYRegs.restore(g)
p("JMP restoreGPs")
label("restoreAVX512:")
lZRegs.restore(g)
p("JMP restoreGPs")
label("restoreGPs:")
p("// Restore GPs") p("// Restore GPs")
l.restore(g) l.restore(g)
p("ADJSP $%d", -l.stack) p("ADJSP $%d", -l.stack)

View file

@ -3,20 +3,28 @@
package runtime package runtime
type xRegs struct { type xRegs struct {
X0 [16]byte Z0 [64]byte
X1 [16]byte Z1 [64]byte
X2 [16]byte Z2 [64]byte
X3 [16]byte Z3 [64]byte
X4 [16]byte Z4 [64]byte
X5 [16]byte Z5 [64]byte
X6 [16]byte Z6 [64]byte
X7 [16]byte Z7 [64]byte
X8 [16]byte Z8 [64]byte
X9 [16]byte Z9 [64]byte
X10 [16]byte Z10 [64]byte
X11 [16]byte Z11 [64]byte
X12 [16]byte Z12 [64]byte
X13 [16]byte Z13 [64]byte
X14 [16]byte Z14 [64]byte
X15 [16]byte Z15 [64]byte
K0 uint64
K1 uint64
K2 uint64
K3 uint64
K4 uint64
K5 uint64
K6 uint64
K7 uint64
} }

View file

@ -36,43 +36,149 @@ TEXT ·asyncPreempt(SB),NOSPLIT|NOFRAME,$0-0
MOVQ g_m(R14), AX MOVQ g_m(R14), AX
MOVQ m_p(AX), AX MOVQ m_p(AX), AX
LEAQ (p_xRegs+xRegPerP_scratch)(AX), AX LEAQ (p_xRegs+xRegPerP_scratch)(AX), AX
#ifdef GOEXPERIMENT_simd
CMPB internalcpu·X86+const_offsetX86HasAVX512(SB), $1
JE saveAVX512
CMPB internalcpu·X86+const_offsetX86HasAVX2(SB), $1
JE saveAVX2
#endif
saveSSE:
MOVUPS X0, 0(AX) MOVUPS X0, 0(AX)
MOVUPS X1, 16(AX) MOVUPS X1, 64(AX)
MOVUPS X2, 32(AX) MOVUPS X2, 128(AX)
MOVUPS X3, 48(AX) MOVUPS X3, 192(AX)
MOVUPS X4, 64(AX) MOVUPS X4, 256(AX)
MOVUPS X5, 80(AX) MOVUPS X5, 320(AX)
MOVUPS X6, 96(AX) MOVUPS X6, 384(AX)
MOVUPS X7, 112(AX) MOVUPS X7, 448(AX)
MOVUPS X8, 128(AX) MOVUPS X8, 512(AX)
MOVUPS X9, 144(AX) MOVUPS X9, 576(AX)
MOVUPS X10, 160(AX) MOVUPS X10, 640(AX)
MOVUPS X11, 176(AX) MOVUPS X11, 704(AX)
MOVUPS X12, 192(AX) MOVUPS X12, 768(AX)
MOVUPS X13, 208(AX) MOVUPS X13, 832(AX)
MOVUPS X14, 224(AX) MOVUPS X14, 896(AX)
MOVUPS X15, 240(AX) MOVUPS X15, 960(AX)
JMP preempt
saveAVX2:
VMOVDQU Y0, 0(AX)
VMOVDQU Y1, 64(AX)
VMOVDQU Y2, 128(AX)
VMOVDQU Y3, 192(AX)
VMOVDQU Y4, 256(AX)
VMOVDQU Y5, 320(AX)
VMOVDQU Y6, 384(AX)
VMOVDQU Y7, 448(AX)
VMOVDQU Y8, 512(AX)
VMOVDQU Y9, 576(AX)
VMOVDQU Y10, 640(AX)
VMOVDQU Y11, 704(AX)
VMOVDQU Y12, 768(AX)
VMOVDQU Y13, 832(AX)
VMOVDQU Y14, 896(AX)
VMOVDQU Y15, 960(AX)
JMP preempt
saveAVX512:
VMOVDQU64 Z0, 0(AX)
VMOVDQU64 Z1, 64(AX)
VMOVDQU64 Z2, 128(AX)
VMOVDQU64 Z3, 192(AX)
VMOVDQU64 Z4, 256(AX)
VMOVDQU64 Z5, 320(AX)
VMOVDQU64 Z6, 384(AX)
VMOVDQU64 Z7, 448(AX)
VMOVDQU64 Z8, 512(AX)
VMOVDQU64 Z9, 576(AX)
VMOVDQU64 Z10, 640(AX)
VMOVDQU64 Z11, 704(AX)
VMOVDQU64 Z12, 768(AX)
VMOVDQU64 Z13, 832(AX)
VMOVDQU64 Z14, 896(AX)
VMOVDQU64 Z15, 960(AX)
KMOVQ K0, 1024(AX)
KMOVQ K1, 1032(AX)
KMOVQ K2, 1040(AX)
KMOVQ K3, 1048(AX)
KMOVQ K4, 1056(AX)
KMOVQ K5, 1064(AX)
KMOVQ K6, 1072(AX)
KMOVQ K7, 1080(AX)
JMP preempt
preempt:
CALL ·asyncPreempt2(SB) CALL ·asyncPreempt2(SB)
// Restore non-GPs from *p.xRegs.cache // Restore non-GPs from *p.xRegs.cache
MOVQ g_m(R14), AX MOVQ g_m(R14), AX
MOVQ m_p(AX), AX MOVQ m_p(AX), AX
MOVQ (p_xRegs+xRegPerP_cache)(AX), AX MOVQ (p_xRegs+xRegPerP_cache)(AX), AX
MOVUPS 240(AX), X15 #ifdef GOEXPERIMENT_simd
MOVUPS 224(AX), X14 CMPB internalcpu·X86+const_offsetX86HasAVX512(SB), $1
MOVUPS 208(AX), X13 JE restoreAVX512
MOVUPS 192(AX), X12 CMPB internalcpu·X86+const_offsetX86HasAVX2(SB), $1
MOVUPS 176(AX), X11 JE restoreAVX2
MOVUPS 160(AX), X10 #endif
MOVUPS 144(AX), X9 restoreSSE:
MOVUPS 128(AX), X8 MOVUPS 960(AX), X15
MOVUPS 112(AX), X7 MOVUPS 896(AX), X14
MOVUPS 96(AX), X6 MOVUPS 832(AX), X13
MOVUPS 80(AX), X5 MOVUPS 768(AX), X12
MOVUPS 64(AX), X4 MOVUPS 704(AX), X11
MOVUPS 48(AX), X3 MOVUPS 640(AX), X10
MOVUPS 32(AX), X2 MOVUPS 576(AX), X9
MOVUPS 16(AX), X1 MOVUPS 512(AX), X8
MOVUPS 448(AX), X7
MOVUPS 384(AX), X6
MOVUPS 320(AX), X5
MOVUPS 256(AX), X4
MOVUPS 192(AX), X3
MOVUPS 128(AX), X2
MOVUPS 64(AX), X1
MOVUPS 0(AX), X0 MOVUPS 0(AX), X0
JMP restoreGPs
restoreAVX2:
VMOVDQU 960(AX), Y15
VMOVDQU 896(AX), Y14
VMOVDQU 832(AX), Y13
VMOVDQU 768(AX), Y12
VMOVDQU 704(AX), Y11
VMOVDQU 640(AX), Y10
VMOVDQU 576(AX), Y9
VMOVDQU 512(AX), Y8
VMOVDQU 448(AX), Y7
VMOVDQU 384(AX), Y6
VMOVDQU 320(AX), Y5
VMOVDQU 256(AX), Y4
VMOVDQU 192(AX), Y3
VMOVDQU 128(AX), Y2
VMOVDQU 64(AX), Y1
VMOVDQU 0(AX), Y0
JMP restoreGPs
restoreAVX512:
KMOVQ 1080(AX), K7
KMOVQ 1072(AX), K6
KMOVQ 1064(AX), K5
KMOVQ 1056(AX), K4
KMOVQ 1048(AX), K3
KMOVQ 1040(AX), K2
KMOVQ 1032(AX), K1
KMOVQ 1024(AX), K0
VMOVDQU64 960(AX), Z15
VMOVDQU64 896(AX), Z14
VMOVDQU64 832(AX), Z13
VMOVDQU64 768(AX), Z12
VMOVDQU64 704(AX), Z11
VMOVDQU64 640(AX), Z10
VMOVDQU64 576(AX), Z9
VMOVDQU64 512(AX), Z8
VMOVDQU64 448(AX), Z7
VMOVDQU64 384(AX), Z6
VMOVDQU64 320(AX), Z5
VMOVDQU64 256(AX), Z4
VMOVDQU64 192(AX), Z3
VMOVDQU64 128(AX), Z2
VMOVDQU64 64(AX), Z1
VMOVDQU64 0(AX), Z0
JMP restoreGPs
restoreGPs:
// Restore GPs // Restore GPs
MOVQ 104(SP), R15 MOVQ 104(SP), R15
MOVQ 96(SP), R14 MOVQ 96(SP), R14