Skip to content

Commit

Permalink
correct shape of logits
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Aug 25, 2023
1 parent e25a7a8 commit 06a9863
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl<B: Backend<FloatElem = f32>> Model<B> {
stability.clone(),
);
let logits =
Tensor::cat(vec![-retention.clone() + 1, retention.clone()], 0).reshape([2, -1]);
Tensor::cat(vec![-retention.clone() + 1, retention.clone()], 1);
info!("stability: {}", &stability);
info!(
"delta_ts: {}",
Expand Down

0 comments on commit 06a9863

Please sign in to comment.