diff --git a/src/crypto/internal/bigmod/nat.go b/src/crypto/internal/bigmod/nat.go index 71699078e29..5cbae40efe9 100644 --- a/src/crypto/internal/bigmod/nat.go +++ b/src/crypto/internal/bigmod/nat.go @@ -688,6 +688,25 @@ func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { return x } +// addMulVVW multiplies the multi-word value x by the single-word value y, +// adding the result to the multi-word value z and returning the final carry. +// It can be thought of as one row of a pen-and-paper column multiplication. +func addMulVVW(z, x []uint, y uint) (carry uint) { + _ = x[len(z)-1] // bounds check elimination hint + for i := range z { + hi, lo := bits.Mul(x[i], y) + lo, c := bits.Add(lo, z[i], 0) + // We use bits.Add with zero to get an add-with-carry instruction that + // absorbs the carry from the previous bits.Add. + hi, _ = bits.Add(hi, 0, c) + lo, c = bits.Add(lo, carry, 0) + hi, _ = bits.Add(hi, 0, c) + carry = hi + z[i] = lo + } + return carry +} + // Mul calculates x = x * y mod m. // // The length of both operands must be the same as the modulus. Both operands diff --git a/src/crypto/internal/bigmod/nat_generic.go b/src/crypto/internal/bigmod/nat_generic.go deleted file mode 100644 index a44d2ec5481..00000000000 --- a/src/crypto/internal/bigmod/nat_generic.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !wasm - -package bigmod - -import "math/bits" - -// addMulVVW multiplies the multi-word value x by the single-word value y, -// adding the result to the multi-word value z and returning the final carry. -// It can be thought of as one row of a pen-and-paper column multiplication. -func addMulVVW(z, x []uint, y uint) (carry uint) { - _ = x[len(z)-1] // bounds check elimination hint - for i := range z { - hi, lo := bits.Mul(x[i], y) - lo, c := bits.Add(lo, z[i], 0) - // We use bits.Add with zero to get an add-with-carry instruction that - // absorbs the carry from the previous bits.Add. - hi, _ = bits.Add(hi, 0, c) - lo, c = bits.Add(lo, carry, 0) - hi, _ = bits.Add(hi, 0, c) - carry = hi - z[i] = lo - } - return carry -} diff --git a/src/crypto/internal/bigmod/nat_noasm.go b/src/crypto/internal/bigmod/nat_noasm.go index 2501a6fb4ce..dbec229f5d2 100644 --- a/src/crypto/internal/bigmod/nat_noasm.go +++ b/src/crypto/internal/bigmod/nat_noasm.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build purego || !(386 || amd64 || arm || arm64 || loong64 || ppc64 || ppc64le || riscv64 || s390x) +//go:build purego || !(386 || amd64 || arm || arm64 || loong64 || ppc64 || ppc64le || riscv64 || s390x || wasm) package bigmod diff --git a/src/crypto/internal/bigmod/nat_wasm.go b/src/crypto/internal/bigmod/nat_wasm.go index 81ffdb286f6..b4aaff74cf0 100644 --- a/src/crypto/internal/bigmod/nat_wasm.go +++ b/src/crypto/internal/bigmod/nat_wasm.go @@ -2,25 +2,30 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build !purego + package bigmod +import "unsafe" + // The generic implementation relies on 64x64->128 bit multiplication and // 64-bit add-with-carry, which are compiler intrinsics on many architectures. // Wasm doesn't support those. Here we implement it with 32x32->64 bit // operations, which is more efficient on Wasm. -// addMulVVW multiplies the multi-word value x by the single-word value y, -// adding the result to the multi-word value z and returning the final carry. -// It can be thought of as one row of a pen-and-paper column multiplication. -func addMulVVW(z, x []uint, y uint) (carry uint) { +func idx(x *uint, i uintptr) *uint { + return (*uint)(unsafe.Pointer(uintptr(unsafe.Pointer(x)) + i*8)) +} + +func addMulVVWWasm(z, x *uint, y uint, n uintptr) (carry uint) { const mask32 = 1<<32 - 1 y0 := y & mask32 y1 := y >> 32 - _ = x[len(z)-1] // bounds check elimination hint - for i, zi := range z { - xi := x[i] + for i := range n { + xi := *idx(x, i) x0 := xi & mask32 x1 := xi >> 32 + zi := *idx(z, i) z0 := zi & mask32 z1 := zi >> 32 c0 := carry & mask32 @@ -38,7 +43,19 @@ func addMulVVW(z, x []uint, y uint) (carry uint) { h10 := w10 >> 32 carry = x1*y1 + h10 + h01 - z[i] = w10<<32 + l00 + *idx(z, i) = w10<<32 + l00 } return carry } + +func addMulVVW1024(z, x *uint, y uint) (c uint) { + return addMulVVWWasm(z, x, y, 1024/_W) +} + +func addMulVVW1536(z, x *uint, y uint) (c uint) { + return addMulVVWWasm(z, x, y, 1536/_W) +} + +func addMulVVW2048(z, x *uint, y uint) (c uint) { + return addMulVVWWasm(z, x, y, 2048/_W) +}