From ee7f6be7419dda4abccbe98963bfee98c4d7fd8c Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Mon, 28 Aug 2023 11:44:48 +0800 Subject: [PATCH] Feat/sort FSRSItem by length to speed up training (#32) * sort FSRSItem by length to speed up training * correct shape of logits * cargo fmt * Create batches with random seq_len * implement BatchShuffledDataset * cargo fmt * rename batch_shuffle & add English comments * Use an SQL sort and .group_by() for separating by card id * We don't need InMemDataset * Move all tests into test modules This allows easily running all tests in one file at once, and allows tests to share code that is not used in production. Also removed the duplicate test_next_stability/difficulty tests. * Remove redundant test_ prefix/suffix * Apply cosine_annealing patch from Asuka to fix test Co-authored-by: Asuka Minato * Run tests in CI, except training Co-authored-by: Asuka Minato * Limit checks to pull requests; bust cache --------- Co-authored-by: Damien Elmes Co-authored-by: Asuka Minato --- .github/workflows/check.sh | 6 + .github/workflows/check.yml | 5 +- Cargo.lock | 105 +++++----- Cargo.toml | 1 + src/batch_shuffle.rs | 104 +++++++++ src/convertor.rs | 71 ++++--- src/cosine_annealing.rs | 56 ++--- src/dataset.rs | 406 ++++++++++++++++++------------------ src/lib.rs | 2 + src/model.rs | 253 +++++++++------------- src/training.rs | 54 +++-- src/weight_clipper.rs | 39 ++-- 12 files changed, 601 insertions(+), 501 deletions(-) create mode 100644 src/batch_shuffle.rs diff --git a/.github/workflows/check.sh b/.github/workflows/check.sh index 2dfe298..c07a7a2 100755 --- a/.github/workflows/check.sh +++ b/.github/workflows/check.sh @@ -9,3 +9,9 @@ cargo fmt --check || ( ) cargo clippy -- -Dwarnings + +install -d tests/data/ +pushd tests/data/ +wget https://github.com/open-spaced-repetition/fsrs-optimizer-burn/files/12394182/collection.anki21.zip +unzip *.zip +SKIP_TRAINING=1 cargo test diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 8840638..bc9d5ab 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -1,6 +1,5 @@ name: Check code on: - push: pull_request: jobs: @@ -16,9 +15,9 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-rust-release-v5-${{ hashFiles('Cargo.lock') }} + key: ${{ runner.os }}-rust-release-v6-${{ hashFiles('Cargo.lock') }} restore-keys: | - ${{ runner.os }}-rust-release-v5 + ${{ runner.os }}-rust-release-v6 ${{ runner.os }}-rust-release - name: Run checks diff --git a/Cargo.lock b/Cargo.lock index 4499509..8161c42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -294,9 +294,9 @@ checksum = "a3e368af43e418a04d52505cf3dbc23dda4e3407ae2fa99fd0e4f308ce546acc" [[package]] name = "cc" -version = "1.0.82" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "305fe645edc1442a0fa8b6726ba61d422798d37a52e12eaecf4b022ebbb88f01" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ "libc", ] @@ -533,9 +533,9 @@ dependencies = [ [[package]] name = "dashmap" -version = "5.5.0" +version = "5.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6943ae99c34386c84a470c499d3414f66502a41340aa895406e0d2e4a207b91d" +checksum = "edd72493923899c6f10c641bdbdeddc7183d6396641d99c1a0d1597f37f92e28" dependencies = [ "cfg-if", "hashbrown 0.14.0", @@ -546,9 +546,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7684a49fb1af197853ef7b2ee694bc1f5b4179556f1e5710e1760c5db6f5e929" +checksum = "f2696e8a945f658fd14dc3b87242e6b80cd0f36ff04ea560fa39082368847946" [[package]] name = "derivative" @@ -743,6 +743,7 @@ dependencies = [ "chrono-tz", "itertools", "log", + "rand", "rusqlite", "serde", ] @@ -1016,7 +1017,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ "hermit-abi 0.3.2", - "rustix 0.38.8", + "rustix 0.38.9", "windows-sys 0.48.0", ] @@ -1389,7 +1390,7 @@ dependencies = [ "libc", "redox_syscall 0.3.5", "smallvec", - "windows-targets 0.48.2", + "windows-targets 0.48.5", ] [[package]] @@ -1479,9 +1480,9 @@ checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" [[package]] name = "png" -version = "0.17.9" +version = "0.17.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59871cc5b6cce7eaccca5a802b4173377a1c2ba90654246789a8fa2334426d11" +checksum = "dd75bf2d8dd3702b9707cdbc56a5b9ef42cec752eb8b3bafc01234558442aa64" dependencies = [ "bitflags 1.3.2", "crc32fast", @@ -1492,9 +1493,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.4.2" +version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f32154ba0af3a075eefa1eda8bb414ee928f62303a54ea85b8d6638ff1a6ee9e" +checksum = "31114a898e107c51bb1609ffaf55a0e011cf6a4d7f1170d0015a165082c0338b" [[package]] name = "ppv-lite86" @@ -1753,7 +1754,7 @@ dependencies = [ "libsqlite3-sys", "serde_json", "smallvec", - "time 0.3.25", + "time 0.3.27", "url", "uuid", ] @@ -1774,9 +1775,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.8" +version = "0.38.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ed4fa021d81c8392ce04db050a3da9a60299050b7ae1cf482d862b54a7218f" +checksum = "9bfe0f2582b4931a45d1fa608f8a8722e8b3c7ac54dd6d5f3b3212791fedef49" dependencies = [ "bitflags 2.4.0", "errno", @@ -1912,9 +1913,9 @@ checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" [[package]] name = "siphasher" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" [[package]] name = "smallvec" @@ -2039,20 +2040,20 @@ dependencies = [ "lazy_static", "libc", "nom 7.1.3", - "time 0.3.25", + "time 0.3.27", "winapi", ] [[package]] name = "tempfile" -version = "3.7.1" +version = "3.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc02fddf48964c42031a0b3fe0428320ecf3a73c401040fc0096f97794310651" +checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef" dependencies = [ "cfg-if", "fastrand", "redox_syscall 0.3.5", - "rustix 0.38.8", + "rustix 0.38.9", "windows-sys 0.48.0", ] @@ -2109,9 +2110,9 @@ dependencies = [ [[package]] name = "thread-id" -version = "4.1.0" +version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ee93aa2b8331c0fec9091548843f2c90019571814057da3b783f9de09349d73" +checksum = "79474f573561cdc4871a0de34a51c92f7f5a56039113fbb5b9c9f96bdb756669" dependencies = [ "libc", "redox_syscall 0.2.16", @@ -2151,9 +2152,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.25" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fdd63d58b18d663fbdf70e049f00a22c8e42be082203be7f26589213cd75ea" +checksum = "0bb39ee79a6d8de55f48f2293a830e040392f1c5f16e336bdd1788cd0aadce07" dependencies = [ "deranged", "itoa", @@ -2170,9 +2171,9 @@ checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" [[package]] name = "time-macros" -version = "0.2.11" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb71511c991639bb078fd5bf97757e03914361c48100d52878b8e52b46fb92cd" +checksum = "733d258752e9303d392b94b75230d07b0b9c489350c69b851fc6c065fde3e8f9" dependencies = [ "time-core", ] @@ -2391,7 +2392,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" dependencies = [ - "windows-targets 0.48.2", + "windows-targets 0.48.5", ] [[package]] @@ -2409,7 +2410,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets 0.48.2", + "windows-targets 0.48.5", ] [[package]] @@ -2429,17 +2430,17 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1eeca1c172a285ee6c2c84c341ccea837e7c01b12fbb2d0fe3c9e550ce49ec8" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm 0.48.2", - "windows_aarch64_msvc 0.48.2", - "windows_i686_gnu 0.48.2", - "windows_i686_msvc 0.48.2", - "windows_x86_64_gnu 0.48.2", - "windows_x86_64_gnullvm 0.48.2", - "windows_x86_64_msvc 0.48.2", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", ] [[package]] @@ -2450,9 +2451,9 @@ checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b10d0c968ba7f6166195e13d593af609ec2e3d24f916f081690695cf5eaffb2f" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_msvc" @@ -2462,9 +2463,9 @@ checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" [[package]] name = "windows_aarch64_msvc" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "571d8d4e62f26d4932099a9efe89660e8bd5087775a2ab5cdd8b747b811f1058" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_i686_gnu" @@ -2474,9 +2475,9 @@ checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" [[package]] name = "windows_i686_gnu" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2229ad223e178db5fbbc8bd8d3835e51e566b8474bfca58d2e6150c48bb723cd" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_msvc" @@ -2486,9 +2487,9 @@ checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" [[package]] name = "windows_i686_msvc" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "600956e2d840c194eedfc5d18f8242bc2e17c7775b6684488af3a9fff6fe3287" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_x86_64_gnu" @@ -2498,9 +2499,9 @@ checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" [[package]] name = "windows_x86_64_gnu" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea99ff3f8b49fb7a8e0d305e5aec485bd068c2ba691b6e277d29eaeac945868a" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnullvm" @@ -2510,9 +2511,9 @@ checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1a05a1ece9a7a0d5a7ccf30ba2c33e3a61a30e042ffd247567d1de1d94120d" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_msvc" @@ -2522,9 +2523,9 @@ checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" [[package]] name = "windows_x86_64_msvc" -version = "0.48.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d419259aba16b663966e29e6d7c6ecfa0bb8425818bb96f6f1f3c3eb71a6e7b9" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "wrapcenum-derive" diff --git a/Cargo.toml b/Cargo.toml index cbbdf03..030ef88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,3 +15,4 @@ rusqlite = { version = "0.29.0" } chrono = "0.4.26" chrono-tz = "0.8.3" itertools = "0.11.0" +rand = "0.8.5" diff --git a/src/batch_shuffle.rs b/src/batch_shuffle.rs new file mode 100644 index 0000000..975f05a --- /dev/null +++ b/src/batch_shuffle.rs @@ -0,0 +1,104 @@ +use burn::data::dataset::Dataset; +use rand::{prelude::SliceRandom, rngs::StdRng, SeedableRng}; +use std::marker::PhantomData; + +pub struct BatchShuffledDataset { + dataset: D, + indices: Vec, + input: PhantomData, +} + +impl BatchShuffledDataset +where + D: Dataset, +{ + /// Creates a new shuffled dataset. + pub fn new(dataset: D, batch_size: usize, rng: &mut StdRng) -> Self { + let len = dataset.len(); + + // Calculate the number of batches + // 计算批数 + let num_batches = (len + batch_size - 1) / batch_size; + + // Create a vector of batch indices and shuffle it + // 创建一个批数索引的向量并打乱 + let mut batch_indices: Vec = (0..num_batches).collect(); + batch_indices.shuffle(rng); + + // Generate the corresponding item indices for each shuffled batch + // 为每个打乱的批次生成相应的元素索引 + let mut indices: Vec = Vec::new(); + for &batch_index in &batch_indices { + let start_index = batch_index * batch_size; + let end_index = std::cmp::min(start_index + batch_size, len); + indices.extend(start_index..end_index); + } + + Self { + dataset, + indices, + input: PhantomData, + } + } + + /// Creates a new shuffled dataset with a fixed seed. + pub fn with_seed(dataset: D, batch_size: usize, seed: u64) -> Self { + let mut rng = StdRng::seed_from_u64(seed); + Self::new(dataset, batch_size, &mut rng) + } +} + +impl Dataset for BatchShuffledDataset +where + D: Dataset, + I: Clone + Send + Sync, +{ + fn get(&self, index: usize) -> Option { + let index = match self.indices.get(index) { + Some(index) => index, + None => return None, + }; + self.dataset.get(*index) + } + + fn len(&self) -> usize { + self.dataset.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn batch_shuffle() { + use crate::dataset::FSRSDataset; + let dataset = FSRSDataset::train(); + let batch_size = 10; + let seed = 42; + let batch_shuffled_dataset: BatchShuffledDataset = + BatchShuffledDataset::with_seed(dataset, batch_size, seed); + for i in 0..batch_shuffled_dataset.len() { + println!("{:?}", batch_shuffled_dataset.get(i).unwrap()); + if i > batch_size { + break; + } + } + } + + #[test] + fn item_shuffle() { + use crate::dataset::FSRSDataset; + use burn::data::dataset::transform::ShuffledDataset; + let dataset = FSRSDataset::train(); + let seed = 42; + let shuffled_dataset: ShuffledDataset = + ShuffledDataset::with_seed(dataset, seed); + for i in 0..shuffled_dataset.len() { + println!("{:?}", shuffled_dataset.get(i).unwrap()); + if i > 10 { + break; + } + } + } +} diff --git a/src/convertor.rs b/src/convertor.rs index b28ecb1..e9d62d2 100644 --- a/src/convertor.rs +++ b/src/convertor.rs @@ -1,12 +1,12 @@ use chrono::prelude::*; use chrono_tz::Tz; +use itertools::Itertools; use rusqlite::{Connection, Result, Row}; -use std::collections::HashMap; use crate::dataset::{FSRSItem, FSRSReview}; #[derive(Clone, Debug, Default)] -struct RevlogEntry { +pub(crate) struct RevlogEntry { id: i64, cid: i64, button_chosen: i32, @@ -63,25 +63,14 @@ fn read_collection() -> Result> { WHERE queue != 0 {suspended_cards_str} {flags_str} - )" + ) + order by cid" ))? .query_and_then((current_timestamp, current_timestamp), row_to_revlog_entry)? .collect::>>()?; Ok(revlogs) } -fn group_by_cid(revlogs: Vec) -> Vec> { - let mut grouped = HashMap::new(); - for revlog in revlogs { - grouped - .entry(revlog.cid) - .or_insert_with(Vec::new) - .push(revlog); - } - - grouped.into_values().collect() -} - fn convert_to_date(timestamp: i64, next_day_starts_at: i64, timezone: Tz) -> chrono::NaiveDate { let timestamp_seconds = timestamp - next_day_starts_at * 3600 * 1000; // 剪去指定小时数 let datetime = Utc @@ -189,14 +178,23 @@ fn convert_to_fsrs_items( ) } -pub fn anki_to_fsrs() -> Vec { - let revlogs = read_collection().expect("read error"); - let revlogs_per_card = group_by_cid(revlogs); - revlogs_per_card +pub(crate) fn anki21_sample_file_converted_to_fsrs() -> Vec { + anki_to_fsrs(read_collection().expect("read error")) +} + +/// Convert a series of revlog entries sorted by card id into FSRS items. +pub(crate) fn anki_to_fsrs(revlogs: Vec) -> Vec { + let mut revlogs = revlogs .into_iter() - .filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai)) + .group_by(|r| r.cid) + .into_iter() + .filter_map(|(_cid, entries)| { + convert_to_fsrs_items(entries.collect(), 4, Tz::Asia__Shanghai) + }) .flatten() - .collect() + .collect_vec(); + revlogs.sort_by_cached_key(|r| r.reviews.len()); + revlogs } #[cfg(test)] @@ -210,7 +208,7 @@ mod tests { // This test currently expects the following .anki21 file to be placed in tests/data/: // https://github.com/open-spaced-repetition/fsrs-optimizer-burn/files/12394182/collection.anki21.zip #[test] - fn test() { + fn conversion_works() { let revlogs = read_collection().unwrap(); let single_card_revlog = vec![revlogs .iter() @@ -218,14 +216,7 @@ mod tests { .cloned() .collect_vec()]; assert_eq!(revlogs.len(), 24394); - let revlogs_per_card = group_by_cid(revlogs); - assert_eq!(revlogs_per_card.len(), 3324); - let fsrs_items = revlogs_per_card - .into_iter() - .filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai)) - .flatten() - .collect_vec(); - + let fsrs_items = anki_to_fsrs(revlogs); assert_eq!(fsrs_items.len(), 14290); assert_eq!( fsrs_items.iter().map(|x| x.reviews.len()).sum::(), @@ -361,4 +352,24 @@ mod tests { ); assert_eq!(res.labels.to_data(), Data::from([1])); } + + #[test] + fn ordering_of_inputs_should_not_change() { + let revlogs = anki21_sample_file_converted_to_fsrs(); + assert_eq!( + revlogs[0], + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 4, + delta_t: 0 + }, + FSRSReview { + rating: 3, + delta_t: 3 + } + ] + } + ); + } } diff --git a/src/cosine_annealing.rs b/src/cosine_annealing.rs index c130d4a..ffb9ff6 100644 --- a/src/cosine_annealing.rs +++ b/src/cosine_annealing.rs @@ -62,31 +62,37 @@ impl LRScheduler for CosineAnnealingLR { } } -#[test] -fn test_lr_scheduler() { - let mut lr_scheduler = CosineAnnealingLR::init(100000.0, 1.0e-1); - let mut lrs = vec![]; - for i in 0..200000 { - if i % 20000 == 0 { - lrs.push(lr_scheduler.current_lr); +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn lr_scheduler() { + let mut lr_scheduler = CosineAnnealingLR::init(100000.0, 1.0e-1); + let mut lrs = vec![]; + for i in 0..200000 { + if i % 20000 == 0 { + lrs.push(lr_scheduler.current_lr); + } + lr_scheduler.step(); } - lr_scheduler.step(); + lrs.push(lr_scheduler.current_lr); + assert!(lrs + .iter() + .zip([ + 0.1, + 0.09045084971874785, + 0.06545084971874875, + 0.034549150281253875, + 0.009549150281252989, + 0.0, + 0.009549150281252692, + 0.03454915028125239, + 0.06545084971874746, + 0.09045084971874952, + 0.10000000000000353 + ]) + // use f64::EPSILON will fail. Seems a floating number difference between linux and macos. + .all(|(x, y)| (x - y).abs() < f32::EPSILON as f64)); } - lrs.push(lr_scheduler.current_lr); - assert_eq!( - lrs, - vec![ - 0.1, - 0.09045084971874785, - 0.06545084971874875, - 0.034549150281253875, - 0.009549150281252989, - 0.0, - 0.009549150281252692, - 0.03454915028125239, - 0.06545084971874746, - 0.09045084971874952, - 0.10000000000000353 - ] - ) } diff --git a/src/dataset.rs b/src/dataset.rs index 8d90368..6d6004f 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -1,12 +1,11 @@ +use crate::convertor::anki21_sample_file_converted_to_fsrs; use burn::data::dataloader::batcher::Batcher; use burn::{ - data::dataset::{Dataset, InMemDataset}, + data::dataset::Dataset, tensor::{backend::Backend, Data, ElementConversion, Float, Int, Shape, Tensor}, }; use serde::{Deserialize, Serialize}; -use crate::convertor::anki_to_fsrs; - /// Stores a list of reviews for a card, in chronological order. Each FSRSItem corresponds /// to a single review, but contains the previous reviews of the card as well, after the /// first one. @@ -117,229 +116,238 @@ impl Batcher> for FSRSBatcher { } pub struct FSRSDataset { - dataset: InMemDataset, + items: Vec, } impl Dataset for FSRSDataset { fn len(&self) -> usize { - self.dataset.len() + self.items.len() } fn get(&self, index: usize) -> Option { - self.dataset.get(index) + self.items.get(index).cloned() } } impl FSRSDataset { pub fn train() -> Self { - Self::new() + Self::new_from_test_file() } pub fn test() -> Self { - Self::new() + Self::new_from_test_file() } pub fn len(&self) -> usize { - self.dataset.len() + self.items.len() } pub fn is_empty(&self) -> bool { - self.dataset.is_empty() + self.items.is_empty() } - fn new() -> Self { - let dataset = InMemDataset::::new(anki_to_fsrs()); - Self { dataset } + fn new_from_test_file() -> Self { + anki21_sample_file_converted_to_fsrs().into() } } -#[test] -fn test_from_anki() { - use burn::data::dataloader::Dataset; - use burn::data::dataset::InMemDataset; +impl From> for FSRSDataset { + fn from(items: Vec) -> Self { + Self { items } + } +} - let dataset = InMemDataset::::new(anki_to_fsrs()); - dbg!(dataset.get(704).unwrap()); +#[cfg(test)] +mod tests { + use super::*; - use burn_ndarray::NdArrayDevice; - let device = NdArrayDevice::Cpu; - use burn_ndarray::NdArrayBackend; - type Backend = NdArrayBackend; - let batcher = FSRSBatcher::::new(device); - use burn::data::dataloader::DataLoaderBuilder; - let dataloader = DataLoaderBuilder::new(batcher) - .batch_size(1) - .shuffle(42) - .num_workers(4) - .build(dataset); - dbg!( - dataloader - .iter() - .next() - .expect("loader is empty") - .r_historys - ); -} + #[test] + fn from_anki() { + use burn::data::dataloader::Dataset; -#[test] -fn test_batcher() { - use burn_ndarray::NdArrayBackend; - use burn_ndarray::NdArrayDevice; - type Backend = NdArrayBackend; - let device = NdArrayDevice::Cpu; - let batcher = FSRSBatcher::::new(device); - let items = vec![ - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0, - }, - FSRSReview { - rating: 3, - delta_t: 5, - }, - ], - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0, - }, - FSRSReview { - rating: 3, - delta_t: 5, - }, - FSRSReview { - rating: 3, - delta_t: 11, - }, - ], - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0, - }, - FSRSReview { - rating: 3, - delta_t: 2, - }, - ], - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0, - }, - FSRSReview { - rating: 3, - delta_t: 2, - }, - FSRSReview { - rating: 3, - delta_t: 6, - }, - ], - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0, - }, - FSRSReview { - rating: 3, - delta_t: 2, - }, - FSRSReview { - rating: 3, - delta_t: 6, - }, - FSRSReview { - rating: 3, - delta_t: 16, - }, - ], - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0, - }, - FSRSReview { - rating: 3, - delta_t: 2, - }, - FSRSReview { - rating: 3, - delta_t: 6, - }, - FSRSReview { - rating: 3, - delta_t: 16, - }, - FSRSReview { - rating: 3, - delta_t: 39, - }, - ], - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 1, - delta_t: 0, - }, - FSRSReview { - rating: 1, - delta_t: 1, - }, - ], - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 1, - delta_t: 0, - }, - FSRSReview { - rating: 1, - delta_t: 1, - }, - FSRSReview { - rating: 3, - delta_t: 1, - }, - ], - }, - ]; - let batch = batcher.batch(items); - assert_eq!( - batch.t_historys.to_data(), - Data::from([ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 5.0, 0.0, 2.0, 2.0, 2.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 16.0, 0.0, 0.0] - ]) - ); - assert_eq!( - batch.r_historys.to_data(), - Data::from([ - [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0], - [0.0, 3.0, 0.0, 3.0, 3.0, 3.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0] - ]) - ); - assert_eq!( - batch.delta_ts.to_data(), - Data::from([5.0, 11.0, 2.0, 6.0, 16.0, 39.0, 1.0, 1.0]) - ); - assert_eq!(batch.labels.to_data(), Data::from([1, 1, 1, 1, 1, 1, 0, 1])); + let dataset = FSRSDataset::from(anki21_sample_file_converted_to_fsrs()); + dbg!(dataset.get(704).unwrap()); + + use burn_ndarray::NdArrayDevice; + let device = NdArrayDevice::Cpu; + use burn_ndarray::NdArrayBackend; + type Backend = NdArrayBackend; + let batcher = FSRSBatcher::::new(device); + use burn::data::dataloader::DataLoaderBuilder; + let dataloader = DataLoaderBuilder::new(batcher) + .batch_size(1) + .shuffle(42) + .num_workers(4) + .build(dataset); + dbg!( + dataloader + .iter() + .next() + .expect("loader is empty") + .r_historys + ); + } + + #[test] + fn batcher() { + use burn_ndarray::NdArrayBackend; + use burn_ndarray::NdArrayDevice; + type Backend = NdArrayBackend; + let device = NdArrayDevice::Cpu; + let batcher = FSRSBatcher::::new(device); + let items = vec![ + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 4, + delta_t: 0, + }, + FSRSReview { + rating: 3, + delta_t: 5, + }, + ], + }, + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 4, + delta_t: 0, + }, + FSRSReview { + rating: 3, + delta_t: 5, + }, + FSRSReview { + rating: 3, + delta_t: 11, + }, + ], + }, + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 4, + delta_t: 0, + }, + FSRSReview { + rating: 3, + delta_t: 2, + }, + ], + }, + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 4, + delta_t: 0, + }, + FSRSReview { + rating: 3, + delta_t: 2, + }, + FSRSReview { + rating: 3, + delta_t: 6, + }, + ], + }, + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 4, + delta_t: 0, + }, + FSRSReview { + rating: 3, + delta_t: 2, + }, + FSRSReview { + rating: 3, + delta_t: 6, + }, + FSRSReview { + rating: 3, + delta_t: 16, + }, + ], + }, + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 4, + delta_t: 0, + }, + FSRSReview { + rating: 3, + delta_t: 2, + }, + FSRSReview { + rating: 3, + delta_t: 6, + }, + FSRSReview { + rating: 3, + delta_t: 16, + }, + FSRSReview { + rating: 3, + delta_t: 39, + }, + ], + }, + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 1, + delta_t: 0, + }, + FSRSReview { + rating: 1, + delta_t: 1, + }, + ], + }, + FSRSItem { + reviews: vec![ + FSRSReview { + rating: 1, + delta_t: 0, + }, + FSRSReview { + rating: 1, + delta_t: 1, + }, + FSRSReview { + rating: 3, + delta_t: 1, + }, + ], + }, + ]; + let batch = batcher.batch(items); + assert_eq!( + batch.t_historys.to_data(), + Data::from([ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 5.0, 0.0, 2.0, 2.0, 2.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 16.0, 0.0, 0.0] + ]) + ); + assert_eq!( + batch.r_historys.to_data(), + Data::from([ + [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0], + [0.0, 3.0, 0.0, 3.0, 3.0, 3.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0] + ]) + ); + assert_eq!( + batch.delta_ts.to_data(), + Data::from([5.0, 11.0, 2.0, 6.0, 16.0, 39.0, 1.0, 1.0]) + ); + assert_eq!(batch.labels.to_data(), Data::from([1, 1, 1, 1, 1, 1, 0, 1])); + } } diff --git a/src/lib.rs b/src/lib.rs index 5ac6c27..52e9574 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,3 +6,5 @@ pub mod dataset; pub mod model; pub mod training; mod weight_clipper; + +mod batch_shuffle; diff --git a/src/model.rs b/src/model.rs index 89888ea..c37787b 100644 --- a/src/model.rs +++ b/src/model.rs @@ -154,156 +154,6 @@ impl ModelConfig { } } -#[test] -fn test_w() { - use burn::tensor::Data; - use burn_ndarray::NdArrayBackend; - type Backend = NdArrayBackend; - let model = Model::::new(); - assert_eq!( - model.w().to_data(), - Data::from([ - [0.4], - [0.6], - [2.4], - [5.8], - [4.93], - [0.94], - [0.86], - [0.01], - [1.49], - [0.14], - [0.94], - [2.18], - [0.05], - [0.34], - [1.26], - [0.29], - [2.61] - ]) - ) -} - -#[test] -fn test_power_forgetting_curve() { - use burn::tensor::Data; - use burn_ndarray::NdArrayBackend; - type Backend = NdArrayBackend; - let model = Model::::new(); - let delta_t = Tensor::::from_floats([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]]); - let stability = Tensor::::from_floats([[1.0], [2.0], [3.0], [4.0], [4.0], [2.0]]); - let retention = model.power_forgetting_curve(delta_t, stability); - assert_eq!( - retention.to_data(), - Data::from([ - [1.0], - [0.9473684], - [0.9310345], - [0.92307687], - [0.9], - [0.7826087] - ]) - ) -} - -#[test] -fn test_init_stability() { - use burn::tensor::Data; - use burn_ndarray::NdArrayBackend; - type Backend = NdArrayBackend; - let model = Model::::new(); - let rating = Tensor::::from_floats([[1.0], [2.0], [3.0], [4.0], [1.0], [2.0]]); - let stability = model.init_stability(rating); - assert_eq!( - stability.to_data(), - Data::from([[0.4], [0.6], [2.4], [5.8], [0.4], [0.6]]) - ) -} - -#[test] -fn test_init_difficulty() { - use burn::tensor::Data; - use burn_ndarray::NdArrayBackend; - type Backend = NdArrayBackend; - let model = Model::::new(); - let rating = Tensor::::from_floats([[1.0], [2.0], [3.0], [4.0], [1.0], [2.0]]); - let difficulty = model.init_difficulty(rating); - assert_eq!( - difficulty.to_data(), - Data::from([[6.81], [5.87], [4.93], [3.9899998], [6.81], [5.87]]) - ) -} - -#[test] -fn test_next_difficulty() { - use burn::tensor::Data; - use burn_ndarray::NdArrayBackend; - type Backend = NdArrayBackend; - let model = Model::::new(); - let difficulty = Tensor::::from_floats([[5.0], [5.0], [5.0], [5.0]]); - let rating = Tensor::::from_floats([[1.0], [2.0], [3.0], [4.0]]); - let next_difficulty = model.next_difficulty(difficulty, rating); - assert_eq!( - next_difficulty.to_data(), - Data::from([[6.7200003], [5.86], [5.0], [4.14]]) - ); - let next_difficulty = model.mean_reversion(next_difficulty); - assert_eq!( - next_difficulty.to_data(), - Data::from([[6.7021003], [5.8507], [4.9993], [4.1478996]]) - ) -} - -#[test] -fn test_next_stability() { - use burn::tensor::Data; - use burn_ndarray::NdArrayBackend; - type Backend = NdArrayBackend; - let model = Model::::new(); - let stability = Tensor::::from_floats([[5.0], [5.0], [5.0], [5.0]]); - let difficulty = Tensor::::from_floats([[1.0], [2.0], [3.0], [4.0]]); - let retention = Tensor::::from_floats([[0.9], [0.8], [0.7], [0.6]]); - let rating = Tensor::::from_floats([[1.0], [2.0], [3.0], [4.0]]); - let s_recall = model.stability_after_success( - stability.clone(), - difficulty.clone(), - retention.clone(), - rating.clone(), - ); - assert_eq!( - s_recall.to_data(), - Data::from([[22.454704], [14.560361], [51.15574], [152.6869]]) - ); - let s_forget = model.stability_after_failure(stability, difficulty, retention); - assert_eq!( - s_forget.to_data(), - Data::from([[2.074517], [2.2729328], [2.526406], [2.8247323]]) - ); - let next_stability = s_recall.mask_where(rating.clone().equal_elem(1), s_forget); - assert_eq!( - next_stability.to_data(), - Data::from([[2.074517], [14.560361], [51.15574], [152.6869]]) - ) -} - -#[test] -fn test_forward() { - use burn_ndarray::NdArrayBackend; - type Backend = NdArrayBackend; - let model = Model::::new(); - let delta_ts = Tensor::::from_floats([ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 2.0, 2.0], - ]); - let ratings = Tensor::::from_floats([ - [1.0, 2.0, 3.0, 4.0, 1.0, 2.0], - [1.0, 2.0, 3.0, 4.0, 1.0, 2.0], - ]); - let (stability, difficulty) = model.forward(delta_ts, ratings); - dbg!(&stability); - dbg!(&difficulty); -} - #[cfg(test)] mod tests { use burn::tensor::Data; @@ -313,7 +163,106 @@ mod tests { use super::*; #[test] - fn test_next_difficulty() { + fn w() { + use burn::tensor::Data; + use burn_ndarray::NdArrayBackend; + type Backend = NdArrayBackend; + let model = Model::::new(); + assert_eq!( + model.w().to_data(), + Data::from([ + [0.4], + [0.6], + [2.4], + [5.8], + [4.93], + [0.94], + [0.86], + [0.01], + [1.49], + [0.14], + [0.94], + [2.18], + [0.05], + [0.34], + [1.26], + [0.29], + [2.61] + ]) + ) + } + + #[test] + fn power_forgetting_curve() { + use burn::tensor::Data; + use burn_ndarray::NdArrayBackend; + type Backend = NdArrayBackend; + let model = Model::::new(); + let delta_t = Tensor::::from_floats([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]]); + let stability = + Tensor::::from_floats([[1.0], [2.0], [3.0], [4.0], [4.0], [2.0]]); + let retention = model.power_forgetting_curve(delta_t, stability); + assert_eq!( + retention.to_data(), + Data::from([ + [1.0], + [0.9473684], + [0.9310345], + [0.92307687], + [0.9], + [0.7826087] + ]) + ) + } + + #[test] + fn init_stability() { + use burn::tensor::Data; + use burn_ndarray::NdArrayBackend; + type Backend = NdArrayBackend; + let model = Model::::new(); + let rating = Tensor::::from_floats([[1.0], [2.0], [3.0], [4.0], [1.0], [2.0]]); + let stability = model.init_stability(rating); + assert_eq!( + stability.to_data(), + Data::from([[0.4], [0.6], [2.4], [5.8], [0.4], [0.6]]) + ) + } + + #[test] + fn init_difficulty() { + use burn::tensor::Data; + use burn_ndarray::NdArrayBackend; + type Backend = NdArrayBackend; + let model = Model::::new(); + let rating = Tensor::::from_floats([[1.0], [2.0], [3.0], [4.0], [1.0], [2.0]]); + let difficulty = model.init_difficulty(rating); + assert_eq!( + difficulty.to_data(), + Data::from([[6.81], [5.87], [4.93], [3.9899998], [6.81], [5.87]]) + ) + } + + #[test] + fn forward() { + use burn_ndarray::NdArrayBackend; + type Backend = NdArrayBackend; + let model = Model::::new(); + let delta_ts = Tensor::::from_floats([ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 2.0, 2.0], + ]); + let ratings = Tensor::::from_floats([ + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0], + ]); + let (stability, difficulty) = model.forward(delta_ts, ratings); + dbg!(&stability); + dbg!(&difficulty); + } + + #[test] + fn next_difficulty() { let model = Model::::new(); let difficulty = Tensor::::from_floats([[5.0], [5.0], [5.0], [5.0]]); let rating = Tensor::::from_floats([[1.0], [2.0], [3.0], [4.0]]); @@ -332,7 +281,7 @@ mod tests { } #[test] - fn test_next_stability() { + fn next_stability() { let model = Model::::new(); let stability = Tensor::::from_floats([[5.0], [5.0], [5.0], [5.0]]); let difficulty = Tensor::::from_floats([[1.0], [2.0], [3.0], [4.0]]); diff --git a/src/training.rs b/src/training.rs index efc6e56..3764f5f 100644 --- a/src/training.rs +++ b/src/training.rs @@ -1,5 +1,4 @@ -use std::path::Path; - +use crate::batch_shuffle::BatchShuffledDataset; use crate::cosine_annealing::CosineAnnealingLR; use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset}; use crate::model::{Model, ModelConfig}; @@ -15,6 +14,7 @@ use burn::{ train::LearnerBuilder, }; use log::info; +use std::path::Path; impl> Model { fn bceloss(&self, retentions: Tensor, labels: Tensor) -> Tensor { @@ -38,8 +38,7 @@ impl> Model { delta_ts.clone().unsqueeze::<2>().transpose(), stability.clone(), ); - let logits = - Tensor::cat(vec![-retention.clone() + 1, retention.clone()], 0).reshape([2, -1]); + let logits = Tensor::cat(vec![-retention.clone() + 1, retention.clone()], 1); info!("stability: {}", &stability); info!( "delta_ts: {}", @@ -125,14 +124,14 @@ pub fn train>( let dataloader_train = DataLoaderBuilder::new(batcher_train) .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build(FSRSDataset::train()); + .build(BatchShuffledDataset::with_seed( + FSRSDataset::train(), + config.batch_size, + config.seed, + )); let dataloader_test = DataLoaderBuilder::new(batcher_valid) .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) .build(FSRSDataset::test()); let lr_scheduler = CosineAnnealingLR::init( @@ -176,18 +175,27 @@ pub fn train>( .expect("Failed to save trained model"); } -#[test] -fn test() { - use burn_ndarray::NdArrayBackend; - use burn_ndarray::NdArrayDevice; - type Backend = NdArrayBackend; - type AutodiffBackend = burn_autodiff::ADBackendDecorator; - let device = NdArrayDevice::Cpu; - - let artifact_dir = ARTIFACT_DIR; - train::( - artifact_dir, - TrainingConfig::new(ModelConfig::new(), AdamConfig::new()), - device, - ); +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn training() { + if std::env::var("SKIP_TRAINING").is_ok() { + println!("Skipping test in CI"); + return; + } + use burn_ndarray::NdArrayBackend; + use burn_ndarray::NdArrayDevice; + type Backend = NdArrayBackend; + type AutodiffBackend = burn_autodiff::ADBackendDecorator; + let device = NdArrayDevice::Cpu; + + let artifact_dir = ARTIFACT_DIR; + train::( + artifact_dir, + TrainingConfig::new(ModelConfig::new(), AdamConfig::new()), + device, + ); + } } diff --git a/src/weight_clipper.rs b/src/weight_clipper.rs index 6555d68..e71d9f2 100644 --- a/src/weight_clipper.rs +++ b/src/weight_clipper.rs @@ -28,21 +28,26 @@ pub fn weight_clipper>(weights: Tensor) -> Ten Tensor::from_data(Data::new(val.clone(), weights.shape())) } -#[test] -fn weight_clipper_test() { - type Backend = burn_ndarray::NdArrayBackend; - //type AutodiffBackend = burn_autodiff::ADBackendDecorator; - - let tensor = Tensor::from_floats([ - 0.0, -1000.0, 1000.0, 0.0, // Ignored - 1000.0, -1000.0, 1.0, 0.25, -0.1, - ]); // Clamped (1.0, 10.0),(0.1, 5.0),(0.1, 5.0),(0.0, 0.5), - - let param: Tensor = weight_clipper(tensor); - let values = ¶m.to_data().value; - - assert_eq!( - *values, - vec![0.0, -1000.0, 1000.0, 0.0, 10.0, 0.1, 1.0, 0.25, 0.0] - ); +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn weight_clipper_works() { + type Backend = burn_ndarray::NdArrayBackend; + //type AutodiffBackend = burn_autodiff::ADBackendDecorator; + + let tensor = Tensor::from_floats([ + 0.0, -1000.0, 1000.0, 0.0, // Ignored + 1000.0, -1000.0, 1.0, 0.25, -0.1, + ]); // Clamped (1.0, 10.0),(0.1, 5.0),(0.1, 5.0),(0.0, 0.5), + + let param: Tensor = weight_clipper(tensor); + let values = ¶m.to_data().value; + + assert_eq!( + *values, + vec![0.0, -1000.0, 1000.0, 0.0, 10.0, 0.1, 1.0, 0.25, 0.0] + ); + } }