From 4ee0bd0a26220ac8f2d32c742d2978005e4f3827 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Tue, 22 Aug 2023 14:50:08 +0800 Subject: [PATCH] Add BCELoss --- src/training.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/training.rs b/src/training.rs index bea993f..36fd6dd 100644 --- a/src/training.rs +++ b/src/training.rs @@ -2,7 +2,6 @@ use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset}; use crate::model::{Model, ModelConfig}; use crate::weight_clipper::weight_clipper; use burn::module::Module; -use burn::nn::loss::CrossEntropyLoss; use burn::optim::AdamConfig; use burn::record::{FullPrecisionSettings, PrettyJsonFileRecorder, Recorder}; use burn::tensor::backend::Backend; @@ -15,6 +14,12 @@ use burn::{ use log::info; impl> Model { + fn bceloss(&self, retentions: Tensor, labels: Tensor) -> Tensor { + let loss: Tensor = + labels.clone() * retentions.clone().log() + (-labels + 1) * (-retentions + 1).log(); + loss.mean().neg() + } + pub fn forward_classification( &self, t_historys: Tensor, @@ -22,19 +27,18 @@ impl> Model { delta_ts: Tensor, labels: Tensor, ) -> ClassificationOutput { - // dbg!(&t_historys); - // dbg!(&r_historys); + // info!("t_historys: {}", &t_historys); + // info!("r_historys: {}", &r_historys); let (stability, _difficulty) = self.forward(t_historys, r_historys); let retention = self.power_forgetting_curve(delta_ts.clone(), stability.clone()); - // dbg!(&retention); let logits = - Tensor::cat(vec![retention.clone(), -retention.clone() + 1], 0).reshape([1, -1]); + Tensor::cat(vec![-retention.clone() + 1, retention.clone()], 0).reshape([1, -1]); info!("stability: {}", &stability); info!("delta_ts: {}", &delta_ts); info!("retention: {}", &retention); info!("logits: {}", &logits); info!("labels: {}", &labels); - let loss = CrossEntropyLoss::new(None).forward(logits.clone(), labels.clone()); + let loss = self.bceloss(retention.clone(), labels.clone().float()); ClassificationOutput::new(loss, logits, labels) } } @@ -124,7 +128,9 @@ pub fn train>( ); let mut model_trained = learner.fit(dataloader_train, dataloader_test); + info!("trained weights: {}", &model_trained.w.val()); model_trained.w = Param::from(weight_clipper(model_trained.w.val())); + info!("clipped weights: {}", &model_trained.w.val()); config .save(format!("{ARTIFACT_DIR}/config.json").as_str()) @@ -136,8 +142,6 @@ pub fn train>( format!("{ARTIFACT_DIR}/model").into(), ) .expect("Failed to save trained model"); - - info!("trained weights: {}", &model_trained.w.val()); } #[test]