Skip to content

Commit

Permalink
Add BCELoss
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Aug 22, 2023
1 parent d0b0239 commit 4ee0bd0
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset};
use crate::model::{Model, ModelConfig};
use crate::weight_clipper::weight_clipper;
use burn::module::Module;
use burn::nn::loss::CrossEntropyLoss;
use burn::optim::AdamConfig;
use burn::record::{FullPrecisionSettings, PrettyJsonFileRecorder, Recorder};
use burn::tensor::backend::Backend;
Expand All @@ -15,26 +14,31 @@ use burn::{
use log::info;

impl<B: Backend<FloatElem = f32>> Model<B> {
fn bceloss(&self, retentions: Tensor<B, 1>, labels: Tensor<B, 1>) -> Tensor<B, 1> {
let loss: Tensor<B, 1> =
labels.clone() * retentions.clone().log() + (-labels + 1) * (-retentions + 1).log();
loss.mean().neg()
}

pub fn forward_classification(
&self,
t_historys: Tensor<B, 2>,
r_historys: Tensor<B, 2>,
delta_ts: Tensor<B, 1>,
labels: Tensor<B, 1, Int>,
) -> ClassificationOutput<B> {
// dbg!(&t_historys);
// dbg!(&r_historys);
// info!("t_historys: {}", &t_historys);
// info!("r_historys: {}", &r_historys);
let (stability, _difficulty) = self.forward(t_historys, r_historys);
let retention = self.power_forgetting_curve(delta_ts.clone(), stability.clone());
// dbg!(&retention);
let logits =
Tensor::cat(vec![retention.clone(), -retention.clone() + 1], 0).reshape([1, -1]);
Tensor::cat(vec![-retention.clone() + 1, retention.clone()], 0).reshape([1, -1]);
info!("stability: {}", &stability);
info!("delta_ts: {}", &delta_ts);
info!("retention: {}", &retention);
info!("logits: {}", &logits);
info!("labels: {}", &labels);
let loss = CrossEntropyLoss::new(None).forward(logits.clone(), labels.clone());
let loss = self.bceloss(retention.clone(), labels.clone().float());
ClassificationOutput::new(loss, logits, labels)
}
}
Expand Down Expand Up @@ -124,7 +128,9 @@ pub fn train<B: ADBackend<FloatElem = f32>>(
);

let mut model_trained = learner.fit(dataloader_train, dataloader_test);
info!("trained weights: {}", &model_trained.w.val());
model_trained.w = Param::from(weight_clipper(model_trained.w.val()));
info!("clipped weights: {}", &model_trained.w.val());

config
.save(format!("{ARTIFACT_DIR}/config.json").as_str())
Expand All @@ -136,8 +142,6 @@ pub fn train<B: ADBackend<FloatElem = f32>>(
format!("{ARTIFACT_DIR}/model").into(),
)
.expect("Failed to save trained model");

info!("trained weights: {}", &model_trained.w.val());
}

#[test]
Expand Down

0 comments on commit 4ee0bd0

Please sign in to comment.