From 4db079990e0e74e93c8e2dab8a0e6a764e4be06b Mon Sep 17 00:00:00 2001 From: gram-signal <84339875+gram-signal@users.noreply.github.com> Date: Thu, 8 May 2025 16:05:09 -0700 Subject: [PATCH] Clear send keys for old epochs when we start sending in a new one. --- src/chain.rs | 34 ++++++++++++++++++++++++++++++++++ src/lib.rs | 2 ++ src/proto/pq_ratchet.proto | 1 + 3 files changed, 37 insertions(+) diff --git a/src/chain.rs b/src/chain.rs index 8c134bd..749b683 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -33,6 +33,7 @@ struct ChainEpoch { pub struct Chain { dir: Direction, current_epoch: Epoch, + send_epoch: Epoch, // links.len() <= EPOCHS_TO_KEEP links: VecDeque, // stores [link[current_epoch-N] .. link[current_epoch]] // next_root.len() == 32 @@ -149,6 +150,7 @@ impl ChainEpochDirection { fn next_key_internal(next: &mut [u8], ctr: &mut u32) -> (u32, [u8; 32]) { hax_lib::fstar!("admit()"); + assert!(!next.is_empty()); *ctr += 1; let mut gen = [0u8; 64]; kdf::hkdf_to_slice( @@ -217,6 +219,10 @@ impl ChainEpochDirection { prev: KeyHistory { data: pb.prev }, }) } + + fn clear_next(&mut self) { + self.next.clear(); + } } #[hax_lib::attributes] @@ -241,6 +247,7 @@ impl Chain { Self { dir, current_epoch: 0, + send_epoch: 0, links: VecDeque::from([ChainEpoch { send: Self::ced_for_direction(&gen, &dir), recv: Self::ced_for_direction(&gen, &dir.switch()), @@ -284,7 +291,16 @@ impl Chain { pub fn send_key(&mut self, epoch: Epoch) -> Result<(u32, Vec), Error> { hax_lib::fstar!("admit ()"); + if epoch < self.send_epoch { + return Err(Error::SendKeyEpochDecreased(self.send_epoch, epoch)); + } let epoch_index = self.epoch_idx(epoch)?; + if self.send_epoch != epoch { + self.send_epoch = epoch; + for i in 0..epoch_index { + self.links[i].send.clear_next(); + } + } Ok(self.links[epoch_index].send.next_key()) } @@ -299,6 +315,7 @@ impl Chain { pqrpb::Chain { a2b: matches!(self.dir, Direction::A2B), current_epoch: self.current_epoch, + send_epoch: self.send_epoch, links: self .links .into_iter() @@ -320,6 +337,7 @@ impl Chain { Direction::B2A }, current_epoch: pb.current_epoch, + send_epoch: pb.send_epoch, next_root: pb.next_root, links: pb .links @@ -403,4 +421,20 @@ mod test { assert_eq!(b2a.recv_key(0, idx).unwrap(), key); } } + + #[test] + fn clear_old_send_keys() { + let mut a2b = Chain::new(b"1", Direction::A2B); + a2b.send_key(0).unwrap(); + a2b.send_key(0).unwrap(); + a2b.add_epoch(EpochSecret { + epoch: 1, + secret: vec![2], + }); + a2b.send_key(1).unwrap(); + assert!(matches!( + a2b.send_key(0).unwrap_err(), + Error::SendKeyEpochDecreased(1, 0) + )); + } } diff --git a/src/lib.rs b/src/lib.rs index 04a6b8c..10a503b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -95,6 +95,8 @@ pub enum Error { KeyAlreadyRequested(u32), #[error("Erroneous data received from remote party")] ErroneousDataReceived, + #[error("Send key epoch decreased ({0} -> {1})")] + SendKeyEpochDecreased(u64, u64), } impl From for Error { diff --git a/src/proto/pq_ratchet.proto b/src/proto/pq_ratchet.proto index 743d41d..770962f 100644 --- a/src/proto/pq_ratchet.proto +++ b/src/proto/pq_ratchet.proto @@ -205,4 +205,5 @@ message Chain { uint64 current_epoch = 2; repeated Epoch links = 3; bytes next_root = 4; + uint64 send_epoch = 5; }