Skip to content

Commit

Permalink
Feat/cosine_annealing_lr
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Aug 24, 2023
1 parent f1106cb commit 76d4ac9
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 4 deletions.
75 changes: 75 additions & 0 deletions src/cosine_annealing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use burn::{lr_scheduler::LRScheduler, LearningRate};
#[derive(Clone, Debug)]
pub struct CosineAnnealingLR {
t_max: f64,
eta_min: f64,
init_lr: f64,
step_count: f64,
current_lr: LearningRate,
}

impl CosineAnnealingLR {
pub fn init(t_max: f64, init_lr: f64) -> CosineAnnealingLR {
CosineAnnealingLR {
t_max,
eta_min: 0.0,
init_lr,
step_count: 0.0,
current_lr: init_lr,
}
}
}

impl LRScheduler for CosineAnnealingLR {
type Record = usize;

fn step(&mut self) -> LearningRate {
self.step_count += 1.0;
use std::f64::consts::PI;
fn cosine_annealing_lr(
init_lr: f64,
lr: f64,
step_count: f64,
t_max: f64,
eta_min: f64,
) -> f64 {
let cosine_arg = PI * step_count / t_max;
if (step_count - 1.0 - t_max) % (2.0 * t_max) == 0.0 {
(init_lr - eta_min) * (1.0 - f64::cos(PI / t_max)) / 2.0
} else {
(1.0 + f64::cos(cosine_arg)) / (1.0 + f64::cos(PI * (step_count - 1.0) / t_max))
* (lr - eta_min)
+ eta_min
}
}
self.current_lr = cosine_annealing_lr(
self.init_lr,
self.current_lr,
self.step_count,
self.t_max,
self.eta_min,
);
self.current_lr
}

fn to_record(&self) -> Self::Record {
self.step_count as usize
}

fn load_record(mut self, record: Self::Record) -> Self {
self.step_count = record as f64;
self
}
}

#[test]
fn test_lr_scheduler() {
let mut lr_scheduler = CosineAnnealingLR::init(100000.0, 1.0e-1);
for i in 0..400000 {
if i % 5000 == 0 {
println!("{}", lr_scheduler.current_lr);
}
lr_scheduler.step();
}
println!("{}", lr_scheduler.current_lr);
}
4 changes: 4 additions & 0 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ impl FSRSDataset {
Self::new()
}

pub fn len(&self) -> usize {
self.dataset.len()
}

fn new() -> Self {
let dataset = InMemDataset::<FSRSItem>::new(anki_to_fsrs());
Self { dataset }
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod convertor;
mod cosine_annealing;
pub mod dataset;
pub mod model;
pub mod training;
Expand Down
14 changes: 10 additions & 4 deletions src/training.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::cosine_annealing::CosineAnnealingLR;
use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset};
use crate::model::{Model, ModelConfig};
use crate::weight_clipper::weight_clipper;
Expand Down Expand Up @@ -75,13 +76,13 @@ pub struct TrainingConfig {
pub optimizer: AdamConfig,
#[config(default = 10)]
pub num_epochs: usize,
#[config(default = 2)]
#[config(default = 1)]
pub batch_size: usize,
#[config(default = 4)]
pub num_workers: usize,
#[config(default = 42)]
pub seed: u64,
#[config(default = 1.0e-4)]
#[config(default = 1.0e-3)]
pub learning_rate: f64,
}

Expand Down Expand Up @@ -113,18 +114,23 @@ pub fn train<B: ADBackend<FloatElem = f32>>(
.num_workers(config.num_workers)
.build(FSRSDataset::test());

let lr_scheduler = CosineAnnealingLR::init(
(FSRSDataset::train().len() * config.num_epochs) as f64,
config.learning_rate,
);

let learner = LearnerBuilder::new(artifact_dir)
// .metric_train_plot(AccuracyMetric::new())
// .metric_valid_plot(AccuracyMetric::new())
// .metric_train_plot(LossMetric::new())
// .metric_valid_plot(LossMetric::new())
.with_file_checkpointer(1, PrettyJsonFileRecorder::<FullPrecisionSettings>::new())
.with_file_checkpointer(10, PrettyJsonFileRecorder::<FullPrecisionSettings>::new())
.devices(vec![device])
.num_epochs(config.num_epochs)
.build(
config.model.init::<B>(),
config.optimizer.init(),
config.learning_rate,
lr_scheduler,
);

let mut model_trained = learner.fit(dataloader_train, dataloader_test);
Expand Down

0 comments on commit 76d4ac9

Please sign in to comment.