diff --git a/src/cosine_annealing.rs b/src/cosine_annealing.rs new file mode 100644 index 0000000..0e910c4 --- /dev/null +++ b/src/cosine_annealing.rs @@ -0,0 +1,75 @@ +use burn::{lr_scheduler::LRScheduler, LearningRate}; +#[derive(Clone, Debug)] +pub struct CosineAnnealingLR { + t_max: f64, + eta_min: f64, + init_lr: f64, + step_count: f64, + current_lr: LearningRate, +} + +impl CosineAnnealingLR { + pub fn init(t_max: f64, init_lr: f64) -> CosineAnnealingLR { + CosineAnnealingLR { + t_max, + eta_min: 0.0, + init_lr, + step_count: 0.0, + current_lr: init_lr, + } + } +} + +impl LRScheduler for CosineAnnealingLR { + type Record = usize; + + fn step(&mut self) -> LearningRate { + self.step_count += 1.0; + use std::f64::consts::PI; + fn cosine_annealing_lr( + init_lr: f64, + lr: f64, + step_count: f64, + t_max: f64, + eta_min: f64, + ) -> f64 { + let cosine_arg = PI * step_count / t_max; + if (step_count - 1.0 - t_max) % (2.0 * t_max) == 0.0 { + (init_lr - eta_min) * (1.0 - f64::cos(PI / t_max)) / 2.0 + } else { + (1.0 + f64::cos(cosine_arg)) / (1.0 + f64::cos(PI * (step_count - 1.0) / t_max)) + * (lr - eta_min) + + eta_min + } + } + self.current_lr = cosine_annealing_lr( + self.init_lr, + self.current_lr, + self.step_count, + self.t_max, + self.eta_min, + ); + self.current_lr + } + + fn to_record(&self) -> Self::Record { + self.step_count as usize + } + + fn load_record(mut self, record: Self::Record) -> Self { + self.step_count = record as f64; + self + } +} + +#[test] +fn test_lr_scheduler() { + let mut lr_scheduler = CosineAnnealingLR::init(100000.0, 1.0e-1); + for i in 0..400000 { + if i % 5000 == 0 { + println!("{}", lr_scheduler.current_lr); + } + lr_scheduler.step(); + } + println!("{}", lr_scheduler.current_lr); +} diff --git a/src/dataset.rs b/src/dataset.rs index c5d2f22..a9db688 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -139,6 +139,10 @@ impl FSRSDataset { Self::new() } + pub fn len(&self) -> usize { + self.dataset.len() + } + fn new() -> Self { let dataset = InMemDataset::::new(anki_to_fsrs()); Self { dataset } diff --git a/src/lib.rs b/src/lib.rs index 9b4d539..f9319eb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod convertor; +mod cosine_annealing; pub mod dataset; pub mod model; pub mod training; diff --git a/src/training.rs b/src/training.rs index 879b286..cdbaa4a 100644 --- a/src/training.rs +++ b/src/training.rs @@ -1,3 +1,4 @@ +use crate::cosine_annealing::CosineAnnealingLR; use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset}; use crate::model::{Model, ModelConfig}; use crate::weight_clipper::weight_clipper; @@ -75,13 +76,13 @@ pub struct TrainingConfig { pub optimizer: AdamConfig, #[config(default = 10)] pub num_epochs: usize, - #[config(default = 2)] + #[config(default = 1)] pub batch_size: usize, #[config(default = 4)] pub num_workers: usize, #[config(default = 42)] pub seed: u64, - #[config(default = 1.0e-4)] + #[config(default = 1.0e-3)] pub learning_rate: f64, } @@ -113,18 +114,23 @@ pub fn train>( .num_workers(config.num_workers) .build(FSRSDataset::test()); + let lr_scheduler = CosineAnnealingLR::init( + (FSRSDataset::train().len() * config.num_epochs) as f64, + config.learning_rate, + ); + let learner = LearnerBuilder::new(artifact_dir) // .metric_train_plot(AccuracyMetric::new()) // .metric_valid_plot(AccuracyMetric::new()) // .metric_train_plot(LossMetric::new()) // .metric_valid_plot(LossMetric::new()) - .with_file_checkpointer(1, PrettyJsonFileRecorder::::new()) + .with_file_checkpointer(10, PrettyJsonFileRecorder::::new()) .devices(vec![device]) .num_epochs(config.num_epochs) .build( config.model.init::(), config.optimizer.init(), - config.learning_rate, + lr_scheduler, ); let mut model_trained = learner.fit(dataloader_train, dataloader_test);