diff --git a/src/beanmachine/ppl/experimental/causal_inference/models/bart/bart_model.py b/src/beanmachine/ppl/experimental/causal_inference/models/bart/bart_model.py index ae02b92fd4..078b0aa035 100644 --- a/src/beanmachine/ppl/experimental/causal_inference/models/bart/bart_model.py +++ b/src/beanmachine/ppl/experimental/causal_inference/models/bart/bart_model.py @@ -219,7 +219,7 @@ def _step(self) -> Tuple[List, float]: self.X ) self._update_sigma(self.y - self._predict_step()) - return self.all_trees, self.sigma + return self.all_trees, self.sigma.val def _update_leaf_mean(self, tree: Tree, partial_residual: torch.Tensor): """