diff --git a/Cargo.toml b/Cargo.toml index 38c0964..b1d1930 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,4 @@ log = "0.4" rusqlite = { version = "0.29.0" } chrono = "0.4.26" chrono-tz = "0.8.3" +itertools = "0.11.0" \ No newline at end of file diff --git a/src/convertor.rs b/src/convertor.rs index 8902622..8b71ae6 100644 --- a/src/convertor.rs +++ b/src/convertor.rs @@ -3,19 +3,15 @@ use chrono_tz::Tz; use rusqlite::{Connection, Result, Row}; use std::collections::HashMap; -use crate::dataset::{FSRSItem, Review}; +use crate::dataset::{FSRSItem, FSRSReview}; -#[derive(Debug)] +#[derive(Debug, Clone)] struct RevlogEntry { id: i64, cid: i64, button_chosen: i32, - ease_factor: i64, review_kind: i64, delta_t: i32, - i: usize, - r_history: Vec, - t_history: Vec, } fn row_to_revlog_entry(row: &Row) -> Result { @@ -23,12 +19,8 @@ fn row_to_revlog_entry(row: &Row) -> Result { id: row.get(0)?, cid: row.get(1)?, button_chosen: row.get(2)?, - ease_factor: row.get(3)?, - review_kind: row.get(4).unwrap_or_default(), + review_kind: row.get(3).unwrap_or_default(), delta_t: 0, - i: 0, - r_history: vec![], - t_history: vec![], }) } @@ -58,9 +50,10 @@ fn read_collection() -> Vec { let current_timestamp = Utc::now().timestamp() * 1000; let query = format!( - "SELECT id, cid, ease, factor, type + "SELECT id, cid, ease, type FROM revlog WHERE (type != 4 OR ivl <= 0) + AND (factor != 0 or type != 3) AND id < {} AND cid < {} AND cid IN ( @@ -104,11 +97,14 @@ fn convert_to_date(timestamp: i64, next_day_starts_at: i64, timezone: Tz) -> chr datetime.date_naive() } -fn extract_time_series_feature( +/// Given a list of revlog entries for a single card with length n, we create +/// n-1 FSRS items, where each item contains the history of the preceding reviews. +fn convert_to_fsrs_items( mut entries: Vec, next_day_starts_at: i64, timezone: Tz, -) -> Vec { +) -> Option> { + // Find the index of the first RevlogEntry in the last continuous group where review_kind = 0 // 寻找最后一组连续 review_kind = 0 的第一个 RevlogEntry 的索引 let mut index_to_keep = 0; let mut i = entries.len(); @@ -118,24 +114,33 @@ fn extract_time_series_feature( if entries[i].review_kind == 0 { index_to_keep = i; } else if index_to_keep != 0 { - break; // 找到了连续的 review_kind = 0 的组,退出循环 + // Found a continuous group of review_kind = 0, exit the loop + // 找到了连续的 review_kind = 0 的组,退出循环 + break; } } + // Remove all entries before this RevlogEntry // 删除此 RevlogEntry 之前的所有条目 entries.drain(..index_to_keep); - // 去掉 review_kind = 4 的 RevlogEntry - entries.retain(|entry| entry.review_kind != 4); - - // 去掉 review_kind = 3 且 ease_factor = 0 的 RevlogEntry - entries.retain(|entry| entry.review_kind != 3 || entry.ease_factor != 0); + // we ignore cards that don't start in the learning state + if let Some(entry) = entries.first() { + if entry.review_kind != 0 { + return None; + } + } else { + // no revlog entries + return None; + } + // Increment review_kind of all entries by 1 // 将所有 review_kind + 1 for entry in &mut entries { entry.review_kind += 1; } + // Convert the timestamp and keep the first RevlogEntry for each date // 转换时间戳并保留每个日期的第一个 RevlogEntry let mut unique_dates = std::collections::HashSet::new(); entries.retain(|entry| { @@ -143,6 +148,7 @@ fn extract_time_series_feature( unique_dates.insert(date) }); + // Compute delta_t for the remaining RevlogEntries // 计算其余 RevlogEntry 的 delta_t for i in 1..entries.len() { let date_current = convert_to_date(entries[i].id, next_day_starts_at, timezone); @@ -150,104 +156,217 @@ fn extract_time_series_feature( entries[i].delta_t = (date_current - date_previous).num_days() as i32; } - // 计算 i, r_history, t_history - for i in 0..entries.len() { - entries[i].i = i + 1; // 位置从 1 开始 - // 除了第一个条目,其余条目将前面的 button_chosen 和 delta_t 加入 r_history 和 t_history - if i > 0 { - entries[i].r_history = entries[0..i].iter().map(|e| e.button_chosen).collect(); - entries[i].t_history = entries[0..i].iter().map(|e| e.delta_t).collect(); - } - } - + // Find the RevlogEntry with review_kind = 0 where the preceding RevlogEntry has review_kind of 1 or 2, then remove it and all following RevlogEntries // 找到 review_kind = 0 且前一个 RevlogEntry 的 review_kind 是 1 或 2 的 RevlogEntry,然后删除其及其之后的所有 RevlogEntry if let Some(index_to_remove) = entries.windows(2).enumerate().find_map(|(i, window)| { if (window[0].review_kind == 1 || window[0].review_kind == 2) && window[1].review_kind == 0 { - Some(i + 1) // 返回第一个符合条件的 RevlogEntry 的索引 + // Return the index of the first RevlogEntry that meets the condition + // 返回第一个符合条件的 RevlogEntry 的索引 + Some(i + 1) } else { None } }) { - entries.truncate(index_to_remove); // 截取从 0 到 index_to_remove 的部分,删除其后的所有条目 + // Truncate from 0 to index_to_remove, removing all subsequent entries + // 截取从 0 到 index_to_remove 的部分,删除其后的所有条目 + entries.truncate(index_to_remove); } - entries -} - -fn convert_to_fsrs_items(revlogs: Vec>) -> Vec { - revlogs - .into_iter() - .flat_map(|group| { - group - .into_iter() - .filter(|entry| entry.i != 1) // 过滤掉 i = 1 的 RevlogEntry - .map(|entry| FSRSItem { - reviews: entry - .r_history - .iter() - .zip(entry.t_history.iter()) - .map(|(&r, &t)| Review { - rating: r, - delta_t: t, - }) - .collect(), - delta_t: entry.delta_t as f32, - label: match entry.button_chosen { - 1 => 0.0, - 2 | 3 | 4 => 1.0, - _ => panic!("Unexpected value for button_chosen"), - }, - }) - }) - .collect() -} - -fn remove_non_learning_first(revlogs_per_card: Vec>) -> Vec> { - let mut result = revlogs_per_card; - result.retain(|entries| { - if let Some(first_entry) = entries.first() { - first_entry.review_kind == 1 - } else { - false - } - }); - result + // Compute i, r_history, t_history + // 计算 i, r_history, t_history + // Except for the first entry, the remaining entries add the preceding button_chosen and delta_t to r_history and t_history + // 除了第一个条目,其余条目将前面的 button_chosen 和 delta_t 加入 r_history 和 t_history + Some( + entries + .iter() + .enumerate() + .skip(1) + .map(|(idx, _)| { + let reviews = entries + .iter() + .take(idx + 1) + .map(|r| FSRSReview { + rating: r.button_chosen, + delta_t: r.delta_t, + }) + .collect(); + FSRSItem { reviews } + }) + .collect(), + ) } pub fn anki_to_fsrs() -> Vec { let revlogs = read_collection(); let revlogs_per_card = group_by_cid(revlogs); - let extracted_revlogs_per_card: Vec> = revlogs_per_card + revlogs_per_card .into_iter() - .map(|entries| extract_time_series_feature(entries, 4, Tz::Asia__Shanghai)) - .collect(); - - let filtered_revlogs_per_card = remove_non_learning_first(extracted_revlogs_per_card); - - convert_to_fsrs_items(filtered_revlogs_per_card) + .filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai)) + .flatten() + .collect() } -#[test] -fn test() { - let revlogs = read_collection(); - assert_eq!(revlogs.len(), 24394); - let revlogs_per_card = group_by_cid(revlogs); - assert_eq!(revlogs_per_card.len(), 3324); - let mut extracted_revlogs_per_card: Vec> = revlogs_per_card - .into_iter() - .map(|entries| extract_time_series_feature(entries, 4, Tz::Asia__Shanghai)) - .collect(); +#[cfg(test)] +mod tests { + use super::*; + use crate::dataset::FSRSBatcher; + use burn::data::dataloader::batcher::Batcher; + use burn::tensor::Data; + use itertools::Itertools; - dbg!(&extracted_revlogs_per_card[0]); - extracted_revlogs_per_card = remove_non_learning_first(extracted_revlogs_per_card); - assert_eq!( - extracted_revlogs_per_card + // This test currently expects the following .anki21 file to be placed in tests/data/: + // https://github.com/open-spaced-repetition/fsrs-optimizer-burn/files/12394182/collection.anki21.zip + #[test] + fn test() { + let revlogs = read_collection(); + let single_card_revlog = vec![revlogs .iter() - .map(|x| x.len()) - .sum::(), - 17614 - ); - let fsrs_items: Vec = convert_to_fsrs_items(extracted_revlogs_per_card); - assert_eq!(fsrs_items.len(), 14290); + .filter(|r| r.cid == 1528947214762) + .cloned() + .collect_vec()]; + assert_eq!(revlogs.len(), 24394); + let revlogs_per_card = group_by_cid(revlogs); + assert_eq!(revlogs_per_card.len(), 3324); + let fsrs_items = revlogs_per_card + .into_iter() + .filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai)) + .flatten() + .collect_vec(); + + assert_eq!(fsrs_items.len(), 14290); + assert_eq!( + fsrs_items.iter().map(|x| x.reviews.len()).sum::(), + 49382 + 14290 + ); + + // convert a subset and check it matches expectations + let mut fsrs_items = single_card_revlog + .into_iter() + .filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai)) + .flatten() + .collect_vec(); + assert_eq!( + &fsrs_items, + &[ + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 3, + delta_t: 0 + }, + FSRSReview { + rating: 3, + delta_t: 5 + } + ], + }, + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 3, + delta_t: 0 + }, + FSRSReview { + rating: 3, + delta_t: 5 + }, + FSRSReview { + rating: 3, + delta_t: 10 + } + ], + }, + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 3, + delta_t: 0 + }, + FSRSReview { + rating: 3, + delta_t: 5 + }, + FSRSReview { + rating: 3, + delta_t: 10 + }, + FSRSReview { + rating: 3, + delta_t: 22 + } + ], + }, + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 3, + delta_t: 0 + }, + FSRSReview { + rating: 3, + delta_t: 5 + }, + FSRSReview { + rating: 3, + delta_t: 10 + }, + FSRSReview { + rating: 3, + delta_t: 22 + }, + FSRSReview { + rating: 2, + delta_t: 56 + } + ], + }, + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 3, + delta_t: 0 + }, + FSRSReview { + rating: 3, + delta_t: 5 + }, + FSRSReview { + rating: 3, + delta_t: 10 + }, + FSRSReview { + rating: 3, + delta_t: 22 + }, + FSRSReview { + rating: 2, + delta_t: 56 + }, + FSRSReview { + rating: 3, + delta_t: 64 + } + ], + } + ] + ); + + use burn_ndarray::NdArrayDevice; + let device = NdArrayDevice::Cpu; + use burn_ndarray::NdArrayBackend; + type Backend = NdArrayBackend; + let batcher = FSRSBatcher::::new(device); + let res = batcher.batch(vec![fsrs_items.pop().unwrap()]); + assert_eq!(res.delta_ts.into_scalar(), 64.0); + assert_eq!( + res.r_historys.squeeze(1).to_data(), + Data::from([3.0, 3.0, 3.0, 3.0, 2.0]) + ); + assert_eq!( + res.t_historys.squeeze(1).to_data(), + Data::from([0.0, 5.0, 10.0, 22.0, 56.0]) + ); + assert_eq!(res.labels.to_data(), Data::from([1])); + } } diff --git a/src/dataset.rs b/src/dataset.rs index 5766029..a342db3 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -7,19 +7,33 @@ use serde::{Deserialize, Serialize}; use crate::convertor::anki_to_fsrs; -#[derive(Debug, Clone, Deserialize, Serialize)] +/// Stores a list of reviews for a card, in chronological order. Each FSRSItem corresponds +/// to a single review, but contains the previous reviews of the card as well, after the +/// first one. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] pub struct FSRSItem { - pub reviews: Vec, - pub delta_t: f32, - pub label: f32, + pub reviews: Vec, } -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Review { +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct FSRSReview { + /// 1-4 pub rating: i32, + /// The number of days that passed pub delta_t: i32, } +impl FSRSItem { + // The previous reviews done before the current one. + pub(crate) fn history(&self) -> impl Iterator { + self.reviews.iter().take(self.reviews.len() - 1) + } + + pub(crate) fn current(&self) -> &FSRSReview { + self.reviews.last().unwrap() + } +} + pub struct FSRSBatcher { device: B::Device, } @@ -40,48 +54,42 @@ pub struct FSRSBatch { impl Batcher> for FSRSBatcher { fn batch(&self, items: Vec) -> FSRSBatch { - let t_historys = items + let (time_histories, rating_histories) = items .iter() .map(|item| { - Data::new( - item.reviews.iter().map(|r| r.delta_t).collect(), - Shape { - dims: [item.reviews.len()], - }, + let (delta_t, rating): (Vec<_>, _) = + item.history().map(|r| (r.delta_t, r.rating)).unzip(); + let count = delta_t.len(); + let delta_t = Tensor::::from_data( + Data::new(delta_t, Shape { dims: [count] }).convert(), ) + .unsqueeze(); + let rating = + Tensor::::from_data(Data::new(rating, Shape { dims: [count] }).convert()) + .unsqueeze(); + (delta_t, rating) }) - .map(|data| Tensor::::from_data(data.convert())) - .map(|tensor| tensor.unsqueeze()) - .collect(); + .unzip(); - let r_historys = items + let (delta_ts, labels) = items .iter() .map(|item| { - Data::new( - item.reviews.iter().map(|r| r.rating).collect(), - Shape { - dims: [item.reviews.len()], - }, - ) + let current = item.current(); + let delta_t = + Tensor::::from_data(Data::from([current.delta_t.elem()])); + let label = match current.rating { + 1 => 0.0, + _ => 1.0, + }; + let label = Tensor::::from_data(Data::from([label.elem()])); + (delta_t, label) }) - .map(|data| Tensor::::from_data(data.convert())) - .map(|tensor| tensor.unsqueeze()) - .collect(); - - let delta_ts = items - .iter() - .map(|item| Tensor::::from_data(Data::from([item.delta_t.elem()]))) - .collect(); - - let labels = items - .iter() - .map(|item| Tensor::::from_data(Data::from([item.label.elem()]))) - .collect(); + .unzip(); - let t_historys = Tensor::cat(t_historys, 0) + let t_historys = Tensor::cat(time_histories, 0) .transpose() .to_device(&self.device); // [seq_len, batch_size] - let r_historys = Tensor::cat(r_historys, 0) + let r_historys = Tensor::cat(rating_histories, 0) .transpose() .to_device(&self.device); // [seq_len, batch_size] let delta_ts = Tensor::cat(delta_ts, 0).to_device(&self.device); diff --git a/src/training.rs b/src/training.rs index 36fd6dd..3a15e86 100644 --- a/src/training.rs +++ b/src/training.rs @@ -38,7 +38,7 @@ impl> Model { info!("retention: {}", &retention); info!("logits: {}", &logits); info!("labels: {}", &labels); - let loss = self.bceloss(retention.clone(), labels.clone().float()); + let loss = self.bceloss(retention, labels.clone().float()); ClassificationOutput::new(loss, logits, labels) } } @@ -138,7 +138,7 @@ pub fn train>( PrettyJsonFileRecorder::::new() .record( - model_trained.clone().into_record(), + model_trained.into_record(), format!("{ARTIFACT_DIR}/model").into(), ) .expect("Failed to save trained model");