diff --git a/src/training.rs b/src/training.rs index 8fcfbb8..41eee0a 100644 --- a/src/training.rs +++ b/src/training.rs @@ -35,8 +35,7 @@ impl> Model { delta_ts.clone().unsqueeze::<2>().transpose(), stability.clone(), ); - let logits = - Tensor::cat(vec![-retention.clone() + 1, retention.clone()], 1); + let logits = Tensor::cat(vec![-retention.clone() + 1, retention.clone()], 1); info!("stability: {}", &stability); info!( "delta_ts: {}",