diff --git a/Cargo.lock b/Cargo.lock index 11b1a8d..8e75610 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -605,9 +605,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.13.5" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +checksum = "7231bd9b3d3d33c86b58adbac74b5ec0ad9f496b19d22801d773636feaa95f3d" dependencies = [ "bytes", "prost-derive", @@ -615,9 +615,9 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.13.5" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" +checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1" dependencies = [ "heck", "itertools", @@ -635,9 +635,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.13.5" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" dependencies = [ "anyhow", "itertools", @@ -648,9 +648,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.13.5" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" +checksum = "b9b4db3d6da204ed77bb26ba83b6122a73aeb2e87e25fbf7ad2e84c4ccbf8f72" dependencies = [ "prost", ] @@ -892,6 +892,7 @@ dependencies = [ "rand_distr", "sha2", "sorted-vec", + "spqr", "thiserror", ] diff --git a/Cargo.toml b/Cargo.toml index 112c7d6..cb4f642 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ libcrux-hmac = "0.0.3" libcrux-ml-kem = { version = "0.0.3", default-features = false, features = ["incremental", "mlkem768"] } log = "0.4.21" num_enum = "0.7.3" -prost = "0.13.1" +prost = "0.14.1" rand = "0.9" rand_core = "0.9" sha2 = "0.10" @@ -25,6 +25,9 @@ sorted-vec = "0.8.6" thiserror = "2.0.11" [dev-dependencies] +# When built directly, auto-enable the test-utils feature. +spqr = { path = ".", features = ["test-utils"] } + galois_field_2pm = "0.1.0" hmac = "0.12.1" matches = "0.1.10" @@ -32,10 +35,11 @@ rand_08 = { package = "rand", version = "0.8" } rand_distr = "0.5.1" [build-dependencies] -prost-build = "0.13.1" +prost-build = "0.14.1" [features] proof = [] +test-utils = [] [target.'cfg(not(any(windows, target_arch = "x86")))'.dependencies] # sha2's asm implementation uses standalone .S files that aren't compiled correctly on Windows, diff --git a/benches/chain.rs b/benches/chain.rs index 62f14a9..e594005 100644 --- a/benches/chain.rs +++ b/benches/chain.rs @@ -12,7 +12,7 @@ mod tests { #[bench] fn add_epoch(b: &mut Bencher) { - let mut c = chain::Chain::new(b"1", Direction::A2B, ChainParams::default().into_pb()) + let mut c = chain::Chain::new(b"1", Direction::A2B, ChainParams::default().into_pb_test()) .expect("should be valid"); let mut e: u64 = 0; b.iter(|| { @@ -28,7 +28,7 @@ mod tests { #[bench] fn send_key(b: &mut Bencher) { - let mut c = chain::Chain::new(b"1", Direction::A2B, ChainParams::default().into_pb()) + let mut c = chain::Chain::new(b"1", Direction::A2B, ChainParams::default().into_pb_test()) .expect("should be valid"); b.iter(|| { // Inner closure, the actual test @@ -38,7 +38,7 @@ mod tests { #[bench] fn recv_key(b: &mut Bencher) { - let mut c = chain::Chain::new(b"1", Direction::A2B, ChainParams::default().into_pb()) + let mut c = chain::Chain::new(b"1", Direction::A2B, ChainParams::default().into_pb_test()) .expect("should be valid"); let mut k: u32 = 0; b.iter(|| { @@ -50,7 +50,7 @@ mod tests { #[bench] fn recv_skip_key(b: &mut Bencher) { - let mut c = chain::Chain::new(b"1", Direction::A2B, ChainParams::default().into_pb()) + let mut c = chain::Chain::new(b"1", Direction::A2B, ChainParams::default().into_pb_test()) .expect("should be valid"); let mut k: u32 = 0; b.iter(|| { @@ -63,7 +63,7 @@ mod tests { #[bench] fn recv_with_truncate(b: &mut Bencher) { - let mut c = chain::Chain::new(b"1", Direction::A2B, ChainParams::default().into_pb()) + let mut c = chain::Chain::new(b"1", Direction::A2B, ChainParams::default().into_pb_test()) .expect("should be valid"); let mut k: u32 = 0; b.iter(|| { diff --git a/benches/polynomial.rs b/benches/polynomial.rs index 374e501..59f263f 100644 --- a/benches/polynomial.rs +++ b/benches/polynomial.rs @@ -66,9 +66,9 @@ mod tests { #[bench] fn encoder_from_pb(b: &mut Bencher) { let encoder = PolyEncoder::encode_bytes(&[3u8; 1088]).expect("encode_bytes"); - let bytes = encoder.into_pb().encode_to_vec(); + let bytes = encoder.into_pb_test().encode_to_vec(); b.iter(|| { - black_box(PolyEncoder::from_pb( + black_box(PolyEncoder::from_pb_test( pqrpb::PolynomialEncoder::decode(bytes.as_slice()).unwrap(), )) }); @@ -82,9 +82,9 @@ mod tests { for i in 1..chunks_needed { decoder.add_chunk(&encoder.chunk_at(i)); } - let bytes = decoder.into_pb().encode_to_vec(); + let bytes = decoder.into_pb_test().encode_to_vec(); b.iter(|| { - black_box(PolyDecoder::from_pb( + black_box(PolyDecoder::from_pb_test( pqrpb::PolynomialDecoder::decode(bytes.as_slice()).unwrap(), )) }); diff --git a/src/chain.rs b/src/chain.rs index d509323..bed6cbf 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -37,7 +37,7 @@ const DEFAULT_CHAIN_PARAMS: ChainParams = ChainParams { }; impl ChainParams { - pub fn into_pb(self) -> ChainParamsPB { + pub(crate) fn into_pb(self) -> ChainParamsPB { ChainParamsPB { max_jump: if self.max_jump == DEFAULT_CHAIN_PARAMS.max_jump { 0 @@ -51,6 +51,13 @@ impl ChainParams { }, } } + + /// Public wrapper for test utilities and benchmarks. + /// For internal use, call `into_pb` directly. + #[cfg(feature = "test-utils")] + pub fn into_pb_test(self) -> ChainParamsPB { + self.into_pb() + } } impl ChainParamsPB { @@ -392,7 +399,7 @@ impl Chain { } #[hax_lib::opaque] // into_iter and map - pub fn into_pb(self) -> pqrpb::Chain { + pub(crate) fn into_pb(self) -> pqrpb::Chain { pqrpb::Chain { direction: self.dir.into(), current_epoch: self.current_epoch, @@ -411,7 +418,7 @@ impl Chain { } #[hax_lib::opaque] // into_iter and map - pub fn from_pb(pb: pqrpb::Chain) -> Result { + pub(crate) fn from_pb(pb: pqrpb::Chain) -> Result { Ok(Self { dir: pb.direction.try_into().map_err(|_| Error::StateDecode)?, current_epoch: pb.current_epoch, diff --git a/src/encoding.rs b/src/encoding.rs index 3f32714..926e0ef 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -5,8 +5,6 @@ pub mod gf; pub mod polynomial; pub mod round_robin; -use crate::proto::pq_ratchet as pqrpb; - #[derive(Debug, thiserror::Error, Copy, Clone, PartialEq)] pub enum EncodingError { #[error("Polynomial error: {0}")] @@ -29,28 +27,6 @@ pub struct Chunk { pub data: [u8; 32], } -impl Chunk { - pub fn into_pb(self) -> pqrpb::Chunk { - pqrpb::Chunk { - index: self.index as u32, - data: self.data[..].to_vec(), - } - } - - pub fn from_pb(pb: pqrpb::Chunk) -> Result { - Ok(Self { - index: pb - .index - .try_into() - .map_err(|_| EncodingError::ChunkIndexDecodingError)?, - data: pb - .data - .as_slice() - .try_into() - .map_err(|_| EncodingError::ChunkDataDecodingError)?, - }) - } -} #[hax_lib::attributes] pub trait Encoder { diff --git a/src/encoding/gf.rs b/src/encoding/gf.rs index d4ecdfc..81112c3 100644 --- a/src/encoding/gf.rs +++ b/src/encoding/gf.rs @@ -541,10 +541,6 @@ impl GF16 { Self { value } } - pub fn inv(&self) -> GF16 { - GF16::ONE.div_impl(self) - } - fn div_impl(&self, other: &Self) -> Self { // Within GF(p^n), inv(a) == a^(p^n-2). We're GF(2^16) == GF(65536), // so we can compute GF(65534). @@ -568,12 +564,6 @@ impl GF16 { } } - pub const fn const_add(&self, other: &Self) -> Self { - Self { - value: self.value ^ other.value, - } - } - pub const fn const_div(&self, other: &Self) -> Self { // Within GF(p^n), inv(a) == a^(p^n-2). We're GF(2^16) == GF(65536), // so we can compute GF(65534). @@ -593,9 +583,6 @@ impl GF16 { out } - pub const fn const_inv(&self) -> GF16 { - GF16::ONE.const_div(self) - } } #[cfg(test)] @@ -632,17 +619,7 @@ mod test { assert_eq!(a, b); } } - #[test] - fn inv() { - let mut rng = rand::rng(); - for _i in 0..100 { - let x = rng.next_u32() as u16; - assert_eq!( - (GF16 { value: x } * GF16 { value: x }.inv()).value, - GF16::ONE.value - ); - } - } + #[test] fn div() { let mut rng = rand::rng(); diff --git a/src/encoding/polynomial.rs b/src/encoding/polynomial.rs index 0989172..8308565 100644 --- a/src/encoding/polynomial.rs +++ b/src/encoding/polynomial.rs @@ -64,10 +64,14 @@ impl PartialEq for Pt { } // The highest degree polynomial that will be stored for Protocol V1 +// Used in hax_lib annotations +#[allow(dead_code)] pub const MAX_STORED_POLYNOMIAL_DEGREE_V1: usize = 35; // The highest degree polynomial that will be constructed in intermediate // calculations for Protocol V1 +// Used in hax_lib annotations +#[allow(dead_code)] pub const MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1: usize = 36; #[derive(Clone, PartialEq)] @@ -523,7 +527,7 @@ impl PolyEncoder { EncoderState::Polys(polys) => hax_lib::Prop::from(polys.len() == 16).and(hax_lib::prop::forall(|poly: &Poly| hax_lib::prop::implies(polys.contains(poly), poly.coefficients.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1))) })] - pub fn into_pb(self) -> proto::pq_ratchet::PolynomialEncoder { + pub(crate) fn into_pb(self) -> proto::pq_ratchet::PolynomialEncoder { let mut out = proto::pq_ratchet::PolynomialEncoder { idx: self.idx, pts: Vec::with_capacity(16), @@ -555,7 +559,7 @@ impl PolyEncoder { out } - pub fn from_pb(pb: proto::pq_ratchet::PolynomialEncoder) -> Result { + pub(crate) fn from_pb(pb: proto::pq_ratchet::PolynomialEncoder) -> Result { let s = if !pb.pts.is_empty() { if !pb.polys.is_empty() { return Err(PolynomialError::SerializationInvalid); @@ -679,6 +683,20 @@ impl PolyEncoder { data: (&out[..]).try_into().expect("should be exactly 32 bytes"), } } + + /// Public wrapper for test utilities and benchmarks. + /// For internal use, call `into_pb` directly. + #[cfg(feature = "test-utils")] + pub fn into_pb_test(self) -> proto::pq_ratchet::PolynomialEncoder { + self.into_pb() + } + + /// Public wrapper for test utilities and benchmarks. + /// For internal use, call `from_pb` directly. + #[cfg(feature = "test-utils")] + pub fn from_pb_test(pb: proto::pq_ratchet::PolynomialEncoder) -> Result { + Self::from_pb(pb) + } } #[hax_lib::attributes] @@ -724,6 +742,8 @@ pub struct PolyDecoder { #[hax_lib::attributes] impl PolyDecoder { + // Used in hax_lib annotations + #[allow(dead_code)] pub fn get_pts_needed(&self) -> usize { self.pts_needed } @@ -749,7 +769,7 @@ impl PolyDecoder { }) } - pub fn into_pb(self) -> proto::pq_ratchet::PolynomialDecoder { + pub(crate) fn into_pb(self) -> proto::pq_ratchet::PolynomialDecoder { let mut out = proto::pq_ratchet::PolynomialDecoder { pts_needed: self.pts_needed as u32, polys: 16, @@ -769,7 +789,7 @@ impl PolyDecoder { out } - pub fn from_pb(pb: proto::pq_ratchet::PolynomialDecoder) -> Result { + pub(crate) fn from_pb(pb: proto::pq_ratchet::PolynomialDecoder) -> Result { if pb.pts.len() != 16 { return Err(PolynomialError::SerializationInvalid); } @@ -791,6 +811,20 @@ impl PolyDecoder { } Ok(out) } + + /// Public wrapper for test utilities and benchmarks. + /// For internal use, call `into_pb` directly. + #[cfg(feature = "test-utils")] + pub fn into_pb_test(self) -> proto::pq_ratchet::PolynomialDecoder { + self.into_pb() + } + + /// Public wrapper for test utilities and benchmarks. + /// For internal use, call `from_pb` directly. + #[cfg(feature = "test-utils")] + pub fn from_pb_test(pb: proto::pq_ratchet::PolynomialDecoder) -> Result { + Self::from_pb(pb) + } } #[hax_lib::attributes] diff --git a/src/lib.rs b/src/lib.rs index 1dca577..cdca453 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,21 +1,36 @@ // Copyright 2025 Signal Messenger, LLC // SPDX-License-Identifier: AGPL-3.0-only -pub mod authenticator; +pub(crate) mod authenticator; +#[cfg(feature = "test-utils")] pub mod chain; +#[cfg(not(feature = "test-utils"))] +pub(crate) mod chain; +#[cfg(feature = "test-utils")] pub mod encoding; +#[cfg(not(feature = "test-utils"))] +pub(crate) mod encoding; pub(crate) mod incremental_mlkem768; pub(crate) mod kdf; -pub mod proto; -pub mod serialize; +pub(crate) mod serialize; pub(crate) mod test; pub(crate) mod util; mod v1; +#[cfg(feature = "test-utils")] +pub mod proto; +#[cfg(not(feature = "test-utils"))] +mod proto; + use crate::chain::Chain; pub use crate::chain::ChainParams; use crate::proto::pq_ratchet as pqrpb; pub use crate::proto::pq_ratchet::{Direction, Version}; +// Re-export error types that are part of the public Error enum +pub use crate::authenticator::Error as AuthenticatorError; +pub use crate::encoding::polynomial::PolynomialError; +pub use crate::encoding::EncodingError; +pub use crate::serialize::Error as SerializationError; use prost::Message; use rand::{CryptoRng, Rng}; use std::cmp::Ordering;