Skip to content

Commit

Permalink
more efficient caching
Browse files Browse the repository at this point in the history
  • Loading branch information
Hannah Davis committed Sep 16, 2024
1 parent 130622a commit 69749ff
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/bt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ type SubTree<V> = Option<Box<Node<V>>>;

/// Represents a node of a binary tree.
pub struct Node<V> {
value: V,
pub(crate) value: V,
left: SubTree<V>,
right: SubTree<V>,
}
Expand Down
12 changes: 12 additions & 0 deletions src/idpf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ impl IdpfInput {
}
}

/// Create a new empty IDPF input.
pub fn empty_input() -> IdpfInput {
IdpfInput::from_bools(&[])
}

/// Get the length of the input in bits.
pub fn len(&self) -> usize {
self.index.len()
Expand Down Expand Up @@ -111,6 +116,13 @@ impl IdpfInput {
}
}

/// Return the single bit of this IDPF input at the given level.
pub fn next_branch(&self, level: usize) -> Self {
Self {
index: self.index[level - 1..level].to_owned().into(),
}
}

/// Return the bit at the specified level if the level is in bounds.
pub fn get(&self, level: usize) -> Option<bool> {
self.index.get(level).as_deref().copied()
Expand Down
74 changes: 49 additions & 25 deletions src/vidpf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::io::{Cursor, Read};
use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq};

use crate::{
bt::{BinaryTree, BinaryTreeError},
bt::{BinaryTree},
codec::{CodecError, Decode, Encode, ParameterizedDecode},
field::FieldElement,
idpf::{
Expand Down Expand Up @@ -60,10 +60,6 @@ pub enum VidpfError {
/// Failure when calling getrandom().
#[error("getrandom: {0}")]
GetRandom(#[from] getrandom::Error),

/// Failure when caching VIDPF evaluation.
#[error("cache tree: {0}")]
BinaryTreeError(#[from] BinaryTreeError<V>),
}

/// Represents the domain of an incremental point function.
Expand All @@ -79,7 +75,7 @@ pub struct Vidpf<W: VidpfValue, const NONCE_SIZE: usize> {
pub(crate) weight_parameter: W::ValueParameter,
}

impl<'a, W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {
impl<W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {
/// Creates a VIDPF instance.
///
/// # Arguments
Expand Down Expand Up @@ -248,27 +244,49 @@ impl<'a, W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {
return Err(VidpfError::InvalidAttributeLength);
}

let mut state = VidpfEvalState::init_from_key(key);
let state = VidpfEvalState::init_from_key(key);
let path = input;
match cache_tree.get(input.prefix(0).index.as_bitslice()) {
Some(_) => Ok(()),
None => cache_tree.insert(IdpfInput::from_bytes(&[]).index.as_bitslice(), self.eval_next_cached(key.id, public, input, 0, &state, nonce)?),
}?;
Some(_) => (),
None => cache_tree
.insert(
IdpfInput::empty_input().index.as_bitslice(),
self.eval_next_cached(key.id, public, input, 0, &state, nonce)?,
)
.expect("inserting into top of empty tree"),
};

let mut cache_node = cache_tree
.get_node(IdpfInput::empty_input().index.as_bitslice())
.expect("previous match statement ensures initialization");

for level in 1..n {
if let Some(next_cache) = cache_tree.get(input.prefix(level).index.as_bitslice()){
();
if cache_node
.get(path.next_branch(level).index.as_bitslice())
.is_some()
{
cache_node = cache_node
.get_node(path.next_branch(level).index.as_bitslice())
.expect("existence of node ensured by above condition");
} else {
let current_cache = cache_tree.get(input.prefix(level-1).index.as_bitslice()).ok_or(VidpfError::BinaryTreeError)?;
let (state, share) = self.eval_next(key.id, public, input, level, &current_cache.state, nonce)?;
let cache = cache_node
.get(IdpfInput::empty_input().index.as_bitslice())
.expect("current node initialized by previous loop iteration");
let (state, share) =
self.eval_next(key.id, public, input, level, &cache.state, nonce)?;
let next_cache = VidpfEvalCache::<W>::init_from_state(state, share);
cache_tree.insert(input.prefix(level).index.as_bitslice(), next_cache)?;
cache_node
.insert(input.next_branch(level).index.as_bitslice(), next_cache)
.expect("current node initialized by previous loop iteration");
cache_node = cache_node
.get_node(path.next_branch(level).index.as_bitslice())
.expect("node was inserted by previous statement");
}
}
let final_cache = cache_tree.get(input.index.as_bitslice()).ok_or(VidpfError::BinaryTreeError)?;
Ok(VidpfValueShare {
share: final_cache.share,
proof: final_cache.state.proof,
})
let final_cache = cache_node
.get(IdpfInput::empty_input().index.as_bitslice())
.expect("node was inserted by last loop iteration");
Ok(final_cache.to_share())
}

/// [`Vidpf::eval_next`] evaluates the `input` at the given level using the provided initial
Expand Down Expand Up @@ -321,7 +339,7 @@ impl<'a, W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {
Ok((next_state, y))
}

/// [`Vidpf::eval_next_cached`] evaluates the `input` at the given level using the provided initial
/// [`Vidpf::eval_next_cached`] evaluates the `input` at the given level using the provided initial
/// state, and returns a cache containing a new state and a share of the input's weight at that level.
fn eval_next_cached(
&self,
Expand Down Expand Up @@ -610,6 +628,7 @@ impl<W: VidpfValue> ParameterizedDecode<(usize, W::ValueParameter)> for VidpfPub
/// Vidpf evaluation state
///
/// Contains the values produced during input evaluation at a given level.
#[derive(Debug)]
pub struct VidpfEvalState {
seed: VidpfSeed,
control_bit: Choice,
Expand All @@ -629,16 +648,21 @@ impl VidpfEvalState {
/// Vidpf evaluation cache
///
/// Contains the values produced during input evaluation at a given level.
#[derive(Debug)]
pub struct VidpfEvalCache<W: VidpfValue> {
state: VidpfEvalState,
share: W,
}

impl<W: VidpfValue> VidpfEvalCache<W> {
fn init_from_state(state: VidpfEvalState, share: W) -> Self {
Self {
state,
share,
Self { state, share }
}

fn to_share(&self) -> VidpfValueShare<W> {
VidpfValueShare::<W> {
share: self.share.clone(),
proof: self.state.proof,
}
}
}
Expand Down Expand Up @@ -864,7 +888,7 @@ mod tests {
.gen_with_keys(&keys_with_same_id, &input, &weight, TEST_NONCE)
.unwrap_err();

assert_eq!(err.to_string(), VidpfError::TestWeight::SameKeyId.to_string());
assert_eq!(err.to_string(), VidpfError::SameKeyId.to_string());
}

#[test]
Expand Down

0 comments on commit 69749ff

Please sign in to comment.