mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
crypto/rsa: allocate nats on the stack for RSA 2048
With a small tweak and the help of the inliner, we preallocate enough nat backing space to do RSA-2048 on the stack. We keep the length of the preallocated slices at zero so they don't silently mask missing expandFor calls. Surprisingly enough, this doesn't move the CPU benchmark needle much, but probably reduces GC pressure on larger applications. name old time/op new time/op delta DecryptPKCS1v15/2048-8 1.25ms ± 0% 1.22ms ± 1% -1.68% (p=0.000 n=10+9) DecryptPKCS1v15/3072-8 3.78ms ± 0% 3.73ms ± 1% -1.33% (p=0.000 n=9+10) DecryptPKCS1v15/4096-8 8.62ms ± 0% 8.45ms ± 1% -1.98% (p=0.000 n=8+10) EncryptPKCS1v15/2048-8 140µs ± 1% 136µs ± 0% -2.43% (p=0.000 n=9+9) DecryptOAEP/2048-8 1.25ms ± 0% 1.24ms ± 0% -0.83% (p=0.000 n=8+10) EncryptOAEP/2048-8 140µs ± 0% 137µs ± 0% -1.82% (p=0.000 n=8+10) SignPKCS1v15/2048-8 1.29ms ± 0% 1.29ms ± 1% ~ (p=0.574 n=8+8) VerifyPKCS1v15/2048-8 139µs ± 0% 136µs ± 0% -2.12% (p=0.000 n=9+10) SignPSS/2048-8 1.30ms ± 0% 1.28ms ± 0% -0.96% (p=0.000 n=8+10) VerifyPSS/2048-8 140µs ± 0% 137µs ± 0% -1.99% (p=0.000 n=10+8) name old alloc/op new alloc/op delta DecryptPKCS1v15/2048-8 15.0kB ± 0% 0.5kB ± 0% -96.58% (p=0.000 n=10+10) DecryptPKCS1v15/3072-8 24.6kB ± 0% 3.3kB ± 0% -86.74% (p=0.000 n=10+10) DecryptPKCS1v15/4096-8 38.9kB ± 0% 4.5kB ± 0% -88.50% (p=0.000 n=10+10) EncryptPKCS1v15/2048-8 18.0kB ± 0% 1.2kB ± 0% -93.48% (p=0.000 n=10+10) DecryptOAEP/2048-8 15.2kB ± 0% 0.7kB ± 0% -95.10% (p=0.000 n=10+10) EncryptOAEP/2048-8 18.2kB ± 0% 1.4kB ± 0% -92.29% (p=0.000 n=10+10) SignPKCS1v15/2048-8 21.9kB ± 0% 0.8kB ± 0% -96.50% (p=0.000 n=10+10) VerifyPKCS1v15/2048-8 17.7kB ± 0% 0.9kB ± 0% -94.85% (p=0.000 n=10+10) SignPSS/2048-8 22.3kB ± 0% 1.2kB ± 0% -94.77% (p=0.000 n=10+10) VerifyPSS/2048-8 17.9kB ± 0% 1.1kB ± 0% -93.75% (p=0.000 n=10+10) name old allocs/op new allocs/op delta DecryptPKCS1v15/2048-8 124 ± 0% 3 ± 0% -97.58% (p=0.000 n=10+10) DecryptPKCS1v15/3072-8 140 ± 0% 9 ± 0% -93.57% (p=0.000 n=10+10) DecryptPKCS1v15/4096-8 158 ± 0% 9 ± 0% -94.30% (p=0.000 n=10+10) EncryptPKCS1v15/2048-8 80.0 ± 0% 7.0 ± 0% -91.25% (p=0.000 n=10+10) DecryptOAEP/2048-8 130 ± 0% 9 ± 0% -93.08% (p=0.000 n=10+10) EncryptOAEP/2048-8 86.0 ± 0% 13.0 ± 0% -84.88% (p=0.000 n=10+10) SignPKCS1v15/2048-8 162 ± 0% 4 ± 0% -97.53% (p=0.000 n=10+10) VerifyPKCS1v15/2048-8 79.0 ± 0% 6.0 ± 0% -92.41% (p=0.000 n=10+10) SignPSS/2048-8 167 ± 0% 9 ± 0% -94.61% (p=0.000 n=10+10) VerifyPSS/2048-8 84.0 ± 0% 11.0 ± 0% -86.90% (p=0.000 n=10+10) Change-Id: I511a2f5f6f596bbec68a0a411e83a9d04080d72a Reviewed-on: https://go-review.googlesource.com/c/go/+/445021 Run-TryBot: Filippo Valsorda <filippo@golang.org> Reviewed-by: Joedian Reid <joedian@golang.org> Reviewed-by: Roland Shoemaker <roland@golang.org> TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
parent
72d2c4c635
commit
58a2db181b
3 changed files with 95 additions and 75 deletions
|
|
@ -69,6 +69,20 @@ type nat struct {
|
||||||
limbs []uint
|
limbs []uint
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// preallocTarget is the size in bits of the numbers used to implement the most
|
||||||
|
// common and most performant RSA key size. It's also enough to cover some of
|
||||||
|
// the operations of key sizes up to 4096.
|
||||||
|
const preallocTarget = 2048
|
||||||
|
const preallocLimbs = (preallocTarget + _W) / _W
|
||||||
|
|
||||||
|
// newNat returns a new nat with a size of zero, just like new(nat), but with
|
||||||
|
// the preallocated capacity to hold a number of up to preallocTarget bits.
|
||||||
|
// newNat inlines, so the allocation can live on the stack.
|
||||||
|
func newNat() *nat {
|
||||||
|
limbs := make([]uint, 0, preallocLimbs)
|
||||||
|
return &nat{limbs}
|
||||||
|
}
|
||||||
|
|
||||||
// expand expands x to n limbs, leaving its value unchanged.
|
// expand expands x to n limbs, leaving its value unchanged.
|
||||||
func (x *nat) expand(n int) *nat {
|
func (x *nat) expand(n int) *nat {
|
||||||
for len(x.limbs) > n {
|
for len(x.limbs) > n {
|
||||||
|
|
@ -104,40 +118,40 @@ func (x *nat) reset(n int) *nat {
|
||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
// clone returns a new nat, with the same value and announced length as x.
|
// set assigns x = y, optionally resizing x to the appropriate size.
|
||||||
func (x *nat) clone() *nat {
|
func (x *nat) set(y *nat) *nat {
|
||||||
out := &nat{make([]uint, len(x.limbs))}
|
x.reset(len(y.limbs))
|
||||||
copy(out.limbs, x.limbs)
|
copy(x.limbs, y.limbs)
|
||||||
return out
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
// natFromBig creates a new natural number from a big.Int.
|
// set assigns x = n, optionally resizing n to the appropriate size.
|
||||||
//
|
//
|
||||||
// The announced length of the resulting nat is based on the actual bit size of
|
// The announced length of x is set based on the actual bit size of the input,
|
||||||
// the input, ignoring leading zeroes.
|
// ignoring leading zeroes.
|
||||||
func natFromBig(x *big.Int) *nat {
|
func (x *nat) setBig(n *big.Int) *nat {
|
||||||
xLimbs := x.Bits()
|
bitSize := bigBitLen(n)
|
||||||
bitSize := bigBitLen(x)
|
|
||||||
requiredLimbs := (bitSize + _W - 1) / _W
|
requiredLimbs := (bitSize + _W - 1) / _W
|
||||||
|
x.reset(requiredLimbs)
|
||||||
|
|
||||||
out := &nat{make([]uint, requiredLimbs)}
|
|
||||||
outI := 0
|
outI := 0
|
||||||
shift := 0
|
shift := 0
|
||||||
for i := range xLimbs {
|
limbs := n.Bits()
|
||||||
xi := uint(xLimbs[i])
|
for i := range limbs {
|
||||||
out.limbs[outI] |= (xi << shift) & _MASK
|
xi := uint(limbs[i])
|
||||||
|
x.limbs[outI] |= (xi << shift) & _MASK
|
||||||
outI++
|
outI++
|
||||||
if outI == requiredLimbs {
|
if outI == requiredLimbs {
|
||||||
return out
|
return x
|
||||||
}
|
}
|
||||||
out.limbs[outI] = xi >> (_W - shift)
|
x.limbs[outI] = xi >> (_W - shift)
|
||||||
shift++ // this assumes bits.UintSize - _W = 1
|
shift++ // this assumes bits.UintSize - _W = 1
|
||||||
if shift == _W {
|
if shift == _W {
|
||||||
shift = 0
|
shift = 0
|
||||||
outI++
|
outI++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return out
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
// fillBytes sets bytes to x as a zero-extended big-endian byte slice.
|
// fillBytes sets bytes to x as a zero-extended big-endian byte slice.
|
||||||
|
|
@ -175,31 +189,32 @@ func (x *nat) fillBytes(bytes []byte) []byte {
|
||||||
return bytes
|
return bytes
|
||||||
}
|
}
|
||||||
|
|
||||||
// natFromBytes converts a slice of big-endian bytes into a nat.
|
// setBytes assigns x = b, where b is a slice of big-endian bytes, optionally
|
||||||
|
// resizing n to the appropriate size.
|
||||||
//
|
//
|
||||||
// The announced length of the output depends on the length of bytes. Unlike
|
// The announced length of the output depends only on the length of b. Unlike
|
||||||
// big.Int, creating a nat will not remove leading zeros.
|
// big.Int, creating a nat will not remove leading zeros.
|
||||||
func natFromBytes(bytes []byte) *nat {
|
func (x *nat) setBytes(b []byte) *nat {
|
||||||
bitSize := len(bytes) * 8
|
bitSize := len(b) * 8
|
||||||
requiredLimbs := (bitSize + _W - 1) / _W
|
requiredLimbs := (bitSize + _W - 1) / _W
|
||||||
|
x.reset(requiredLimbs)
|
||||||
|
|
||||||
out := &nat{make([]uint, requiredLimbs)}
|
|
||||||
outI := 0
|
outI := 0
|
||||||
shift := 0
|
shift := 0
|
||||||
for i := len(bytes) - 1; i >= 0; i-- {
|
for i := len(b) - 1; i >= 0; i-- {
|
||||||
bi := bytes[i]
|
bi := b[i]
|
||||||
out.limbs[outI] |= uint(bi) << shift
|
x.limbs[outI] |= uint(bi) << shift
|
||||||
shift += 8
|
shift += 8
|
||||||
if shift >= _W {
|
if shift >= _W {
|
||||||
shift -= _W
|
shift -= _W
|
||||||
out.limbs[outI] &= _MASK
|
x.limbs[outI] &= _MASK
|
||||||
outI++
|
outI++
|
||||||
if shift > 0 {
|
if shift > 0 {
|
||||||
out.limbs[outI] = uint(bi) >> (8 - shift)
|
x.limbs[outI] = uint(bi) >> (8 - shift)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return out
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
// cmpEq returns 1 if x == y, and 0 otherwise.
|
// cmpEq returns 1 if x == y, and 0 otherwise.
|
||||||
|
|
@ -306,7 +321,7 @@ type modulus struct {
|
||||||
|
|
||||||
// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
|
// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
|
||||||
func rr(m *modulus) *nat {
|
func rr(m *modulus) *nat {
|
||||||
rr := new(nat).expandFor(m)
|
rr := newNat().expandFor(m)
|
||||||
// R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the
|
// R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the
|
||||||
// most significant limb to 1. We then get to R*R by shifting left by _W
|
// most significant limb to 1. We then get to R*R by shifting left by _W
|
||||||
// n + 1 times.
|
// n + 1 times.
|
||||||
|
|
@ -387,7 +402,7 @@ func modulusSize(m *modulus) int {
|
||||||
//
|
//
|
||||||
// This assumes that x is already reduced mod m, and that y < 2^_W.
|
// This assumes that x is already reduced mod m, and that y < 2^_W.
|
||||||
func (x *nat) shiftIn(y uint, m *modulus) *nat {
|
func (x *nat) shiftIn(y uint, m *modulus) *nat {
|
||||||
d := new(nat).resetFor(m)
|
d := newNat().resetFor(m)
|
||||||
|
|
||||||
// Eliminate bounds checks in the loop.
|
// Eliminate bounds checks in the loop.
|
||||||
size := len(m.nat.limbs)
|
size := len(m.nat.limbs)
|
||||||
|
|
@ -528,7 +543,7 @@ func (x *nat) modAdd(y *nat, m *modulus) *nat {
|
||||||
func (x *nat) montgomeryRepresentation(m *modulus) *nat {
|
func (x *nat) montgomeryRepresentation(m *modulus) *nat {
|
||||||
// A Montgomery multiplication (which computes a * b / R) by R * R works out
|
// A Montgomery multiplication (which computes a * b / R) by R * R works out
|
||||||
// to a multiplication by R, which takes the value out of the Montgomery domain.
|
// to a multiplication by R, which takes the value out of the Montgomery domain.
|
||||||
return x.montgomeryMul(x.clone(), m.RR, m)
|
return x.montgomeryMul(newNat().set(x), m.RR, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
|
// montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
|
||||||
|
|
@ -539,8 +554,8 @@ func (x *nat) montgomeryReduction(m *modulus) *nat {
|
||||||
// By Montgomery multiplying with 1 not in Montgomery representation, we
|
// By Montgomery multiplying with 1 not in Montgomery representation, we
|
||||||
// convert out back from Montgomery representation, because it works out to
|
// convert out back from Montgomery representation, because it works out to
|
||||||
// dividing by R.
|
// dividing by R.
|
||||||
t0 := x.clone()
|
t0 := newNat().set(x)
|
||||||
t1 := new(nat).expandFor(m)
|
t1 := newNat().expandFor(m)
|
||||||
t1.limbs[0] = 1
|
t1.limbs[0] = 1
|
||||||
return x.montgomeryMul(t0, t1, m)
|
return x.montgomeryMul(t0, t1, m)
|
||||||
}
|
}
|
||||||
|
|
@ -599,7 +614,7 @@ func (d *nat) montgomeryMul(a *nat, b *nat, m *modulus) *nat {
|
||||||
func (x *nat) modMul(y *nat, m *modulus) *nat {
|
func (x *nat) modMul(y *nat, m *modulus) *nat {
|
||||||
// A Montgomery multiplication by a value out of the Montgomery domain
|
// A Montgomery multiplication by a value out of the Montgomery domain
|
||||||
// takes the result out of Montgomery representation.
|
// takes the result out of Montgomery representation.
|
||||||
xR := x.clone().montgomeryRepresentation(m) // xR = x * R mod m
|
xR := newNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
|
||||||
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
|
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -611,18 +626,23 @@ func (out *nat) exp(x *nat, e []byte, m *modulus) *nat {
|
||||||
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
|
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
|
||||||
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
|
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
|
||||||
// Using bit sizes that don't divide 8 are more complex to implement.
|
// Using bit sizes that don't divide 8 are more complex to implement.
|
||||||
table := make([]*nat, (1<<4)-1) // table[i] = x ^ (i+1)
|
|
||||||
table[0] = x.clone().montgomeryRepresentation(m)
|
table := [(1 << 4) - 1]*nat{ // table[i] = x ^ (i+1)
|
||||||
|
// newNat calls are unrolled so they are allocated on the stack.
|
||||||
|
newNat(), newNat(), newNat(), newNat(), newNat(),
|
||||||
|
newNat(), newNat(), newNat(), newNat(), newNat(),
|
||||||
|
newNat(), newNat(), newNat(), newNat(), newNat(),
|
||||||
|
}
|
||||||
|
table[0].set(x).montgomeryRepresentation(m)
|
||||||
for i := 1; i < len(table); i++ {
|
for i := 1; i < len(table); i++ {
|
||||||
table[i] = new(nat).expandFor(m)
|
|
||||||
table[i].montgomeryMul(table[i-1], table[0], m)
|
table[i].montgomeryMul(table[i-1], table[0], m)
|
||||||
}
|
}
|
||||||
|
|
||||||
out.resetFor(m)
|
out.resetFor(m)
|
||||||
out.limbs[0] = 1
|
out.limbs[0] = 1
|
||||||
out.montgomeryRepresentation(m)
|
out.montgomeryRepresentation(m)
|
||||||
t0 := new(nat).expandFor(m)
|
t0 := newNat().expandFor(m)
|
||||||
t1 := new(nat).expandFor(m)
|
t1 := newNat().expandFor(m)
|
||||||
for _, b := range e {
|
for _, b := range e {
|
||||||
for _, j := range []int{4, 0} {
|
for _, j := range []int{4, 0} {
|
||||||
// Square four times.
|
// Square four times.
|
||||||
|
|
|
||||||
|
|
@ -30,9 +30,9 @@ func testModAddCommutative(a *nat, b *nat) bool {
|
||||||
mLimbs[i] = _MASK
|
mLimbs[i] = _MASK
|
||||||
}
|
}
|
||||||
m := modulusFromNat(&nat{mLimbs})
|
m := modulusFromNat(&nat{mLimbs})
|
||||||
aPlusB := a.clone()
|
aPlusB := new(nat).set(a)
|
||||||
aPlusB.modAdd(b, m)
|
aPlusB.modAdd(b, m)
|
||||||
bPlusA := b.clone()
|
bPlusA := new(nat).set(b)
|
||||||
bPlusA.modAdd(a, m)
|
bPlusA.modAdd(a, m)
|
||||||
return aPlusB.cmpEq(bPlusA) == 1
|
return aPlusB.cmpEq(bPlusA) == 1
|
||||||
}
|
}
|
||||||
|
|
@ -50,7 +50,7 @@ func testModSubThenAddIdentity(a *nat, b *nat) bool {
|
||||||
mLimbs[i] = _MASK
|
mLimbs[i] = _MASK
|
||||||
}
|
}
|
||||||
m := modulusFromNat(&nat{mLimbs})
|
m := modulusFromNat(&nat{mLimbs})
|
||||||
original := a.clone()
|
original := new(nat).set(a)
|
||||||
a.modSub(b, m)
|
a.modSub(b, m)
|
||||||
a.modAdd(b, m)
|
a.modAdd(b, m)
|
||||||
return a.cmpEq(original) == 1
|
return a.cmpEq(original) == 1
|
||||||
|
|
@ -66,12 +66,12 @@ func TestModSubThenAddIdentity(t *testing.T) {
|
||||||
func testMontgomeryRoundtrip(a *nat) bool {
|
func testMontgomeryRoundtrip(a *nat) bool {
|
||||||
one := &nat{make([]uint, len(a.limbs))}
|
one := &nat{make([]uint, len(a.limbs))}
|
||||||
one.limbs[0] = 1
|
one.limbs[0] = 1
|
||||||
aPlusOne := a.clone()
|
aPlusOne := new(nat).set(a)
|
||||||
aPlusOne.add(1, one)
|
aPlusOne.add(1, one)
|
||||||
m := modulusFromNat(aPlusOne)
|
m := modulusFromNat(aPlusOne)
|
||||||
monty := a.clone()
|
monty := new(nat).set(a)
|
||||||
monty.montgomeryRepresentation(m)
|
monty.montgomeryRepresentation(m)
|
||||||
aAgain := monty.clone()
|
aAgain := new(nat).set(monty)
|
||||||
aAgain.montgomeryMul(monty, one, m)
|
aAgain.montgomeryMul(monty, one, m)
|
||||||
return a.cmpEq(aAgain) == 1
|
return a.cmpEq(aAgain) == 1
|
||||||
}
|
}
|
||||||
|
|
@ -86,7 +86,7 @@ func TestMontgomeryRoundtrip(t *testing.T) {
|
||||||
func TestFromBig(t *testing.T) {
|
func TestFromBig(t *testing.T) {
|
||||||
expected := []byte{0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
|
expected := []byte{0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
|
||||||
theBig := new(big.Int).SetBytes(expected)
|
theBig := new(big.Int).SetBytes(expected)
|
||||||
actual := natFromBig(theBig).fillBytes(make([]byte, len(expected)))
|
actual := new(nat).setBig(theBig).fillBytes(make([]byte, len(expected)))
|
||||||
if !bytes.Equal(actual, expected) {
|
if !bytes.Equal(actual, expected) {
|
||||||
t.Errorf("%+x != %+x", actual, expected)
|
t.Errorf("%+x != %+x", actual, expected)
|
||||||
}
|
}
|
||||||
|
|
@ -94,7 +94,7 @@ func TestFromBig(t *testing.T) {
|
||||||
|
|
||||||
func TestFillBytes(t *testing.T) {
|
func TestFillBytes(t *testing.T) {
|
||||||
xBytes := []byte{0xAA, 0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
|
xBytes := []byte{0xAA, 0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
|
||||||
x := natFromBytes(xBytes)
|
x := new(nat).setBytes(xBytes)
|
||||||
for l := 20; l >= len(xBytes); l-- {
|
for l := 20; l >= len(xBytes); l-- {
|
||||||
buf := make([]byte, l)
|
buf := make([]byte, l)
|
||||||
rand.Read(buf)
|
rand.Read(buf)
|
||||||
|
|
@ -122,7 +122,7 @@ func TestFromBytes(t *testing.T) {
|
||||||
if len(xBytes) == 0 {
|
if len(xBytes) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
actual := natFromBytes(xBytes).fillBytes(make([]byte, len(xBytes)))
|
actual := new(nat).setBytes(xBytes).fillBytes(make([]byte, len(xBytes)))
|
||||||
if !bytes.Equal(actual, xBytes) {
|
if !bytes.Equal(actual, xBytes) {
|
||||||
t.Errorf("%+x != %+x", actual, xBytes)
|
t.Errorf("%+x != %+x", actual, xBytes)
|
||||||
return false
|
return false
|
||||||
|
|
@ -169,9 +169,9 @@ func TestShiftIn(t *testing.T) {
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for i, tt := range examples {
|
for i, tt := range examples {
|
||||||
m := modulusFromNat(natFromBytes(tt.m))
|
m := modulusFromNat(new(nat).setBytes(tt.m))
|
||||||
got := natFromBytes(tt.x).expandFor(m).shiftIn(uint(tt.y), m)
|
got := new(nat).setBytes(tt.x).expandFor(m).shiftIn(uint(tt.y), m)
|
||||||
if got.cmpEq(natFromBytes(tt.expected).expandFor(m)) != 1 {
|
if got.cmpEq(new(nat).setBytes(tt.expected).expandFor(m)) != 1 {
|
||||||
t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
|
t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -182,10 +182,10 @@ func TestModulusAndNatSizes(t *testing.T) {
|
||||||
// 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
|
// 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
|
||||||
// limbs, if they are not, they fit in three. This can be a problem because
|
// limbs, if they are not, they fit in three. This can be a problem because
|
||||||
// modulus strips leading zeroes and nat does not.
|
// modulus strips leading zeroes and nat does not.
|
||||||
m := modulusFromNat(natFromBytes([]byte{
|
m := modulusFromNat(new(nat).setBytes([]byte{
|
||||||
0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}))
|
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}))
|
||||||
x := natFromBytes([]byte{
|
x := new(nat).setBytes([]byte{
|
||||||
0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe})
|
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe})
|
||||||
x.expandFor(m) // must not panic for shrinking
|
x.expandFor(m) // must not panic for shrinking
|
||||||
|
|
@ -224,11 +224,11 @@ func TestExpand(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMod(t *testing.T) {
|
func TestMod(t *testing.T) {
|
||||||
m := modulusFromNat(natFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}))
|
m := modulusFromNat(new(nat).setBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}))
|
||||||
x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
|
x := new(nat).setBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
|
||||||
out := new(nat)
|
out := new(nat)
|
||||||
out.mod(x, m)
|
out.mod(x, m)
|
||||||
expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
|
expected := new(nat).setBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
|
||||||
if out.cmpEq(expected) != 1 {
|
if out.cmpEq(expected) != 1 {
|
||||||
t.Errorf("%+v != %+v", out, expected)
|
t.Errorf("%+v != %+v", out, expected)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -314,9 +314,9 @@ func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (*PrivateKey
|
||||||
Dq: Dq,
|
Dq: Dq,
|
||||||
Qinv: Qinv,
|
Qinv: Qinv,
|
||||||
CRTValues: make([]CRTValue, 0), // non-nil, to match Precompute
|
CRTValues: make([]CRTValue, 0), // non-nil, to match Precompute
|
||||||
n: modulusFromNat(natFromBig(N)),
|
n: modulusFromNat(newNat().setBig(N)),
|
||||||
p: modulusFromNat(natFromBig(P)),
|
p: modulusFromNat(newNat().setBig(P)),
|
||||||
q: modulusFromNat(natFromBig(Q)),
|
q: modulusFromNat(newNat().setBig(Q)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return key, nil
|
return key, nil
|
||||||
|
|
@ -454,12 +454,12 @@ var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA key siz
|
||||||
func encrypt(pub *PublicKey, plaintext []byte) []byte {
|
func encrypt(pub *PublicKey, plaintext []byte) []byte {
|
||||||
boring.Unreachable()
|
boring.Unreachable()
|
||||||
|
|
||||||
N := modulusFromNat(natFromBig(pub.N))
|
N := modulusFromNat(newNat().setBig(pub.N))
|
||||||
m := natFromBytes(plaintext).expandFor(N)
|
m := newNat().setBytes(plaintext).expandFor(N)
|
||||||
e := intToBytes(pub.E)
|
e := intToBytes(pub.E)
|
||||||
|
|
||||||
out := make([]byte, modulusSize(N))
|
out := make([]byte, modulusSize(N))
|
||||||
return new(nat).exp(m, e, N).fillBytes(out)
|
return newNat().exp(m, e, N).fillBytes(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// intToBytes returns i as a big-endian slice of bytes with no leading zeroes,
|
// intToBytes returns i as a big-endian slice of bytes with no leading zeroes,
|
||||||
|
|
@ -553,9 +553,9 @@ var ErrVerification = errors.New("crypto/rsa: verification error")
|
||||||
// in the future.
|
// in the future.
|
||||||
func (priv *PrivateKey) Precompute() {
|
func (priv *PrivateKey) Precompute() {
|
||||||
if priv.Precomputed.n == nil && len(priv.Primes) == 2 {
|
if priv.Precomputed.n == nil && len(priv.Primes) == 2 {
|
||||||
priv.Precomputed.n = modulusFromNat(natFromBig(priv.N))
|
priv.Precomputed.n = modulusFromNat(newNat().setBig(priv.N))
|
||||||
priv.Precomputed.p = modulusFromNat(natFromBig(priv.Primes[0]))
|
priv.Precomputed.p = modulusFromNat(newNat().setBig(priv.Primes[0]))
|
||||||
priv.Precomputed.q = modulusFromNat(natFromBig(priv.Primes[1]))
|
priv.Precomputed.q = modulusFromNat(newNat().setBig(priv.Primes[1]))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill in the backwards-compatibility *big.Int values.
|
// Fill in the backwards-compatibility *big.Int values.
|
||||||
|
|
@ -600,9 +600,9 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) {
|
||||||
|
|
||||||
N := priv.Precomputed.n
|
N := priv.Precomputed.n
|
||||||
if N == nil {
|
if N == nil {
|
||||||
N = modulusFromNat(natFromBig(priv.N))
|
N = modulusFromNat(newNat().setBig(priv.N))
|
||||||
}
|
}
|
||||||
c := natFromBytes(ciphertext).expandFor(N)
|
c := newNat().setBytes(ciphertext).expandFor(N)
|
||||||
if c.cmpGeq(N.nat) == 1 {
|
if c.cmpGeq(N.nat) == 1 {
|
||||||
return nil, ErrDecryption
|
return nil, ErrDecryption
|
||||||
}
|
}
|
||||||
|
|
@ -612,18 +612,18 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) {
|
||||||
|
|
||||||
var m *nat
|
var m *nat
|
||||||
if priv.Precomputed.n == nil {
|
if priv.Precomputed.n == nil {
|
||||||
m = new(nat).exp(c, priv.D.Bytes(), N)
|
m = newNat().exp(c, priv.D.Bytes(), N)
|
||||||
} else {
|
} else {
|
||||||
t0 := new(nat)
|
t0 := newNat()
|
||||||
P, Q := priv.Precomputed.p, priv.Precomputed.q
|
P, Q := priv.Precomputed.p, priv.Precomputed.q
|
||||||
// m = c ^ Dp mod p
|
// m = c ^ Dp mod p
|
||||||
m = new(nat).exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P)
|
m = newNat().exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P)
|
||||||
// m2 = c ^ Dq mod q
|
// m2 = c ^ Dq mod q
|
||||||
m2 := new(nat).exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
|
m2 := newNat().exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
|
||||||
// m = m - m2 mod p
|
// m = m - m2 mod p
|
||||||
m.modSub(t0.mod(m2, P), P)
|
m.modSub(t0.mod(m2, P), P)
|
||||||
// m = m * Qinv mod p
|
// m = m * Qinv mod p
|
||||||
m.modMul(natFromBig(priv.Precomputed.Qinv).expandFor(P), P)
|
m.modMul(newNat().setBig(priv.Precomputed.Qinv).expandFor(P), P)
|
||||||
// m = m * q mod N
|
// m = m * q mod N
|
||||||
m.expandFor(N).modMul(t0.mod(Q.nat, N), N)
|
m.expandFor(N).modMul(t0.mod(Q.nat, N), N)
|
||||||
// m = m + m2 mod N
|
// m = m + m2 mod N
|
||||||
|
|
@ -631,7 +631,7 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if check {
|
if check {
|
||||||
c1 := new(nat).exp(m, intToBytes(priv.E), N)
|
c1 := newNat().exp(m, intToBytes(priv.E), N)
|
||||||
if c1.cmpEq(c) != 1 {
|
if c1.cmpEq(c) != 1 {
|
||||||
return nil, ErrDecryption
|
return nil, ErrDecryption
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue