diff --git a/Cargo.lock b/Cargo.lock index 7e2d786..86da17c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1077,7 +1077,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "0.6.4" +version = "1.0.0" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index b365167..a79aea9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "0.6.4" +version = "1.0.0" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/benches/benchmark.rs b/benches/benchmark.rs index 0bc0802..add9db8 100644 --- a/benches/benchmark.rs +++ b/benches/benchmark.rs @@ -35,7 +35,7 @@ pub(crate) fn next_states(inf: &FSRS) -> NextStates { .unwrap() } -pub(crate) fn optimal_retention(inf: &FSRS, config: &SimulatorConfig) -> f64 { +pub(crate) fn optimal_retention(inf: &FSRS, config: &SimulatorConfig) -> f32 { inf.optimal_retention(config, &[], |_v| true).unwrap() } @@ -63,7 +63,7 @@ pub fn criterion_benchmark(c: &mut Criterion) { let config = SimulatorConfig { deck_size: 3650, learn_span: 365, - max_cost_perday: f64::INFINITY, + max_cost_perday: f32::INFINITY, learn_limit: 10, loss_aversion: 1.0, ..Default::default() diff --git a/src/batch_shuffle.rs b/src/batch_shuffle.rs index 5f6cc0a..389ae68 100644 --- a/src/batch_shuffle.rs +++ b/src/batch_shuffle.rs @@ -280,164 +280,204 @@ mod tests { [ FSRSItem { reviews: vec![ + FSRSReview { + rating: 1, + delta_t: 0 + }, FSRSReview { rating: 4, - delta_t: 0, + delta_t: 0 }, FSRSReview { - rating: 3, - delta_t: 1, + rating: 4, + delta_t: 0 }, FSRSReview { - rating: 3, - delta_t: 3, + rating: 4, + delta_t: 1 } ] }, FSRSItem { reviews: vec![ FSRSReview { - rating: 1, - delta_t: 0, + rating: 4, + delta_t: 0 }, FSRSReview { - rating: 3, - delta_t: 1, + rating: 4, + delta_t: 0 + }, + FSRSReview { + rating: 4, + delta_t: 1 }, FSRSReview { rating: 3, - delta_t: 3, + delta_t: 2 } - ], + ] }, FSRSItem { reviews: vec![ FSRSReview { - rating: 1, - delta_t: 0, + rating: 4, + delta_t: 0 }, FSRSReview { - rating: 3, - delta_t: 1, + rating: 4, + delta_t: 0 + }, + FSRSReview { + rating: 4, + delta_t: 1 }, FSRSReview { rating: 3, - delta_t: 3, + delta_t: 1 } - ], + ] }, FSRSItem { reviews: vec![ FSRSReview { rating: 1, - delta_t: 0, + delta_t: 0 }, FSRSReview { - rating: 3, - delta_t: 1, + rating: 4, + delta_t: 0 }, FSRSReview { - rating: 3, - delta_t: 3, + rating: 4, + delta_t: 0 + }, + FSRSReview { + rating: 4, + delta_t: 1 } - ], + ] }, FSRSItem { reviews: vec![ FSRSReview { - rating: 1, - delta_t: 0, + rating: 4, + delta_t: 0 }, FSRSReview { - rating: 3, - delta_t: 1, + rating: 4, + delta_t: 0 + }, + FSRSReview { + rating: 4, + delta_t: 1 }, FSRSReview { rating: 3, - delta_t: 3, + delta_t: 1 } - ], + ] }, FSRSItem { reviews: vec![ + FSRSReview { + rating: 1, + delta_t: 0 + }, FSRSReview { rating: 4, - delta_t: 0, + delta_t: 0 }, FSRSReview { rating: 4, - delta_t: 3, + delta_t: 0 }, FSRSReview { - rating: 3, - delta_t: 10, + rating: 4, + delta_t: 1 } - ], + ] }, FSRSItem { reviews: vec![ FSRSReview { rating: 4, - delta_t: 0, + delta_t: 0 + }, + FSRSReview { + rating: 4, + delta_t: 0 }, FSRSReview { rating: 4, - delta_t: 1, + delta_t: 1 }, FSRSReview { rating: 3, - delta_t: 4, + delta_t: 3 } - ], + ] }, FSRSItem { reviews: vec![ FSRSReview { rating: 4, - delta_t: 0, + delta_t: 0 + }, + FSRSReview { + rating: 4, + delta_t: 0 }, FSRSReview { rating: 4, - delta_t: 1, + delta_t: 1 }, FSRSReview { rating: 3, - delta_t: 4, + delta_t: 1 } - ], + ] }, FSRSItem { reviews: vec![ FSRSReview { rating: 4, - delta_t: 0, + delta_t: 0 }, FSRSReview { - rating: 3, - delta_t: 5, + rating: 4, + delta_t: 0 + }, + FSRSReview { + rating: 4, + delta_t: 1 }, FSRSReview { rating: 3, - delta_t: 11, + delta_t: 2 } - ], + ] }, FSRSItem { reviews: vec![ + FSRSReview { + rating: 3, + delta_t: 0 + }, FSRSReview { rating: 4, - delta_t: 0, + delta_t: 0 }, FSRSReview { - rating: 3, - delta_t: 1, + rating: 4, + delta_t: 0 }, FSRSReview { - rating: 3, - delta_t: 3, + rating: 4, + delta_t: 1 } - ], - }, + ] + } ] ); } diff --git a/src/convertor_tests.rs b/src/convertor_tests.rs index a426af1..2c732d9 100644 --- a/src/convertor_tests.rs +++ b/src/convertor_tests.rs @@ -1,8 +1,7 @@ -use std::collections::HashMap; - use crate::convertor_tests::RevlogReviewKind::*; use crate::dataset::FSRSBatcher; use crate::dataset::{FSRSItem, FSRSReview}; +use crate::optimal_retention::{RevlogEntry, RevlogReviewKind}; use crate::test_helpers::NdArrayAutodiff; use burn::backend::ndarray::NdArrayDevice; use burn::data::dataloader::batcher::Batcher; @@ -16,39 +15,6 @@ use rusqlite::Connection; use rusqlite::{Result, Row}; use serde::{Deserialize, Serialize}; -#[derive(Clone, Copy, Debug, Default, PartialEq)] -pub struct RevlogEntry { - pub id: i64, - pub cid: i64, - pub usn: i32, - /// - In the V1 scheduler, 3 represents easy in the learning case. - /// - 0 represents manual rescheduling. - pub button_chosen: u8, - /// Positive values are in days, negative values in seconds. - pub interval: i32, - /// Positive values are in days, negative values in seconds. - pub last_interval: i32, - /// Card's ease after answering, stored as 10x the %, eg 2500 represents - /// 250%. - pub ease_factor: u32, - /// Amount of milliseconds taken to answer the card. - pub taken_millis: u32, - pub review_kind: RevlogReviewKind, -} - -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] -pub enum RevlogReviewKind { - #[default] - Learning = 0, - Review = 1, - Relearning = 2, - /// Old Anki versions called this "Cram" or "Early", and assigned it when - /// reviewing cards ahead. It is now only used for filtered decks with - /// rescheduling disabled. - Filtered = 3, - Manual = 4, -} - impl rusqlite::types::FromSql for RevlogReviewKind { fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult { let rusqlite::types::ValueRef::Integer(i) = value else { @@ -82,280 +48,6 @@ impl TryFrom<&Row<'_>> for RevlogEntry { } } -#[derive(Debug, PartialEq)] -struct SimulatorConfig { - learn_costs: Vec, - review_costs: Vec, - learn_buttons: Vec, - review_buttons: Vec, - first_rating_prob: Vec, - review_rating_prob: Vec, - first_rating_offset: Vec, - first_session_len: Vec, - forget_rating_offset: f32, - forget_session_len: f32, -} - -fn extract_simulation_config(df: Vec, day_cutoff: i64) -> SimulatorConfig { - /* - def rating_counts(x): - tmp = defaultdict(int, x.value_counts().to_dict()) - first = x.iloc[0] - tmp[first] -= 1 - return tmp - */ - fn rating_counts(entries: &[RevlogEntry]) -> [u32; 4] { - let mut counts = [0; 4]; - - for entry in entries.iter().skip(1) { - counts[entry.button_chosen as usize - 1] += 1; - } - - counts - } - /* - df1 = ( - df[(df["review_duration"] > 0) & (df["review_duration"] < 1200000)] - .groupby(by=["card_id", "real_days"]) - .agg( - { - "review_state": "first", - "review_rating": ["first", rating_counts], - "review_duration": "sum", - } - ) - .reset_index() - ) - */ - struct Df1Row { - card_id: i64, - first_review_state: u8, - first_review_rating: u8, - review_rating_counts: [u32; 4], - sum_review_duration: u32, - } - let df1 = { - let mut grouped_data = HashMap::new(); - for &row in df.iter() { - if row.taken_millis > 0 && row.taken_millis < 1200000 { - let real_days = (row.id / 1000 - day_cutoff) / 86400; - let key = (row.cid.clone(), real_days); - grouped_data.entry(key).or_insert_with(Vec::new).push(row); - } - } - - grouped_data - .into_iter() - .filter_map(|((card_id, _real_days), entries)| { - entries.first().map(|first_entry| { - let first_review_state = first_entry.review_kind as u8 + 1; - let first_review_rating = first_entry.button_chosen; - let review_rating_counts = rating_counts(&entries); - let sum_review_duration = - entries.iter().map(|entry| entry.taken_millis).sum::(); - - Df1Row { - card_id, - first_review_state, - first_review_rating, - review_rating_counts, - sum_review_duration, - } - }) - }) - .collect_vec() - }; - - let cost_dict = { - let mut cost_dict = HashMap::new(); - for row in df1.iter() { - cost_dict - .entry((row.first_review_state, row.first_review_rating)) - .or_insert_with(Vec::new) - .push(row.sum_review_duration); - } - // calculate the median of the sum_review_duration - fn median(x: &mut [u32]) -> u32 { - x.sort_unstable(); - let n = x.len(); - if n % 2 == 0 { - (x[n / 2 - 1] + x[n / 2]) / 2 - } else { - x[n / 2] - } - } - cost_dict - .into_iter() - .map(|(k, mut v)| (k, median(&mut v))) - .collect::>() - }; - - // [cost_dict[(1, i)] / 1000 for i in range(1, 5)] - let learn_costs = (1..5) - .map(|i| cost_dict.get(&(1, i)).map(|x| *x).unwrap_or_default() as f32 / 1000f32) - .collect_vec(); - // [cost_dict[(2, i)] / 1000 for i in range(1, 5)] - let review_costs = (1..5) - .map(|i| cost_dict.get(&(2, i)).map(|x| *x).unwrap_or_default() as f32 / 1000f32) - .collect_vec(); - /* - button_usage_dict = ( - df1.groupby(by=["first_review_state", "first_review_rating"])["card_id"] - .count() - .to_dict() - ) */ - let button_usage_dict = { - let mut button_usage_dict = HashMap::new(); - for row in df1.iter() { - button_usage_dict - .entry((row.first_review_state, row.first_review_rating)) - .or_insert_with(Vec::new) - .push(row.card_id); // is this correct? - } - button_usage_dict - .into_iter() - .map(|(x, y)| (x, y.len() as i64)) - .collect::>() - }; - // [button_usage_dict.get((1, i), 0) for i in range(1, 5)] - let learn_buttons = (1..5) - .map(|i| { - button_usage_dict - .get(&(1, i)) - .map(|x| *x) - .unwrap_or_default() - }) - .collect_vec(); - // [button_usage_dict.get((2, i), 0) for i in range(1, 5)] - let review_buttons = (1..5) - .map(|i| { - button_usage_dict - .get(&(2, i)) - .map(|x| *x) - .unwrap_or_default() - }) - .collect_vec(); - - // self.first_rating_prob = self.learn_buttons / self.learn_buttons.sum() - let first_rating_prob = learn_buttons - .iter() - .map(|x| *x as f32 / learn_buttons.iter().sum::() as f32) - .collect_vec(); - // self.review_buttons[1:] / self.review_buttons[1:].sum() - let review_rating_prob = review_buttons - .iter() - .skip(1) - .map(|x| *x as f32 / review_buttons.iter().skip(1).sum::() as f32) - .collect_vec(); - - // df2 = ( - // df1.groupby(by=["first_review_state", "first_review_rating"])[[1, 2, 3, 4]] - // .mean() - // .round(2) - // ) - - let df2 = { - let mut grouped = HashMap::new(); - for review in df1 { - grouped - .entry((review.first_review_state, review.first_review_rating)) - .or_insert_with(Vec::new) - .push(review); - } - grouped - .iter() - .map(|((state, rating), group)| { - let count = group.len() as f32; - let (sum1, sum2, sum3, sum4) = - group - .iter() - .fold((0, 0, 0, 0), |(sum1, sum2, sum3, sum4), review| { - ( - sum1 + review.review_rating_counts[0], - sum2 + review.review_rating_counts[1], - sum3 + review.review_rating_counts[2], - sum4 + review.review_rating_counts[3], - ) - }); - - let averages = [ - (sum1 as f32 / count * 100.0).round() / 100.0, - (sum2 as f32 / count * 100.0).round() / 100.0, - (sum3 as f32 / count * 100.0).round() / 100.0, - (sum4 as f32 / count * 100.0).round() / 100.0, - ]; - - ((*state, *rating), averages) - }) - .collect::>() - }; - // rating_offset_dict = sum([df2[g] * (g - 3) for g in range(1, 5)]).to_dict() - let rating_offset_dict = { - let mut rating_offset_dict = HashMap::new(); - for (k, averages) in df2.iter() { - let offset = averages - .iter() - .enumerate() - .map(|(i, &v)| ((i + 1) as f32 - 3.0) * v) - .sum::(); - rating_offset_dict.insert(k, (offset * 100.0).round() / 100.0); - } - rating_offset_dict - }; - // session_len_dict = sum([df2[g] for g in range(1, 5)]).to_dict() - let session_len_dict = { - let mut session_len_dict = HashMap::new(); - for (k, averages) in df2.iter() { - let sum = averages.iter().sum::(); - session_len_dict.insert(k, (sum * 100.0).round() / 100.0); - } - session_len_dict - }; - // [rating_offset_dict[(1, i)] for i in range(1, 5)] - let first_rating_offset = (1..5) - .map(|i| { - rating_offset_dict - .get(&(1, i)) - .map(|x| *x) - .unwrap_or_default() - }) - .collect_vec(); - - // [session_len_dict[(1, i)] for i in range(1, 5)] - let first_session_len = (1..5) - .map(|i| { - session_len_dict - .get(&(1, i)) - .map(|x| *x) - .unwrap_or_default() - }) - .collect_vec(); - - // rating_offset_dict[(2, 1)] - let forget_rating_offset = rating_offset_dict - .get(&(2, 1)) - .map(|x| *x) - .unwrap_or_default(); - // session_len_dict[(2, 1)] - let forget_session_len = session_len_dict - .get(&(2, 1)) - .map(|x| *x) - .unwrap_or_default(); - - SimulatorConfig { - learn_costs, - review_costs, - learn_buttons, - review_buttons, - first_rating_prob, - review_rating_prob, - first_rating_offset, - first_session_len, - forget_rating_offset, - forget_session_len, - } -} - fn filter_out_cram(entries: Vec) -> Vec { entries .into_iter() @@ -395,19 +87,6 @@ fn convert_to_date(timestamp: i64, next_day_starts_at: i64, timezone: Tz) -> Nai datetime.date_naive() } -fn keep_first_revlog_same_date( - mut entries: Vec, - next_day_starts_at: i64, - timezone: Tz, -) -> Vec { - let mut unique_dates = std::collections::HashSet::new(); - entries.retain(|entry| { - let date = convert_to_date(entry.id, next_day_starts_at, timezone); - unique_dates.insert(date) - }); - entries -} - /// Given a list of revlog entries for a single card with length n, we create /// n-1 FSRS items, where each item contains the history of the preceding reviews. @@ -419,7 +98,6 @@ fn convert_to_fsrs_items( // entries = filter_out_cram(entries); // entries = filter_out_manual(entries); entries = remove_revlog_before_last_first_learn(entries); - entries = keep_first_revlog_same_date(entries, next_day_starts_at, timezone); for i in 1..entries.len() { let date_current = convert_to_date(entries[i].id, next_day_starts_at, timezone); @@ -443,6 +121,7 @@ fn convert_to_fsrs_items( .collect(); FSRSItem { reviews } }) + .filter(|item| item.current().delta_t > 0) .collect(), ) } @@ -505,7 +184,7 @@ pub(crate) fn anki21_sample_file_converted_to_fsrs() -> Vec { anki_to_fsrs(read_collection().expect("read error")) } -fn read_collection() -> Result> { +pub(crate) fn read_collection() -> Result> { let db = Connection::open("tests/data/collection.anki21")?; let filter_out_suspended_cards = false; let filter_out_flags = []; @@ -555,29 +234,6 @@ fn read_collection() -> Result> { Ok(revlogs) } -#[test] -fn extract_simulator_config_from_revlog() { - let mut revlogs = read_collection().unwrap(); - revlogs.sort_by_cached_key(|r| (r.cid, r.id)); - let day_cutoff = 1720900800; - let simulator_config = extract_simulation_config(revlogs, day_cutoff); - assert_eq!( - simulator_config, - SimulatorConfig { - learn_costs: vec![30.061, 0., 17.298, 12.352], - review_costs: vec![19.139, 6.887, 5.83, 4.002], - learn_buttons: vec![690, 0, 512, 2364], - review_buttons: vec![788, 960, 11767, 331], - first_rating_prob: vec![0.19349411, 0., 0.14357824, 0.66292765], - review_rating_prob: vec![0.07351815, 0.9011334, 0.025348445], - first_rating_offset: vec![1.64, 0., 0.69, 1.11], - first_session_len: vec![2.74, 0., 1.32, 1.19], - forget_rating_offset: 1.28, - forget_session_len: 1.77 - } - ) -} - // This test currently expects the following .anki21 file to be placed in tests/data/: // https://github.com/open-spaced-repetition/fsrs-optimizer-burn/files/12394182/collection.anki21.zip #[test] @@ -592,7 +248,10 @@ fn conversion_works() { let fsrs_items = anki_to_fsrs(revlogs); assert_eq!(fsrs_items.len(), 14290); assert_eq!( - fsrs_items.iter().map(|x| x.reviews.len()).sum::(), + fsrs_items + .iter() + .map(|x| x.long_term_review_cnt() + 1) + .sum::(), 49382 + 14290 ); @@ -611,11 +270,15 @@ fn conversion_works() { rating: 3, delta_t: 0 }, + FSRSReview { + rating: 4, + delta_t: 0 + }, FSRSReview { rating: 3, delta_t: 5 } - ], + ] }, FSRSItem { reviews: vec![ @@ -623,6 +286,10 @@ fn conversion_works() { rating: 3, delta_t: 0 }, + FSRSReview { + rating: 4, + delta_t: 0 + }, FSRSReview { rating: 3, delta_t: 5 @@ -631,7 +298,7 @@ fn conversion_works() { rating: 3, delta_t: 10 } - ], + ] }, FSRSItem { reviews: vec![ @@ -639,6 +306,10 @@ fn conversion_works() { rating: 3, delta_t: 0 }, + FSRSReview { + rating: 4, + delta_t: 0 + }, FSRSReview { rating: 3, delta_t: 5 @@ -651,7 +322,7 @@ fn conversion_works() { rating: 3, delta_t: 22 } - ], + ] }, FSRSItem { reviews: vec![ @@ -659,6 +330,10 @@ fn conversion_works() { rating: 3, delta_t: 0 }, + FSRSReview { + rating: 4, + delta_t: 0 + }, FSRSReview { rating: 3, delta_t: 5 @@ -675,7 +350,7 @@ fn conversion_works() { rating: 2, delta_t: 56 } - ], + ] }, FSRSItem { reviews: vec![ @@ -683,6 +358,10 @@ fn conversion_works() { rating: 3, delta_t: 0 }, + FSRSReview { + rating: 4, + delta_t: 0 + }, FSRSReview { rating: 3, delta_t: 5 @@ -703,7 +382,7 @@ fn conversion_works() { rating: 3, delta_t: 64 } - ], + ] } ] ); @@ -714,11 +393,11 @@ fn conversion_works() { assert_eq!(res.delta_ts.into_scalar(), 64.0); assert_eq!( res.r_historys.squeeze(1).to_data(), - Data::from([3.0, 3.0, 3.0, 3.0, 2.0]) + Data::from([3.0, 4.0, 3.0, 3.0, 3.0, 2.0]) ); assert_eq!( res.t_historys.squeeze(1).to_data(), - Data::from([0.0, 5.0, 10.0, 22.0, 56.0]) + Data::from([0.0, 0.0, 5.0, 10.0, 22.0, 56.0]) ); assert_eq!(res.labels.to_data(), Data::from([1])); } @@ -1198,92 +877,3 @@ fn test_remove_revlog_before_last_first_learn() { ] ); } - -#[test] -fn test_keep_first_revlog_same_date() { - let revlog_vec = vec![ - RevlogEntry { - id: 1581372095493, - cid: 1559076329057, - usn: 5212, - button_chosen: 1, - interval: -60, - last_interval: -60, - ease_factor: 0, - taken_millis: 60000, - review_kind: Learning, - }, - RevlogEntry { - id: 1581372260598, - cid: 1559076329057, - usn: 5212, - button_chosen: 3, - interval: -600, - last_interval: -60, - ease_factor: 0, - taken_millis: 46425, - review_kind: Learning, - }, - RevlogEntry { - id: 1581406251414, - cid: 1559076329057, - usn: 5213, - button_chosen: 2, - interval: -600, - last_interval: -600, - ease_factor: 0, - taken_millis: 17110, - review_kind: Learning, - }, - RevlogEntry { - id: 1581407568344, - cid: 1559076329057, - usn: 5213, - button_chosen: 3, - interval: 1, - last_interval: -600, - ease_factor: 2500, - taken_millis: 8861, - review_kind: Learning, - }, - RevlogEntry { - id: 1581454426227, - cid: 1559076329057, - usn: 5215, - button_chosen: 3, - interval: 3, - last_interval: 1, - ease_factor: 2500, - taken_millis: 13128, - review_kind: Review, - }, - ]; - let revlog_vec = keep_first_revlog_same_date(revlog_vec, 4, Tz::Asia__Shanghai); - assert_eq!( - revlog_vec, - vec![ - RevlogEntry { - id: 1581372095493, - cid: 1559076329057, - usn: 5212, - button_chosen: 1, - interval: -60, - last_interval: -60, - ease_factor: 0, - taken_millis: 60000, - review_kind: Learning, - }, - RevlogEntry { - id: 1581454426227, - cid: 1559076329057, - usn: 5215, - button_chosen: 3, - interval: 3, - last_interval: 1, - ease_factor: 2500, - taken_millis: 13128, - review_kind: Review, - }, - ] - ) -} diff --git a/src/dataset.rs b/src/dataset.rs index 1208ef9..676d432 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -36,6 +36,21 @@ impl FSRSItem { self.reviews.last().unwrap() } + pub(crate) fn long_term_review_cnt(&self) -> usize { + self.reviews + .iter() + .filter(|review| review.delta_t > 0) + .count() + } + + pub(crate) fn first_long_term_review(&self) -> FSRSReview { + self.reviews + .iter() + .find(|review| review.delta_t > 0) + .unwrap() + .clone() + } + pub(crate) fn r_matrix_index(&self) -> (u32, u32, u32) { let delta_t = self.current().delta_t as f64; let delta_t_bin = (2.48 * 3.62f64.powf(delta_t.log(3.62).floor()) * 100.0).round() as u32; @@ -205,18 +220,20 @@ pub fn filter_outlier( } // keep the items in trainset if they are not removed from filtered_items trainset.retain(|item| { - !removed_pairs[item.reviews[0].rating as usize].contains(&item.reviews[1].delta_t) + !removed_pairs[item.reviews[0].rating as usize] + .contains(&item.first_long_term_review().delta_t) }); (filtered_items, trainset) } -pub fn split_filter_data(items: Vec) -> (Vec, Vec) { - let (mut pretrainset, mut trainset) = - items.into_iter().partition(|item| item.reviews.len() == 2); +pub fn prepare_training_data(items: Vec) -> (Vec, Vec) { + let (mut pretrainset, mut trainset) = items + .into_iter() + .partition(|item| item.long_term_review_cnt() == 1); if std::env::var("FSRS_NO_OUTLIER").is_err() { (pretrainset, trainset) = filter_outlier(pretrainset, trainset); } - (pretrainset, trainset) + (pretrainset.clone(), [pretrainset, trainset].concat()) } #[cfg(test)] @@ -234,12 +251,12 @@ mod tests { FSRSItem { reviews: vec![ FSRSReview { - rating: 1, + rating: 3, delta_t: 0, }, FSRSReview { - rating: 4, - delta_t: 2, + rating: 3, + delta_t: 1, }, ], } @@ -433,7 +450,7 @@ mod tests { let dataset = anki21_sample_file_converted_to_fsrs(); let (mut pretrainset, mut trainset): (Vec, Vec) = dataset .into_iter() - .partition(|item| item.reviews.len() == 2); + .partition(|item| item.long_term_review_cnt() == 1); assert_eq!(pretrainset.len(), 3315); assert_eq!(trainset.len(), 10975); (pretrainset, trainset) = filter_outlier(pretrainset, trainset); diff --git a/src/inference.rs b/src/inference.rs index 66403e2..c646139 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -21,9 +21,9 @@ pub(crate) const S_MIN: f32 = 0.01; pub type Parameters = [f32]; use itertools::izip; -pub static DEFAULT_PARAMETERS: [f32; 17] = [ - 0.4872, 1.4003, 3.7145, 13.8206, 5.1618, 1.2298, 0.8975, 0.031, 1.6474, 0.1367, 1.0461, 2.1072, - 0.0793, 0.3246, 1.587, 0.2272, 2.8755, +pub static DEFAULT_PARAMETERS: [f32; 19] = [ + 0.4197, 1.1869, 3.0412, 15.2441, 7.1434, 0.6477, 1.0007, 0.0674, 1.6597, 0.1712, 1.1178, + 2.0225, 0.0904, 0.3025, 2.1214, 0.2498, 2.9466, 0.4891, 0.6468, ]; fn infer( @@ -172,25 +172,20 @@ impl FSRS { let current_memory_state_tensors = current_memory_state.map(MemoryStateTensors::from); let model = self.model(); let mut next_memory_states = (1..=4).map(|rating| { - Ok( - if let (Some(current_memory_state), 0) = (current_memory_state, days_elapsed) { - // When there's an existing memory state and no days have elapsed, we leave it unchanged. - current_memory_state - } else { - let state = MemoryState::from(model.step( - delta_t.clone(), - Tensor::from_data( - Data::new(vec![rating.elem()], Shape { dims: [1] }), - &self.device(), - ), - current_memory_state_tensors.clone(), - )); - if !state.stability.is_finite() || !state.difficulty.is_finite() { - return Err(FSRSError::InvalidInput); - } - state - }, - ) + Ok({ + let state = MemoryState::from(model.step( + delta_t.clone(), + Tensor::from_data( + Data::new(vec![rating.elem()], Shape { dims: [1] }), + &self.device(), + ), + current_memory_state_tensors.clone(), + )); + if !state.stability.is_finite() || !state.difficulty.is_finite() { + return Err(FSRSError::InvalidInput); + } + state + }) }); let mut get_next_state = || { @@ -382,8 +377,9 @@ mod tests { }; static PARAMETERS: &[f32] = &[ - 1.0171, 1.8296, 4.4145, 10.9355, 5.0965, 1.3322, 1.017, 0.0, 1.6243, 0.1369, 1.0321, - 2.1866, 0.0661, 0.336, 1.7766, 0.1693, 2.9244, + 0.72466177, 1.6790825, 4.562257, 10.0608635, 7.7002444, 0.912309, 1.0909119, 0.03472257, + 1.4395499, 0.1712, 0.8977034, 2.1090207, 0.0904, 0.3025, 2.3506782, 0.23486121, 3.1349943, + 0.18253358, 0.13885707, ]; #[test] @@ -431,8 +427,8 @@ mod tests { assert_eq!( fsrs.memory_state(item, None).unwrap(), MemoryState { - stability: 43.05542, - difficulty: 7.7609 + stability: 29.448196, + difficulty: 7.7002444 } ); @@ -449,8 +445,8 @@ mod tests { .good .memory, MemoryState { - stability: 51.441338, - difficulty: 7.005062 + stability: 40.669125, + difficulty: 7.0292006 } ); Ok(()) @@ -469,8 +465,9 @@ mod tests { #[test] fn test_evaluate() -> Result<()> { let items = anki21_sample_file_converted_to_fsrs(); - let (mut pretrainset, mut trainset): (Vec, Vec) = - items.into_iter().partition(|item| item.reviews.len() == 2); + let (mut pretrainset, mut trainset): (Vec, Vec) = items + .into_iter() + .partition(|item| item.long_term_review_cnt() == 1); (pretrainset, trainset) = filter_outlier(pretrainset, trainset); let items = [pretrainset, trainset].concat(); let fsrs = FSRS::new(Some(&[]))?; @@ -478,20 +475,20 @@ mod tests { let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.204_330, 0.031_510]), 5); + .assert_approx_eq(&Data::from([0.216539, 0.045964]), 5); let fsrs = FSRS::new(Some(PARAMETERS))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.202_188, 0.021_781]), 5); + .assert_approx_eq(&Data::from([0.202754, 0.032861]), 5); let (self_by_other, other_by_self) = fsrs .universal_metrics(items, &DEFAULT_PARAMETERS, |_| true) .unwrap(); Data::from([self_by_other, other_by_self]) - .assert_approx_eq(&Data::from([0.013_520, 0.019_003]), 5); + .assert_approx_eq(&Data::from([0.015230, 0.032233]), 5); Ok(()) } @@ -524,31 +521,31 @@ mod tests { NextStates { again: ItemState { memory: MemoryState { - stability: 3.9653313, - difficulty: 9.7949 + stability: 3.0271175, + difficulty: 9.80631 }, - interval: 4 + interval: 3 }, hard: ItemState { memory: MemoryState { - stability: 22.415548, - difficulty: 8.7779 + stability: 16.725859, + difficulty: 8.753278 }, - interval: 22 + interval: 17 }, good: ItemState { memory: MemoryState { - stability: 43.05542, - difficulty: 7.7609 + stability: 29.448196, + difficulty: 7.7002444 }, - interval: 43 + interval: 29 }, easy: ItemState { memory: MemoryState { - stability: 90.86977, - difficulty: 6.7439003 + stability: 64.947784, + difficulty: 6.647212 }, - interval: 91 + interval: 65 } } ); @@ -556,21 +553,6 @@ mod tests { Ok(()) } - #[test] - fn states_are_unchaged_when_no_days_elapsed() -> Result<()> { - let fsrs = FSRS::new(Some(&[]))?; - // the first time a card is seen, a memory state must be set - let mut state_a = fsrs.next_states(None, 1.0, 0)?.again.memory; - // but if no days have elapsed and it's reviewed again, the state should be unchanged - let state_b = fsrs.next_states(Some(state_a), 1.0, 0)?.again.memory; - assert_eq!(state_a, state_b); - // if a day elapses, it's counted - state_a = fsrs.next_states(Some(state_a), 1.0, 1)?.again.memory; - assert_ne!(state_a, state_b); - - Ok(()) - } - #[test] fn current_retrievability() { let fsrs = FSRS::new(None).unwrap(); @@ -589,13 +571,13 @@ mod tests { let fsrs = FSRS::new(Some(&[]))?; let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap(); Data::from([memory_state.stability, memory_state.difficulty]) - .assert_approx_eq(&Data::from([9.999996, 7.4120417]), 5); + .assert_approx_eq(&Data::from([9.999996, 7.422087]), 5); let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.8).unwrap(); Data::from([memory_state.stability, memory_state.difficulty]) - .assert_approx_eq(&Data::from([4.170096, 9.491373]), 5); + .assert_approx_eq(&Data::from([4.170096, 9.545_82]), 5); let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.95).unwrap(); Data::from([memory_state.stability, memory_state.difficulty]) - .assert_approx_eq(&Data::from([21.712555, 2.80758]), 5); + .assert_approx_eq(&Data::from([21.712555, 2.593589]), 5); let memory_state = fsrs.memory_state_from_sm2(1.3, 20.0, 0.9).unwrap(); Data::from([memory_state.stability, memory_state.difficulty]) .assert_approx_eq(&Data::from([19.999992, 10.0]), 5); diff --git a/src/lib.rs b/src/lib.rs index 13a7913..9dc7177 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,5 +21,7 @@ pub use inference::{ ItemProgress, ItemState, MemoryState, ModelEvaluation, NextStates, DEFAULT_PARAMETERS, }; pub use model::FSRS; -pub use optimal_retention::{simulate, Card, SimulatorConfig}; +pub use optimal_retention::{ + extract_simulation_config, simulate, Card, RevlogEntry, RevlogReviewKind, SimulatorConfig, +}; pub use training::CombinedProgressState; diff --git a/src/model.rs b/src/model.rs index e82e80a..0339c85 100644 --- a/src/model.rs +++ b/src/model.rs @@ -38,7 +38,7 @@ impl Model { Self { w: Param::from_tensor(Tensor::from_floats( - Data::new(initial_params, Shape { dims: [17] }), + Data::new(initial_params, Shape { dims: [19] }), &B::Device::default(), )), config, @@ -87,6 +87,10 @@ impl Model { .mask_where(last_s.clone().lower(new_s), last_s) } + fn stability_short_term(&self, last_s: Tensor, rating: Tensor) -> Tensor { + last_s * (self.w.get(17) * (rating - 3 + self.w.get(18))).exp() + } + fn mean_reversion(&self, new_d: Tensor) -> Tensor { self.w.get(7) * (self.w.get(4) - new_d.clone()) + new_d } @@ -96,7 +100,7 @@ impl Model { } fn init_difficulty(&self, rating: Tensor) -> Tensor { - self.w.get(4) - self.w.get(5) * (rating - 3) + self.w.get(4) - (self.w.get(5) * (rating - 1)).exp() + 1 } fn next_difficulty(&self, difficulty: Tensor, rating: Tensor) -> Tensor { @@ -110,7 +114,7 @@ impl Model { state: Option>, ) -> MemoryStateTensors { let (new_s, new_d) = if let Some(state) = state { - let retention = self.power_forgetting_curve(delta_t, state.stability.clone()); + let retention = self.power_forgetting_curve(delta_t.clone(), state.stability.clone()); let stability_after_success = self.stability_after_success( state.stability.clone(), state.difficulty.clone(), @@ -122,8 +126,11 @@ impl Model { state.difficulty.clone(), retention, ); + let stability_short_term = + self.stability_short_term(state.stability.clone(), rating.clone()); let mut new_stability = stability_after_success .mask_where(rating.clone().equal_elem(1), stability_after_failure); + new_stability = new_stability.mask_where(delta_t.equal_elem(0), stability_short_term); let mut new_difficulty = self.next_difficulty(state.difficulty.clone(), rating.clone()); new_difficulty = self.mean_reversion(new_difficulty).clamp(1.0, 10.0); @@ -207,7 +214,7 @@ impl FSRS { if let Some(parameters) = &mut parameters { if parameters.is_empty() { *parameters = DEFAULT_PARAMETERS.as_slice() - } else if parameters.len() != 17 { + } else if parameters.len() != 19 && parameters.len() != 17 { return Err(FSRSError::InvalidParameters); } } @@ -231,8 +238,15 @@ impl FSRS { pub(crate) fn parameters_to_model(parameters: &Parameters) -> Model { let config = ModelConfig::default(); let mut model = Model::new(config); + let new_params = if parameters.len() == 17 { + let mut new_params = parameters.to_vec(); + new_params.extend_from_slice(&[0.0, 0.0]); + new_params + } else { + parameters.to_vec() + }; model.w = Param::from_tensor(Tensor::from_floats( - Data::new(clip_parameters(parameters), Shape { dims: [17] }), + Data::new(clip_parameters(&new_params), Shape { dims: [19] }), &B::Device::default(), )); model @@ -291,12 +305,12 @@ mod tests { assert_eq!( difficulty.to_data(), Data::from([ - DEFAULT_PARAMETERS[4] + 2.0 * DEFAULT_PARAMETERS[5], - DEFAULT_PARAMETERS[4] + DEFAULT_PARAMETERS[5], DEFAULT_PARAMETERS[4], - DEFAULT_PARAMETERS[4] - DEFAULT_PARAMETERS[5], - DEFAULT_PARAMETERS[4] + 2.0 * DEFAULT_PARAMETERS[5], - DEFAULT_PARAMETERS[4] + DEFAULT_PARAMETERS[5] + DEFAULT_PARAMETERS[4] - DEFAULT_PARAMETERS[5].exp() + 1.0, + DEFAULT_PARAMETERS[4] - (2.0 * DEFAULT_PARAMETERS[5]).exp() + 1.0, + DEFAULT_PARAMETERS[4] - (3.0 * DEFAULT_PARAMETERS[5]).exp() + 1.0, + DEFAULT_PARAMETERS[4], + DEFAULT_PARAMETERS[4] - DEFAULT_PARAMETERS[5].exp() + 1.0, ]) ) } @@ -344,7 +358,7 @@ mod tests { next_difficulty.clone().backward(); assert_eq!( next_difficulty.to_data(), - Data::from([6.744371, 5.8746934, 5.005016, 4.1353383]) + Data::from([7.0109706, 6.077718, 5.144465, 4.211212]) ) } @@ -365,19 +379,24 @@ mod tests { s_recall.clone().backward(); assert_eq!( s_recall.to_data(), - Data::from([27.980768, 14.916422, 66.45966, 222.94603]) + Data::from([28.603035, 16.240442, 68.610886, 237.08693]) ); - let s_forget = model.stability_after_failure(stability, difficulty, retention); + let s_forget = model.stability_after_failure(stability.clone(), difficulty, retention); s_forget.clone().backward(); assert_eq!( s_forget.to_data(), - Data::from([1.9482934, 2.161251, 2.4528089, 2.8098207]) + Data::from([1.7989675, 2.089014, 2.4897401, 2.9990985]) ); let next_stability = s_recall.mask_where(rating.clone().equal_elem(1), s_forget); next_stability.clone().backward(); assert_eq!( next_stability.to_data(), - Data::from([1.9482934, 14.916422, 66.45966, 222.94603]) + Data::from([1.7989675, 16.240442, 68.610886, 237.08693]) + ); + let next_stability = model.stability_short_term(stability, rating); + assert_eq!( + next_stability.to_data(), + Data::from([2.5794802, 4.206739, 6.8605514, 11.188516]) ) } @@ -386,5 +405,6 @@ mod tests { assert!(FSRS::new(Some(&[])).is_ok()); assert!(FSRS::new(Some(&[1.])).is_err()); assert!(FSRS::new(Some(DEFAULT_PARAMETERS.as_slice())).is_ok()); + assert!(FSRS::new(Some(&DEFAULT_PARAMETERS[..17])).is_ok()); } } diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 70d67bf..f2e579f 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -2,7 +2,7 @@ use crate::error::{FSRSError, Result}; use crate::inference::{next_interval, ItemProgress, Parameters, DECAY, FACTOR, S_MIN}; use crate::{DEFAULT_PARAMETERS, FSRS}; use burn::tensor::backend::Backend; -use itertools::izip; +use itertools::{izip, Itertools}; use ndarray::{s, Array1, Array2, Ix0, Ix1, SliceInfoElem, Zip}; use ndarray_rand::rand_distr::Distribution; use ndarray_rand::RandomExt; @@ -13,6 +13,7 @@ use rand::{ }; use rayon::iter::IntoParallelIterator; use rayon::iter::ParallelIterator; +use std::collections::HashMap; use strum::EnumCount; #[derive(Debug, EnumCount)] @@ -43,21 +44,24 @@ impl From for SliceInfoElem { } } -const R_MIN: f64 = 0.75; -const R_MAX: f64 = 0.95; +const R_MIN: f32 = 0.75; +const R_MAX: f32 = 0.95; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct SimulatorConfig { pub deck_size: usize, pub learn_span: usize, - pub max_cost_perday: f64, - pub max_ivl: f64, - pub recall_costs: [f64; 3], - pub forget_cost: f64, - pub learn_cost: f64, - pub first_rating_prob: [f64; 4], - pub review_rating_prob: [f64; 3], - pub loss_aversion: f64, + pub max_cost_perday: f32, + pub max_ivl: f32, + pub learn_costs: [f32; 4], + pub review_costs: [f32; 4], + pub first_rating_prob: [f32; 4], + pub review_rating_prob: [f32; 3], + pub first_rating_offsets: [f32; 4], + pub first_session_lens: [f32; 4], + pub forget_rating_offset: f32, + pub forget_session_len: f32, + pub loss_aversion: f32, pub learn_limit: usize, pub review_limit: usize, } @@ -69,58 +73,82 @@ impl Default for SimulatorConfig { learn_span: 365, max_cost_perday: 1800.0, max_ivl: 36500.0, - recall_costs: [14.0, 10.0, 6.0], - forget_cost: 50.0, - learn_cost: 20.0, - first_rating_prob: [0.15, 0.2, 0.6, 0.05], - review_rating_prob: [0.3, 0.6, 0.1], - loss_aversion: 2.5, + learn_costs: [33.79, 24.3, 13.68, 6.5], + review_costs: [23.0, 11.68, 7.33, 5.6], + first_rating_prob: [0.24, 0.094, 0.495, 0.171], + review_rating_prob: [0.224, 0.631, 0.145], + first_rating_offsets: [-0.72, -0.15, -0.01, 0.0], + first_session_lens: [2.02, 1.28, 0.81, 0.0], + forget_rating_offset: -0.28, + forget_session_len: 1.05, + loss_aversion: 1.5, learn_limit: usize::MAX, review_limit: usize::MAX, } } } -fn stability_after_success(w: &[f64], s: f64, r: f64, d: f64, response: usize) -> f64 { - let hard_penalty = if response == 2 { w[15] } else { 1.0 }; - let easy_bonus = if response == 4 { w[16] } else { 1.0 }; - s * (f64::exp(w[8]) +fn stability_after_success(w: &[f32], s: f32, r: f32, d: f32, rating: usize) -> f32 { + let hard_penalty = if rating == 2 { w[15] } else { 1.0 }; + let easy_bonus = if rating == 4 { w[16] } else { 1.0 }; + s * (f32::exp(w[8]) * (11.0 - d) * s.powf(-w[9]) - * (f64::exp((1.0 - r) * w[10]) - 1.0) + * (f32::exp((1.0 - r) * w[10]) - 1.0) * hard_penalty) .mul_add(easy_bonus, 1.0) } -fn stability_after_failure(w: &[f64], s: f64, r: f64, d: f64) -> f64 { - (w[11] * d.powf(-w[12]) * ((s + 1.0).powf(w[13]) - 1.0) * f64::exp((1.0 - r) * w[14])) - .clamp(S_MIN.into(), s) +fn stability_after_failure(w: &[f32], s: f32, r: f32, d: f32) -> f32 { + (w[11] * d.powf(-w[12]) * ((s + 1.0).powf(w[13]) - 1.0) * f32::exp((1.0 - r) * w[14])) + .clamp(S_MIN, s) +} + +fn stability_short_term(w: &[f32], s: f32, rating_offset: f32, session_len: f32) -> f32 { + s * (w[17] * (rating_offset + session_len * w[18])).exp() +} + +fn init_d(w: &[f32], rating: usize, rating_offset: f32) -> f32 { + let new_d = w[4] - (w[5] * (rating - 1) as f32).exp() + 1.0 - w[6] * rating_offset; + new_d.clamp(1.0, 10.0) +} + +fn next_d(w: &[f32], d: f32, rating: usize) -> f32 { + let new_d = d - w[6] * (rating as f32 - 3.0); + mean_reversion(w, w[4], new_d).clamp(1.0, 10.0) +} + +fn mean_reversion(w: &[f32], init: f32, current: f32) -> f32 { + w[7] * init + (1.0 - w[7]) * current } pub struct Card { - pub difficulty: f64, - pub stability: f64, - pub last_date: f64, - pub due: f64, + pub difficulty: f32, + pub stability: f32, + pub last_date: f32, + pub due: f32, } pub fn simulate( config: &SimulatorConfig, - w: &[f64], - desired_retention: f64, + w: &[f32], + desired_retention: f32, seed: Option, existing_cards: Option>, -) -> (Array1, Array1, Array1, Array1) { +) -> (Array1, Array1, Array1, Array1) { let SimulatorConfig { deck_size, learn_span, max_cost_perday, max_ivl, - recall_costs, - forget_cost, - learn_cost, + learn_costs, + review_costs, first_rating_prob, review_rating_prob, + first_rating_offsets, + first_session_lens, + forget_rating_offset, + forget_session_len, loss_aversion, learn_limit, review_limit, @@ -128,7 +156,7 @@ pub fn simulate( let mut card_table = Array2::zeros((Column::COUNT, deck_size)); card_table .slice_mut(s![Column::Due, ..]) - .fill(learn_span as f64); + .fill(learn_span as f32); card_table.slice_mut(s![Column::Difficulty, ..]).fill(1e-10); card_table.slice_mut(s![Column::Stability, ..]).fill(1e-10); @@ -155,6 +183,11 @@ pub fn simulate( let mut rng = StdRng::seed_from_u64(seed.unwrap_or(42)); + let mut init_ratings = Array1::zeros(deck_size); + init_ratings.iter_mut().for_each(|rating| { + *rating = first_rating_choices[first_rating_dist.sample(&mut rng)]; + }); + // Main simulation loop for today in 0..learn_span { let old_stability = card_table.slice(s![Column::Stability, ..]); @@ -168,13 +201,13 @@ pub fn simulate( izip!(&mut delta_t, &old_last_date, &has_learned) .filter(|(.., &has_learned_flag)| has_learned_flag) .for_each(|(delta_t, &last_date, ..)| { - *delta_t = today as f64 - last_date; + *delta_t = today as f32 - last_date; }); let mut retrievability = Array1::zeros(deck_size); // Create an array for retrievability - fn power_forgetting_curve(t: f64, s: f64) -> f64 { - (t / s).mul_add(FACTOR, 1.0).powf(DECAY) + fn power_forgetting_curve(t: f32, s: f32) -> f32 { + (t / s).mul_add(FACTOR as f32, 1.0).powf(DECAY as f32) } // Calculate retrievability for entries where has_learned is true @@ -185,11 +218,11 @@ pub fn simulate( }); // Set 'cost' column to 0 - let mut cost = Array1::::zeros(deck_size); + let mut cost = Array1::::zeros(deck_size); // Create 'need_review' mask let old_due = card_table.slice(s![Column::Due, ..]); - let need_review = old_due.mapv(|x| x <= today as f64); + let need_review = old_due.mapv(|x| x <= today as f32); // dbg!(&need_review.mapv(|x| x as i32).sum()); @@ -215,10 +248,14 @@ pub fn simulate( // Sample 'rating' for 'need_review' entries let mut ratings = Array1::zeros(deck_size); - izip!(&mut ratings, &(&need_review & !&forget)) - .filter(|(_, &condition)| condition) - .for_each(|(rating, _)| { - *rating = review_rating_choices[review_rating_dist.sample(&mut rng)] + izip!(&mut ratings, &need_review, &forget) + .filter(|(_, &condition, _)| condition) + .for_each(|(rating, _, forget)| { + *rating = if *forget { + 1 + } else { + review_rating_choices[review_rating_dist.sample(&mut rng)] + }; }); // Update 'cost' column based on 'need_review', 'forget' and 'ratings' @@ -226,14 +263,14 @@ pub fn simulate( .filter(|(_, &need_review_flag, _, _)| need_review_flag) .for_each(|(cost, _, &forget_flag, &rating)| { *cost = if forget_flag { - forget_cost * loss_aversion + review_costs[0] * loss_aversion } else { - recall_costs[rating - 2] + review_costs[rating - 1] } }); // Calculate cumulative sum of 'cost' - let mut cum_sum = Array1::::zeros(deck_size); + let mut cum_sum = Array1::::zeros(deck_size); cum_sum[0] = cost[0]; for i in 1..deck_size { cum_sum[i] = cum_sum[i - 1] + cost[i]; @@ -253,12 +290,12 @@ pub fn simulate( && (review_count <= review_limit) }); - let need_learn = old_due.mapv(|x| x == learn_span as f64); + let need_learn = old_due.mapv(|x| x == learn_span as f32); // Update 'cost' column based on 'need_learn' - izip!(&mut cost, &need_learn) - .filter(|(_, &need_learn_flag)| need_learn_flag) - .for_each(|(cost, _)| { - *cost = learn_cost; + izip!(&mut cost, &need_learn, &init_ratings) + .filter(|(_, &need_learn_flag, _)| need_learn_flag) + .for_each(|(cost, _, &rating)| { + *cost = learn_costs[rating - 1]; }); cum_sum[0] = cost[0]; @@ -280,13 +317,6 @@ pub fn simulate( need_learn_flag && (cum_cost <= max_cost_perday) && (learn_count <= learn_limit) }); - // Sample 'rating' for 'true_learn' entries - izip!(&mut ratings, &true_learn) - .filter(|(_, &true_learn_flag)| true_learn_flag) - .for_each(|(rating, _)| { - *rating = first_rating_choices[first_rating_dist.sample(&mut rng)] - }); - let mut new_stability = old_stability.to_owned(); let old_difficulty = card_table.slice(s![Column::Difficulty, ..]); // Iterate over slices and apply stability_after_failure function @@ -299,7 +329,9 @@ pub fn simulate( ) .filter(|(.., &condition)| condition) .for_each(|(new_stab, &stab, &retr, &diff, ..)| { - *new_stab = stability_after_failure(w, stab, retr, diff); + let post_lapse_stab = stability_after_failure(w, stab, retr, diff); + *new_stab = + stability_short_term(w, post_lapse_stab, forget_rating_offset, forget_session_len); }); // Iterate over slices and apply stability_after_success function @@ -319,51 +351,47 @@ pub fn simulate( // Initialize a new Array1 to store updated difficulty values let mut new_difficulty = old_difficulty.to_owned(); - // Update the difficulty values based on the condition 'true_review & forget' - izip!(&mut new_difficulty, &old_difficulty, &true_review, &forget) - .filter(|(.., &true_rev, &frgt)| true_rev && frgt) - .for_each(|(new_diff, &old_diff, ..)| { - *new_diff = (2.0f64.mul_add(w[6], old_diff)).clamp(1.0, 10.0); + // Update difficulty for review cards + izip!(&mut new_difficulty, &old_difficulty, &ratings, &true_review,) + .filter(|(.., &condition)| condition) + .for_each(|(new_diff, &old_diff, &rating, ..)| { + *new_diff = next_d(w, old_diff, rating); + if rating == 1 { + *new_diff -= (w[6] * forget_rating_offset).clamp(1.0, 10.0); + } }); - // Update the difficulty values based on the condition 'true_review & !forget' - izip!( - &mut new_difficulty, - &old_difficulty, - &ratings, - &(&true_review & !&forget) - ) - .filter(|(.., &condition)| condition) - .for_each(|(new_diff, &old_diff, &rating, ..)| { - *new_diff = w[6].mul_add(3.0 - rating as f64, old_diff).clamp(1.0, 10.0); - }); - // Update 'last_date' column where 'true_review' or 'true_learn' is true let mut new_last_date = old_last_date.to_owned(); izip!(&mut new_last_date, &true_review, &true_learn) .filter(|(_, &true_review_flag, &true_learn_flag)| true_review_flag || true_learn_flag) .for_each(|(new_last_date, ..)| { - *new_last_date = today as f64; + *new_last_date = today as f32; }); + // Initialize stability and difficulty for new cards izip!( &mut new_stability, &mut new_difficulty, - &ratings, + &init_ratings, &true_learn ) .filter(|(.., &true_learn_flag)| true_learn_flag) .for_each(|(new_stab, new_diff, &rating, _)| { - *new_stab = w[rating - 1]; - *new_diff = (w[5].mul_add(-(rating as f64 - 3.0), w[4])).clamp(1.0, 10.0); + *new_stab = stability_short_term( + w, + w[rating - 1], + first_rating_offsets[rating - 1], + first_session_lens[rating - 1], + ); + *new_diff = init_d(w, rating, first_rating_offsets[rating - 1]); }); let old_interval = card_table.slice(s![Column::Interval, ..]); let mut new_interval = old_interval.to_owned(); izip!(&mut new_interval, &new_stability, &true_review, &true_learn) .filter(|(.., &true_review_flag, &true_learn_flag)| true_review_flag || true_learn_flag) .for_each(|(new_ivl, &new_stab, ..)| { - *new_ivl = (next_interval(new_stab as f32, desired_retention as f32) as f64) - .clamp(1.0, max_ivl); + *new_ivl = (next_interval(new_stab, desired_retention) as f32).clamp(1.0, max_ivl); }); let old_due = card_table.slice(s![Column::Due, ..]); @@ -371,7 +399,7 @@ pub fn simulate( izip!(&mut new_due, &new_interval, &true_review, &true_learn) .filter(|(.., &true_review_flag, &true_learn_flag)| true_review_flag || true_learn_flag) .for_each(|(new_due, &new_ivl, ..)| { - *new_due = today as f64 + new_ivl; + *new_due = today as f32 + new_ivl; }); // Update the card_table with the new values @@ -408,11 +436,11 @@ pub fn simulate( fn sample( config: &SimulatorConfig, - parameters: &[f64], - desired_retention: f64, + parameters: &[f32], + desired_retention: f32, n: usize, progress: &mut F, -) -> Result +) -> Result where F: FnMut() -> bool, { @@ -433,8 +461,8 @@ where let total_cost = cost_per_day.sum(); total_cost / total_memorized }) - .sum::() - / n as f64) + .sum::() + / n as f32) } const SAMPLE_SIZE: usize = 4; @@ -447,20 +475,23 @@ impl FSRS { config: &SimulatorConfig, parameters: &Parameters, mut progress: F, - ) -> Result + ) -> Result where F: FnMut(ItemProgress) -> bool + Send, { let parameters = if parameters.is_empty() { - &DEFAULT_PARAMETERS - } else if parameters.len() != 17 { - return Err(FSRSError::InvalidParameters); + DEFAULT_PARAMETERS.to_vec() + } else if parameters.len() != 19 { + if parameters.len() == 17 { + let mut parameters = parameters.to_vec(); + parameters.extend_from_slice(&[0.0, 0.0]); + parameters + } else { + return Err(FSRSError::InvalidParameters); + } } else { - parameters - } - .iter() - .map(|v| *v as f64) - .collect::>(); + parameters.to_vec() + }; let mut progress_info = ItemProgress { current: 0, // not provided for this method @@ -477,16 +508,16 @@ impl FSRS { /// https://github.com/scipy/scipy/blob/5e4a5e3785f79dd4e8930eed883da89958860db2/scipy/optimize/_optimize.py#L2446 fn brent( config: &SimulatorConfig, - parameters: &[f64], + parameters: &[f32], mut progress: F, - ) -> Result + ) -> Result where F: FnMut() -> bool, { let mintol = 1e-10; - let cg = 0.3819660; + let cg = 0.381_966; let maxiter = 64; - let tol = 0.01f64; + let tol = 0.01f32; let (xb, fb) = ( R_MIN, @@ -495,7 +526,7 @@ impl FSRS { let (mut x, mut v, mut w) = (xb, xb, xb); let (mut fx, mut fv, mut fw) = (fb, fb, fb); let (mut a, mut b) = (R_MIN, R_MAX); - let mut deltax: f64 = 0.0; + let mut deltax: f32 = 0.0; let mut iter = 0; let mut rat = 0.0; let mut u; @@ -505,7 +536,7 @@ impl FSRS { let tol2 = 2.0 * tol1; let xmid = 0.5 * (a + b); // check for convergence - if (x - xmid).abs() < 0.5f64.mul_add(-(b - a), tol2) { + if (x - xmid).abs() < 0.5f32.mul_add(-(b - a), tol2) { break; } if deltax.abs() <= tol1 { @@ -588,26 +619,301 @@ impl FSRS { } } +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +pub enum RevlogReviewKind { + #[default] + Learning = 0, + Review = 1, + Relearning = 2, + /// Old Anki versions called this "Cram" or "Early", and assigned it when + /// reviewing cards ahead. It is now only used for filtered decks with + /// rescheduling disabled. + Filtered = 3, + Manual = 4, +} + +#[derive(Clone, Copy, Debug, Default, PartialEq)] +pub struct RevlogEntry { + pub id: i64, + pub cid: i64, + pub usn: i32, + /// - In the V1 scheduler, 3 represents easy in the learning case. + /// - 0 represents manual rescheduling. + pub button_chosen: u8, + /// Positive values are in days, negative values in seconds. + pub interval: i32, + /// Positive values are in days, negative values in seconds. + pub last_interval: i32, + /// Card's ease after answering, stored as 10x the %, eg 2500 represents + /// 250%. + pub ease_factor: u32, + /// Amount of milliseconds taken to answer the card. + pub taken_millis: u32, + pub review_kind: RevlogReviewKind, +} + +pub fn extract_simulation_config(df: Vec, day_cutoff: i64) -> SimulatorConfig { + /* + def rating_counts(x): + tmp = defaultdict(int, x.value_counts().to_dict()) + first = x.iloc[0] + tmp[first] -= 1 + return tmp + */ + fn rating_counts(entries: &[RevlogEntry]) -> [u32; 4] { + let mut counts = [0; 4]; + + for entry in entries.iter().skip(1) { + counts[entry.button_chosen as usize - 1] += 1; + } + + counts + } + /* + df1 = ( + df[(df["review_duration"] > 0) & (df["review_duration"] < 1200000)] + .groupby(by=["card_id", "real_days"]) + .agg( + { + "review_state": "first", + "review_rating": ["first", rating_counts], + "review_duration": "sum", + } + ) + .reset_index() + ) + */ + struct Df1Row { + card_id: i64, + first_review_state: u8, + first_review_rating: u8, + review_rating_counts: [u32; 4], + sum_review_duration: u32, + } + let df1 = { + let mut grouped_data = HashMap::new(); + for &row in df.iter() { + if row.taken_millis > 0 && row.taken_millis < 1200000 { + let real_days = (row.id / 1000 - day_cutoff) / 86400; + let key = (row.cid, real_days); + grouped_data.entry(key).or_insert_with(Vec::new).push(row); + } + } + + grouped_data + .into_iter() + .filter_map(|((card_id, _real_days), entries)| { + entries.first().map(|first_entry| { + let first_review_state = first_entry.review_kind as u8 + 1; + let first_review_rating = first_entry.button_chosen; + let review_rating_counts = rating_counts(&entries); + let sum_review_duration = + entries.iter().map(|entry| entry.taken_millis).sum::(); + + Df1Row { + card_id, + first_review_state, + first_review_rating, + review_rating_counts, + sum_review_duration, + } + }) + }) + .collect_vec() + }; + + let cost_dict = { + let mut cost_dict = HashMap::new(); + for row in df1.iter() { + cost_dict + .entry((row.first_review_state, row.first_review_rating)) + .or_insert_with(Vec::new) + .push(row.sum_review_duration); + } + // calculate the median of the sum_review_duration + fn median(x: &mut [u32]) -> u32 { + x.sort_unstable(); + let n = x.len(); + if n % 2 == 0 { + (x[n / 2 - 1] + x[n / 2]) / 2 + } else { + x[n / 2] + } + } + cost_dict + .into_iter() + .map(|(k, mut v)| (k, median(&mut v))) + .collect::>() + }; + + // [cost_dict[(1, i)] / 1000 for i in range(1, 5)] + let learn_costs = (1..5) + .map(|i| cost_dict.get(&(1, i)).copied().unwrap_or_default() as f32 / 1000f32) + .collect_vec() + .try_into() + .unwrap(); + // [cost_dict[(2, i)] / 1000 for i in range(1, 5)] + let review_costs = (1..5) + .map(|i| cost_dict.get(&(2, i)).copied().unwrap_or_default() as f32 / 1000f32) + .collect_vec() + .try_into() + .unwrap(); + /* + button_usage_dict = ( + df1.groupby(by=["first_review_state", "first_review_rating"])["card_id"] + .count() + .to_dict() + ) */ + let button_usage_dict = { + let mut button_usage_dict = HashMap::new(); + for row in df1.iter() { + button_usage_dict + .entry((row.first_review_state, row.first_review_rating)) + .or_insert_with(Vec::new) + .push(row.card_id); // is this correct? + } + button_usage_dict + .into_iter() + .map(|(x, y)| (x, y.len() as i64)) + .collect::>() + }; + // [button_usage_dict.get((1, i), 0) for i in range(1, 5)] + let learn_buttons: [i64; 4] = (1..5) + .map(|i| button_usage_dict.get(&(1, i)).copied().unwrap_or_default()) + .collect_vec() + .try_into() + .unwrap(); + // [button_usage_dict.get((2, i), 0) for i in range(1, 5)] + let review_buttons: [i64; 4] = (1..5) + .map(|i| button_usage_dict.get(&(2, i)).copied().unwrap_or_default()) + .collect_vec() + .try_into() + .unwrap(); + + // self.first_rating_prob = self.learn_buttons / self.learn_buttons.sum() + let first_rating_prob = learn_buttons + .iter() + .map(|x| *x as f32 / learn_buttons.iter().sum::() as f32) + .collect_vec() + .try_into() + .unwrap(); + // self.review_buttons[1:] / self.review_buttons[1:].sum() + let review_rating_prob = review_buttons + .iter() + .skip(1) + .map(|x| *x as f32 / review_buttons.iter().skip(1).sum::() as f32) + .collect_vec() + .try_into() + .unwrap(); + + // df2 = ( + // df1.groupby(by=["first_review_state", "first_review_rating"])[[1, 2, 3, 4]] + // .mean() + // .round(2) + // ) + + let df2 = { + let mut grouped = HashMap::new(); + for review in df1 { + grouped + .entry((review.first_review_state, review.first_review_rating)) + .or_insert_with(Vec::new) + .push(review); + } + grouped + .iter() + .map(|((state, rating), group)| { + let count = group.len() as f32; + let (sum1, sum2, sum3, sum4) = + group + .iter() + .fold((0, 0, 0, 0), |(sum1, sum2, sum3, sum4), review| { + ( + sum1 + review.review_rating_counts[0], + sum2 + review.review_rating_counts[1], + sum3 + review.review_rating_counts[2], + sum4 + review.review_rating_counts[3], + ) + }); + + let averages = [ + (sum1 as f32 / count * 100.0).round() / 100.0, + (sum2 as f32 / count * 100.0).round() / 100.0, + (sum3 as f32 / count * 100.0).round() / 100.0, + (sum4 as f32 / count * 100.0).round() / 100.0, + ]; + + ((*state, *rating), averages) + }) + .collect::>() + }; + // rating_offset_dict = sum([df2[g] * (g - 3) for g in range(1, 5)]).to_dict() + let rating_offset_dict = { + let mut rating_offset_dict = HashMap::new(); + for (k, averages) in df2.iter() { + let offset = averages + .iter() + .enumerate() + .map(|(i, &v)| ((i + 1) as f32 - 3.0) * v) + .sum::(); + rating_offset_dict.insert(k, (offset * 100.0).round() / 100.0); + } + rating_offset_dict + }; + // session_len_dict = sum([df2[g] for g in range(1, 5)]).to_dict() + let session_len_dict = { + let mut session_len_dict = HashMap::new(); + for (k, averages) in df2.iter() { + let sum = averages.iter().sum::(); + session_len_dict.insert(k, (sum * 100.0).round() / 100.0); + } + session_len_dict + }; + // [rating_offset_dict[(1, i)] for i in range(1, 5)] + let first_rating_offsets = (1..5) + .map(|i| rating_offset_dict.get(&(1, i)).copied().unwrap_or_default()) + .collect_vec() + .try_into() + .unwrap(); + + // [session_len_dict[(1, i)] for i in range(1, 5)] + let first_session_lens = (1..5) + .map(|i| session_len_dict.get(&(1, i)).copied().unwrap_or_default()) + .collect_vec() + .try_into() + .unwrap(); + + // rating_offset_dict[(2, 1)] + let forget_rating_offset = rating_offset_dict.get(&(2, 1)).copied().unwrap_or_default(); + // session_len_dict[(2, 1)] + let forget_session_len = session_len_dict.get(&(2, 1)).copied().unwrap_or_default(); + + SimulatorConfig { + learn_costs, + review_costs, + first_rating_prob, + review_rating_prob, + first_rating_offsets, + first_session_lens, + forget_rating_offset, + forget_session_len, + ..Default::default() + } +} + #[cfg(test)] mod tests { - use itertools::Itertools; - use super::*; - use crate::DEFAULT_PARAMETERS; + use crate::{convertor_tests::read_collection, DEFAULT_PARAMETERS}; #[test] fn simulator() { let config = SimulatorConfig::default(); - let (memorized_cnt_per_day, _, _, _) = simulate( - &config, - &DEFAULT_PARAMETERS.iter().map(|v| *v as f64).collect_vec(), - 0.9, - None, - None, - ); + let (memorized_cnt_per_day, _, _, _) = + simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None); assert_eq!( memorized_cnt_per_day[memorized_cnt_per_day.len() - 1], - 3199.9526251977177 + 8222.023 ) } @@ -617,7 +923,7 @@ mod tests { learn_span: 30, learn_limit: 60, review_limit: 200, - max_cost_perday: f64::INFINITY, + max_cost_perday: f32::INFINITY, ..Default::default() }; let cards = vec![ @@ -634,13 +940,7 @@ mod tests { due: 0.0, }, ]; - let memorization = simulate( - &config, - &DEFAULT_PARAMETERS.iter().map(|v| *v as f64).collect_vec(), - 0.9, - None, - Some(cards), - ); + let memorization = simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, Some(cards)); dbg!(memorization); } @@ -650,21 +950,15 @@ mod tests { learn_span: 30, learn_limit: 60, review_limit: 200, - max_cost_perday: f64::INFINITY, + max_cost_perday: f32::INFINITY, ..Default::default() }; - let results = simulate( - &config, - &DEFAULT_PARAMETERS.iter().map(|v| *v as f64).collect_vec(), - 0.9, - None, - None, - ); + let results = simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None); assert_eq!( results.1.to_vec(), vec![ - 0, 16, 27, 34, 84, 80, 91, 92, 104, 106, 109, 112, 133, 123, 139, 121, 136, 149, - 136, 159, 173, 178, 175, 180, 189, 181, 196, 200, 193, 196 + 0, 15, 16, 17, 70, 73, 77, 82, 75, 87, 86, 111, 113, 110, 105, 112, 124, 131, 127, + 119, 122, 163, 145, 150, 171, 150, 136, 163, 167, 156 ] ); assert_eq!( @@ -681,14 +975,56 @@ mod tests { let config = SimulatorConfig { deck_size: learn_span * learn_limit, learn_span, - max_cost_perday: f64::INFINITY, + max_cost_perday: f32::INFINITY, learn_limit, - loss_aversion: 2.5, + loss_aversion: 1.5, ..Default::default() }; let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap(); - assert_eq!(optimal_retention, 0.8419900928572013); + assert_eq!(optimal_retention, 0.7791796); assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err()); Ok(()) } + + #[test] + fn optimal_retention_with_old_parameters() -> Result<()> { + let learn_span = 1000; + let learn_limit = 10; + let fsrs = FSRS::new(None)?; + let config = SimulatorConfig { + deck_size: learn_span * learn_limit, + learn_span, + max_cost_perday: f32::INFINITY, + learn_limit, + loss_aversion: 1.5, + ..Default::default() + }; + let optimal_retention = fsrs + .optimal_retention(&config, &DEFAULT_PARAMETERS[..17], |_v| true) + .unwrap(); + assert_eq!(optimal_retention, 0.76764786); + Ok(()) + } + + #[test] + fn extract_simulator_config_from_revlog() { + let mut revlogs = read_collection().unwrap(); + revlogs.sort_by_cached_key(|r| (r.cid, r.id)); + let day_cutoff = 1720900800; + let simulator_config = extract_simulation_config(revlogs, day_cutoff); + assert_eq!( + simulator_config, + SimulatorConfig { + learn_costs: [30.061, 0., 17.298, 12.352], + review_costs: [19.139, 6.887, 5.83, 4.002], + first_rating_prob: [0.19349411, 0., 0.14357824, 0.662_927_6], + review_rating_prob: [0.07351815, 0.9011334, 0.025348445], + first_rating_offsets: [1.64, 0., 0.69, 1.11], + first_session_lens: [2.74, 0., 1.32, 1.19], + forget_rating_offset: 1.28, + forget_session_len: 1.77, + ..Default::default() + } + ) + } } diff --git a/src/parameter_clipper.rs b/src/parameter_clipper.rs index e7cd849..501e06e 100644 --- a/src/parameter_clipper.rs +++ b/src/parameter_clipper.rs @@ -14,24 +14,26 @@ pub(crate) fn parameter_clipper(parameters: Tensor) -> Tensor< pub(crate) fn clip_parameters(parameters: &Parameters) -> Vec { // https://regex101.com/r/21mXNI/1 - const CLAMPS: [(f32, f32); 17] = [ + const CLAMPS: [(f32, f32); 19] = [ (S_MIN, INIT_S_MAX), (S_MIN, INIT_S_MAX), (S_MIN, INIT_S_MAX), (S_MIN, INIT_S_MAX), (1.0, 10.0), - (0.1, 5.0), - (0.1, 5.0), + (0.1, 4.0), + (0.1, 4.0), (0.0, 0.75), - (0.0, 4.0), + (0.0, 4.5), (0.0, 0.8), - (0.01, 3.0), - (0.5, 5.0), - (0.01, 0.2), + (0.01, 3.5), + (0.1, 5.0), + (0.01, 0.25), (0.01, 0.9), - (0.01, 3.0), + (0.01, 4.0), (0.0, 1.0), (1.0, 6.0), + (0.0, 2.0), + (0.0, 2.0), ]; let mut parameters = parameters.to_vec(); diff --git a/src/pre_training.rs b/src/pre_training.rs index 175b1ed..e8fe52d 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -23,25 +23,28 @@ type FirstRating = u32; type Count = u32; fn create_pretrain_data(fsrs_items: Vec) -> HashMap> { - // filter FSRSItem instances with exactly 2 reviews. + // filter FSRSItem instances with exactly 1 long term review. let items: Vec<_> = fsrs_items .into_iter() - .filter(|item| item.reviews.len() == 2) + .filter(|item| item.long_term_review_cnt() == 1) .collect(); // use a nested HashMap (groups) to group items first by the rating in the first FSRSReview // and then by the delta_t in the second FSRSReview. - // (first_rating -> second_delta_t -> vec![0/1 for fail/pass]) + // (first_rating -> first_long_term_delta_t -> vec![0/1 for fail/pass]) let mut groups = HashMap::new(); for item in items { let first_rating = item.reviews[0].rating; - let second_delta_t = item.reviews[1].delta_t; - let second_label = (item.reviews[1].rating > 1) as i32; + let first_long_term_review = item.first_long_term_review(); + let first_long_term_delta_t = first_long_term_review.delta_t; + let first_long_term_label = (first_long_term_review.rating > 1) as i32; let inner_map = groups.entry(first_rating).or_insert_with(HashMap::new); - let ratings = inner_map.entry(second_delta_t).or_insert_with(Vec::new); - ratings.push(second_label); + let ratings = inner_map + .entry(first_long_term_delta_t) + .or_insert_with(Vec::new); + ratings.push(first_long_term_label); } let mut results = HashMap::new(); @@ -293,9 +296,9 @@ mod tests { let count = Array1::from(vec![435.0, 97.0, 63.0, 38.0, 28.0]); let default_s0 = DEFAULT_PARAMETERS[0] as f64; let actual = loss(&delta_t, &recall, &count, 1.017056, default_s0); - assert_eq!(actual, 280.7447802452844); + assert_eq!(actual, 280.7489989949864); let actual = loss(&delta_t, &recall, &count, 1.017011, default_s0); - assert_eq!(actual, 280.7444462249327); + assert_eq!(actual, 280.74866497463466); } #[test] @@ -331,7 +334,7 @@ mod tests { }, ], )]); - let actual = search_parameters(pretrainset, 0.9430285915990116); + let actual = search_parameters(pretrainset, 0.943_028_57); Data::from([*actual.get(&first_rating).unwrap()]) .assert_approx_eq(&Data::from([0.908_688]), 6); } @@ -340,8 +343,9 @@ mod tests { fn test_pretrain() { use crate::convertor_tests::anki21_sample_file_converted_to_fsrs; let items = anki21_sample_file_converted_to_fsrs(); - let (mut pretrainset, mut trainset) = - items.into_iter().partition(|item| item.reviews.len() == 2); + let (mut pretrainset, mut trainset) = items + .into_iter() + .partition(|item| item.long_term_review_cnt() == 1); (pretrainset, trainset) = filter_outlier(pretrainset, trainset); let items = [pretrainset.clone(), trainset].concat(); let average_recall = calculate_average_recall(&items); @@ -359,6 +363,6 @@ mod tests { let mut rating_stability = HashMap::from([(2, 0.35)]); let rating_count = HashMap::from([(2, 1)]); let actual = smooth_and_fill(&mut rating_stability, &rating_count).unwrap(); - assert_eq!(actual, [0.1217739, 0.35, 0.928426, 3.4544096]); + assert_eq!(actual, [0.123763576, 0.34999996, 0.8968067, 4.495269]); } } diff --git a/src/training.rs b/src/training.rs index 81308af..af04884 100644 --- a/src/training.rs +++ b/src/training.rs @@ -1,6 +1,6 @@ use crate::batch_shuffle::BatchShuffledDataLoaderBuilder; use crate::cosine_annealing::CosineAnnealingLR; -use crate::dataset::{split_filter_data, FSRSBatcher, FSRSDataset, FSRSItem}; +use crate::dataset::{prepare_training_data, FSRSBatcher, FSRSDataset, FSRSItem}; use crate::error::Result; use crate::model::{Model, ModelConfig}; use crate::parameter_clipper::parameter_clipper; @@ -207,7 +207,7 @@ impl FSRS { }; let average_recall = calculate_average_recall(&train_set); - let (pre_train_set, next_train_set) = split_filter_data(train_set); + let (pre_train_set, next_train_set) = prepare_training_data(train_set); if pre_train_set.len() + next_train_set.len() < 8 { finish_progress(); return Ok(DEFAULT_PARAMETERS.to_vec()); @@ -228,7 +228,7 @@ impl FSRS { let config = TrainingConfig::new( ModelConfig { - freeze_stability: true, + freeze_stability: false, initial_stability: Some(initial_stability), }, AdamConfig::new(), @@ -279,11 +279,11 @@ impl FSRS { let average_recall = calculate_average_recall(&train_set); let (pre_train_set, next_train_set) = train_set .into_iter() - .partition(|item| item.reviews.len() == 2); + .partition(|item| item.long_term_review_cnt() == 1); let initial_stability = pretrain(pre_train_set, average_recall).unwrap(); let config = TrainingConfig::new( ModelConfig { - freeze_stability: true, + freeze_stability: false, initial_stability: Some(initial_stability), }, AdamConfig::new(),