Skip to content

Commit

Permalink
Move current review into the list of reviews, instead of storing sepa…
Browse files Browse the repository at this point in the history
…rately
  • Loading branch information
dae committed Aug 22, 2023
1 parent d327628 commit a423fe1
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 63 deletions.
52 changes: 30 additions & 22 deletions src/convertor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,16 @@ fn convert_to_fsrs_items(
.iter()
.enumerate()
.skip(1)
.map(|(idx, entry)| {
.map(|(idx, _)| {
let reviews = entries
.iter()
.take(idx)
.take(idx + 1)
.map(|r| FSRSReview {
rating: r.button_chosen,
delta_t: r.delta_t,
})
.collect();
FSRSItem {
reviews,
delta_t: entry.delta_t,
rating: entry.button_chosen,
}
FSRSItem { reviews }
})
.collect(),
)
Expand Down Expand Up @@ -241,7 +237,7 @@ mod tests {
assert_eq!(fsrs_items.len(), 14290);
assert_eq!(
fsrs_items.iter().map(|x| x.reviews.len()).sum::<usize>(),
49382
49382 + 14290
);

// convert a subset and check it matches expectations
Expand All @@ -254,12 +250,16 @@ mod tests {
&fsrs_items,
&[
FSRSItem {
reviews: vec![FSRSReview {
rating: 3,
delta_t: 0
}],
delta_t: 5,
rating: 3
reviews: vec![
FSRSReview {
rating: 3,
delta_t: 0
},
FSRSReview {
rating: 3,
delta_t: 5
}
],
},
FSRSItem {
reviews: vec![
Expand All @@ -270,10 +270,12 @@ mod tests {
FSRSReview {
rating: 3,
delta_t: 5
},
FSRSReview {
rating: 3,
delta_t: 10
}
],
delta_t: 10,
rating: 3
},
FSRSItem {
reviews: vec![
Expand All @@ -288,10 +290,12 @@ mod tests {
FSRSReview {
rating: 3,
delta_t: 10
},
FSRSReview {
rating: 3,
delta_t: 22
}
],
delta_t: 22,
rating: 3
},
FSRSItem {
reviews: vec![
Expand All @@ -310,10 +314,12 @@ mod tests {
FSRSReview {
rating: 3,
delta_t: 22
},
FSRSReview {
rating: 2,
delta_t: 56
}
],
delta_t: 56,
rating: 2
},
FSRSItem {
reviews: vec![
Expand All @@ -336,10 +342,12 @@ mod tests {
FSRSReview {
rating: 2,
delta_t: 56
},
FSRSReview {
rating: 3,
delta_t: 64
}
],
delta_t: 64,
rating: 3
}
]
);
Expand Down
78 changes: 37 additions & 41 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@ use serde::{Deserialize, Serialize};

use crate::convertor::anki_to_fsrs;

/// Represents a single review on a card, and contains the previous reviews for that card.
/// Stores a list of reviews for a card, in chronological order. Each FSRSItem corresponds
/// to a single review, but contains the previous reviews of the card as well, after the
/// first one.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct FSRSItem {
/// The previous reviews done to the card before the current one
pub reviews: Vec<FSRSReview>,
/// 1-4
pub rating: i32,
/// The number of days that passed
pub delta_t: i32,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
Expand All @@ -26,6 +23,17 @@ pub struct FSRSReview {
pub delta_t: i32,
}

impl FSRSItem {
// The previous reviews done before the current one.
pub(crate) fn history(&self) -> impl Iterator<Item = &FSRSReview> {
self.reviews.iter().take(self.reviews.len() - 1)
}

pub(crate) fn current(&self) -> &FSRSReview {
self.reviews.last().unwrap()
}
}

pub struct FSRSBatcher<B: Backend> {
device: B::Device,
}
Expand All @@ -46,54 +54,42 @@ pub struct FSRSBatch<B: Backend> {

impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
fn batch(&self, items: Vec<FSRSItem>) -> FSRSBatch<B> {
let t_historys = items
.iter()
.map(|item| {
Data::new(
item.reviews.iter().map(|r| r.delta_t).collect(),
Shape {
dims: [item.reviews.len()],
},
)
})
.map(|data| Tensor::<B, 1>::from_data(data.convert()))
.map(|tensor| tensor.unsqueeze())
.collect();

let r_historys = items
let (time_histories, rating_histories) = items
.iter()
.map(|item| {
Data::new(
item.reviews.iter().map(|r| r.rating).collect(),
Shape {
dims: [item.reviews.len()],
},
let (delta_t, rating): (Vec<_>, _) =
item.history().map(|r| (r.delta_t, r.rating)).unzip();
let count = delta_t.len();
let delta_t = Tensor::<B, 1>::from_data(
Data::new(delta_t, Shape { dims: [count] }).convert(),
)
.unsqueeze();
let rating =
Tensor::<B, 1>::from_data(Data::new(rating, Shape { dims: [count] }).convert())
.unsqueeze();
(delta_t, rating)
})
.map(|data| Tensor::<B, 1>::from_data(data.convert()))
.map(|tensor| tensor.unsqueeze())
.collect();

let delta_ts = items
.iter()
.map(|item| Tensor::<B, 1, Float>::from_data(Data::from([item.delta_t.elem()])))
.collect();
.unzip();

let labels = items
let (delta_ts, labels) = items
.iter()
.map(|item| {
Tensor::<B, 1, Int>::from_data(Data::from([match item.rating {
let current = item.current();
let delta_t = current.delta_t;
let delta_t = Tensor::<B, 1, Float>::from_data(Data::from([delta_t.elem()]));
let label = match current.rating {
1 => 0.0,
_ => 1.0,
}
.elem()]))
};
let label = Tensor::<B, 1, Int>::from_data(Data::from([label.elem()]));
(delta_t, label)
})
.collect();
.unzip();

let t_historys = Tensor::cat(t_historys, 0)
let t_historys = Tensor::cat(time_histories, 0)
.transpose()
.to_device(&self.device); // [seq_len, batch_size]
let r_historys = Tensor::cat(r_historys, 0)
let r_historys = Tensor::cat(rating_histories, 0)
.transpose()
.to_device(&self.device); // [seq_len, batch_size]
let delta_ts = Tensor::cat(delta_ts, 0).to_device(&self.device);
Expand Down

0 comments on commit a423fe1

Please sign in to comment.