Skip to content

Commit

Permalink
Create batches with random seq_len
Browse files Browse the repository at this point in the history
  • Loading branch information
dae committed Aug 26, 2023
1 parent 5161a92 commit d99d925
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 59 deletions.
105 changes: 53 additions & 52 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ rusqlite = { version = "0.29.0" }
chrono = "0.4.26"
chrono-tz = "0.8.3"
itertools = "0.11.0"
rand = "0.8.5"
36 changes: 32 additions & 4 deletions src/convertor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use chrono::prelude::*;
use chrono_tz::Tz;
use itertools::Itertools;
use rand::Rng;
use rusqlite::{Connection, Result, Row};
use std::collections::HashMap;

Expand Down Expand Up @@ -193,13 +194,40 @@ fn convert_to_fsrs_items(
pub fn anki_to_fsrs() -> Vec<FSRSItem> {
let revlogs = read_collection().expect("read error");
let revlogs_per_card = group_by_cid(revlogs);
let mut revlogs = revlogs_per_card
// collect FSRS items into a map by sequence size
let mut total_fsrs_items = 0;
let mut revlogs_by_seq_size: HashMap<usize, Vec<FSRSItem>> = HashMap::new();
revlogs_per_card
.into_iter()
.filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai))
.flatten()
.collect_vec();
revlogs.sort_by_key(|r| r.reviews.len());
revlogs
.for_each(|r| {
total_fsrs_items += 1;
revlogs_by_seq_size
.entry(r.reviews.len())
.or_default()
.push(r)
});
let mut sizes = revlogs_by_seq_size.keys().copied().collect_vec();
let mut rng = rand::thread_rng();
let mut out: Vec<FSRSItem> = Vec::with_capacity(total_fsrs_items);
while !sizes.is_empty() {
// pick a random sequence size
let size_idx = rng.gen_range(0..sizes.len() as u32) as usize;
let size = &mut sizes[size_idx];
let items = revlogs_by_seq_size.get_mut(size).unwrap();
// add up to 512 items from it to the output vector
for _ in 0..512 {
let Some(item) = items.pop() else {
// this size has run out of items; clear it from available sizes
sizes.swap_remove(size_idx);
break;
};
out.push(item);
}
}

out
}

#[cfg(test)]
Expand Down
1 change: 0 additions & 1 deletion src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ fn test_from_anki() {
use burn::data::dataloader::DataLoaderBuilder;
let dataloader = DataLoaderBuilder::new(batcher)
.batch_size(1)
.shuffle(42)
.num_workers(4)
.build(dataset);
dbg!(
Expand Down
2 changes: 0 additions & 2 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,11 @@ pub fn train<B: ADBackend<FloatElem = f32>>(

let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
// .shuffle(config.seed)
.num_workers(config.num_workers)
.build(FSRSDataset::train());

let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
// .shuffle(config.seed)
.num_workers(config.num_workers)
.build(FSRSDataset::test());

Expand Down

0 comments on commit d99d925

Please sign in to comment.