mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
crypto/internal/fips140/aes: optimize ctrBlocks8Asm on amd64
Implement overflow-aware optimization in ctrBlocks8Asm: make a fast branch
in case when there is no overflow. One branch per 8 blocks is faster than
7 increments in general purpose registers and transfers from them to XMM.
Added AES-192 and AES-256 modes to the AES-CTR benchmark.
Added a correctness test in ctr_test.go for the overflow optimization.
This improves performance, especially in AES-128 mode.
goos: windows
goarch: amd64
pkg: crypto/cipher
cpu: AMD Ryzen 7 5800H with Radeon Graphics
│ B/s │ B/s vs base
AESCTR/128/50-16 1.377Gi ± 0% 1.384Gi ± 0% +0.51% (p=0.028 n=20)
AESCTR/128/1K-16 6.164Gi ± 0% 6.892Gi ± 1% +11.81% (p=0.000 n=20)
AESCTR/128/8K-16 7.372Gi ± 0% 8.768Gi ± 1% +18.95% (p=0.000 n=20)
AESCTR/192/50-16 1.289Gi ± 0% 1.279Gi ± 0% -0.75% (p=0.001 n=20)
AESCTR/192/1K-16 5.734Gi ± 0% 6.011Gi ± 0% +4.83% (p=0.000 n=20)
AESCTR/192/8K-16 6.889Gi ± 1% 7.437Gi ± 0% +7.96% (p=0.000 n=20)
AESCTR/256/50-16 1.170Gi ± 0% 1.163Gi ± 0% -0.54% (p=0.005 n=20)
AESCTR/256/1K-16 5.235Gi ± 0% 5.391Gi ± 0% +2.98% (p=0.000 n=20)
AESCTR/256/8K-16 6.361Gi ± 0% 6.676Gi ± 0% +4.94% (p=0.000 n=20)
geomean 3.681Gi 3.882Gi +5.46%
The slight slowdown on 50-byte workloads is unrelated to this change,
because such workloads never use ctrBlocks8Asm.
Updates #76061
Change-Id: Idfd628ac8bb282d9c73c6adf048eb12274a41379
GitHub-Last-Rev: 5aadd39351
GitHub-Pull-Request: golang/go#76059
Reviewed-on: https://go-review.googlesource.com/c/go/+/714361
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: AHMAD ابو وليد <mizommz@gmail.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
parent
03fcb33c0e
commit
71f8f031b2
4 changed files with 185 additions and 29 deletions
|
|
@ -65,12 +65,12 @@ func BenchmarkAESGCM(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
func benchmarkAESStream(b *testing.B, mode func(cipher.Block, []byte) cipher.Stream, buf []byte) {
|
||||
func benchmarkAESStream(b *testing.B, mode func(cipher.Block, []byte) cipher.Stream, buf []byte, keySize int) {
|
||||
b.SetBytes(int64(len(buf)))
|
||||
|
||||
var key [16]byte
|
||||
key := make([]byte, keySize)
|
||||
var iv [16]byte
|
||||
aes, _ := aes.NewCipher(key[:])
|
||||
aes, _ := aes.NewCipher(key)
|
||||
stream := mode(aes, iv[:])
|
||||
|
||||
b.ResetTimer()
|
||||
|
|
@ -87,15 +87,20 @@ const almost1K = 1024 - 5
|
|||
const almost8K = 8*1024 - 5
|
||||
|
||||
func BenchmarkAESCTR(b *testing.B) {
|
||||
b.Run("50", func(b *testing.B) {
|
||||
benchmarkAESStream(b, cipher.NewCTR, make([]byte, 50))
|
||||
})
|
||||
b.Run("1K", func(b *testing.B) {
|
||||
benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost1K))
|
||||
})
|
||||
b.Run("8K", func(b *testing.B) {
|
||||
benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost8K))
|
||||
})
|
||||
for _, keyBits := range []int{128, 192, 256} {
|
||||
keySize := keyBits / 8
|
||||
b.Run(strconv.Itoa(keyBits), func(b *testing.B) {
|
||||
b.Run("50", func(b *testing.B) {
|
||||
benchmarkAESStream(b, cipher.NewCTR, make([]byte, 50), keySize)
|
||||
})
|
||||
b.Run("1K", func(b *testing.B) {
|
||||
benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost1K), keySize)
|
||||
})
|
||||
b.Run("8K", func(b *testing.B) {
|
||||
benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost8K), keySize)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAESCBCEncrypt1K(b *testing.B) {
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import (
|
|||
"crypto/internal/boring"
|
||||
"crypto/internal/cryptotest"
|
||||
fipsaes "crypto/internal/fips140/aes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
|
|
@ -117,6 +118,60 @@ func makeTestingCiphers(aesBlock cipher.Block, iv []byte) (genericCtr, multibloc
|
|||
return cipher.NewCTR(wrap(aesBlock), iv), cipher.NewCTR(aesBlock, iv)
|
||||
}
|
||||
|
||||
// TestCTR_AES_blocks8FastPathMatchesGeneric ensures the overlow aware branch
|
||||
// produces identical keystreams to the generic counter walker across
|
||||
// representative IVs, including near-overflow cases.
|
||||
func TestCTR_AES_blocks8FastPathMatchesGeneric(t *testing.T) {
|
||||
key := make([]byte, aes.BlockSize)
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := block.(*fipsaes.Block); !ok {
|
||||
t.Skip("requires crypto/internal/fips140/aes")
|
||||
}
|
||||
|
||||
keystream := make([]byte, 8*aes.BlockSize)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
hi uint64
|
||||
lo uint64
|
||||
}{
|
||||
{"Zero", 0, 0},
|
||||
{"NearOverflowMinus7", 1, ^uint64(0) - 7},
|
||||
{"NearOverflowMinus6", 2, ^uint64(0) - 6},
|
||||
{"Overflow", 0, ^uint64(0)},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var iv [aes.BlockSize]byte
|
||||
binary.BigEndian.PutUint64(iv[0:8], tc.hi)
|
||||
binary.BigEndian.PutUint64(iv[8:], tc.lo)
|
||||
|
||||
generic, multiblock := makeTestingCiphers(block, iv[:])
|
||||
|
||||
genericOut := make([]byte, len(keystream))
|
||||
multiblockOut := make([]byte, len(keystream))
|
||||
|
||||
generic.XORKeyStream(genericOut, keystream)
|
||||
multiblock.XORKeyStream(multiblockOut, keystream)
|
||||
|
||||
if !bytes.Equal(multiblockOut, genericOut) {
|
||||
t.Fatalf("mismatch for iv %#x:%#x\n"+
|
||||
"asm keystream: %x\n"+
|
||||
"gen keystream: %x\n"+
|
||||
"asm counters: %x\n"+
|
||||
"gen counters: %x",
|
||||
tc.hi, tc.lo, multiblockOut, genericOut,
|
||||
extractCounters(block, multiblockOut),
|
||||
extractCounters(block, genericOut))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func randBytes(t *testing.T, r *rand.Rand, count int) []byte {
|
||||
t.Helper()
|
||||
buf := make([]byte, count)
|
||||
|
|
@ -297,3 +352,12 @@ func TestCTR_AES_multiblock_XORKeyStreamAt(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func extractCounters(block cipher.Block, keystream []byte) []byte {
|
||||
blockSize := block.BlockSize()
|
||||
res := make([]byte, len(keystream))
|
||||
for i := 0; i < len(keystream); i += blockSize {
|
||||
block.Decrypt(res[i:i+blockSize], keystream[i:i+blockSize])
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,19 +40,79 @@ func ctrBlocks(numBlocks int) {
|
|||
bswap := XMM()
|
||||
MOVOU(bswapMask(), bswap)
|
||||
|
||||
blocks := make([]VecVirtual, 0, numBlocks)
|
||||
blocks := make([]VecVirtual, numBlocks)
|
||||
|
||||
// Lay out counter block plaintext.
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
x := XMM()
|
||||
blocks = append(blocks, x)
|
||||
// For the 8-block case we optimize counter generation. We build the first
|
||||
// counter as usual, then check whether the remaining seven increments will
|
||||
// overflow. When they do not (the common case) we keep the work entirely in
|
||||
// XMM registers to avoid expensive general-purpose -> XMM moves. Otherwise
|
||||
// we fall back to the traditional scalar path.
|
||||
if numBlocks == 8 {
|
||||
for i := range blocks {
|
||||
blocks[i] = XMM()
|
||||
}
|
||||
|
||||
MOVQ(ivlo, x)
|
||||
PINSRQ(Imm(1), ivhi, x)
|
||||
PSHUFB(bswap, x)
|
||||
if i < numBlocks-1 {
|
||||
ADDQ(Imm(1), ivlo)
|
||||
ADCQ(Imm(0), ivhi)
|
||||
base := XMM()
|
||||
tmp := GP64()
|
||||
addVec := XMM()
|
||||
|
||||
MOVQ(ivlo, blocks[0])
|
||||
PINSRQ(Imm(1), ivhi, blocks[0])
|
||||
MOVAPS(blocks[0], base)
|
||||
PSHUFB(bswap, blocks[0])
|
||||
|
||||
// Check whether any of these eight counters will overflow.
|
||||
MOVQ(ivlo, tmp)
|
||||
ADDQ(Imm(uint64(numBlocks-1)), tmp)
|
||||
slowLabel := fmt.Sprintf("ctr%d_slow", numBlocks)
|
||||
doneLabel := fmt.Sprintf("ctr%d_done", numBlocks)
|
||||
JC(LabelRef(slowLabel))
|
||||
|
||||
// Fast branch: create an XMM increment vector containing the value 1.
|
||||
// Adding it to the base counter yields each subsequent counter.
|
||||
XORQ(tmp, tmp)
|
||||
INCQ(tmp)
|
||||
PXOR(addVec, addVec)
|
||||
PINSRQ(Imm(0), tmp, addVec)
|
||||
|
||||
for i := 1; i < numBlocks; i++ {
|
||||
PADDQ(addVec, base)
|
||||
MOVAPS(base, blocks[i])
|
||||
}
|
||||
JMP(LabelRef(doneLabel))
|
||||
|
||||
Label(slowLabel)
|
||||
ADDQ(Imm(1), ivlo)
|
||||
ADCQ(Imm(0), ivhi)
|
||||
for i := 1; i < numBlocks; i++ {
|
||||
MOVQ(ivlo, blocks[i])
|
||||
PINSRQ(Imm(1), ivhi, blocks[i])
|
||||
if i < numBlocks-1 {
|
||||
ADDQ(Imm(1), ivlo)
|
||||
ADCQ(Imm(0), ivhi)
|
||||
}
|
||||
}
|
||||
|
||||
Label(doneLabel)
|
||||
|
||||
// Convert little-endian counters to big-endian after the branch since
|
||||
// both paths share the same shuffle sequence.
|
||||
for i := 1; i < numBlocks; i++ {
|
||||
PSHUFB(bswap, blocks[i])
|
||||
}
|
||||
} else {
|
||||
// Lay out counter block plaintext.
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
x := XMM()
|
||||
blocks[i] = x
|
||||
|
||||
MOVQ(ivlo, x)
|
||||
PINSRQ(Imm(1), ivhi, x)
|
||||
PSHUFB(bswap, x)
|
||||
if i < numBlocks-1 {
|
||||
ADDQ(Imm(1), ivlo)
|
||||
ADCQ(Imm(0), ivhi)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -286,41 +286,68 @@ TEXT ·ctrBlocks8Asm(SB), $0-48
|
|||
MOVOU bswapMask<>+0(SB), X0
|
||||
MOVQ SI, X1
|
||||
PINSRQ $0x01, DI, X1
|
||||
MOVAPS X1, X8
|
||||
PSHUFB X0, X1
|
||||
MOVQ SI, R8
|
||||
ADDQ $0x07, R8
|
||||
JC ctr8_slow
|
||||
XORQ R8, R8
|
||||
INCQ R8
|
||||
PXOR X9, X9
|
||||
PINSRQ $0x00, R8, X9
|
||||
PADDQ X9, X8
|
||||
MOVAPS X8, X2
|
||||
PADDQ X9, X8
|
||||
MOVAPS X8, X3
|
||||
PADDQ X9, X8
|
||||
MOVAPS X8, X4
|
||||
PADDQ X9, X8
|
||||
MOVAPS X8, X5
|
||||
PADDQ X9, X8
|
||||
MOVAPS X8, X6
|
||||
PADDQ X9, X8
|
||||
MOVAPS X8, X7
|
||||
PADDQ X9, X8
|
||||
MOVAPS X8, X8
|
||||
JMP ctr8_done
|
||||
|
||||
ctr8_slow:
|
||||
ADDQ $0x01, SI
|
||||
ADCQ $0x00, DI
|
||||
MOVQ SI, X2
|
||||
PINSRQ $0x01, DI, X2
|
||||
PSHUFB X0, X2
|
||||
ADDQ $0x01, SI
|
||||
ADCQ $0x00, DI
|
||||
MOVQ SI, X3
|
||||
PINSRQ $0x01, DI, X3
|
||||
PSHUFB X0, X3
|
||||
ADDQ $0x01, SI
|
||||
ADCQ $0x00, DI
|
||||
MOVQ SI, X4
|
||||
PINSRQ $0x01, DI, X4
|
||||
PSHUFB X0, X4
|
||||
ADDQ $0x01, SI
|
||||
ADCQ $0x00, DI
|
||||
MOVQ SI, X5
|
||||
PINSRQ $0x01, DI, X5
|
||||
PSHUFB X0, X5
|
||||
ADDQ $0x01, SI
|
||||
ADCQ $0x00, DI
|
||||
MOVQ SI, X6
|
||||
PINSRQ $0x01, DI, X6
|
||||
PSHUFB X0, X6
|
||||
ADDQ $0x01, SI
|
||||
ADCQ $0x00, DI
|
||||
MOVQ SI, X7
|
||||
PINSRQ $0x01, DI, X7
|
||||
PSHUFB X0, X7
|
||||
ADDQ $0x01, SI
|
||||
ADCQ $0x00, DI
|
||||
MOVQ SI, X8
|
||||
PINSRQ $0x01, DI, X8
|
||||
|
||||
ctr8_done:
|
||||
PSHUFB X0, X2
|
||||
PSHUFB X0, X3
|
||||
PSHUFB X0, X4
|
||||
PSHUFB X0, X5
|
||||
PSHUFB X0, X6
|
||||
PSHUFB X0, X7
|
||||
PSHUFB X0, X8
|
||||
MOVUPS (CX), X0
|
||||
PXOR X0, X1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue