From 26529f821fed83149db6ac849631d600e43bbadd Mon Sep 17 00:00:00 2001 From: Himaghna Bhattacharjee Date: Wed, 6 Jul 2022 16:13:34 -0700 Subject: [PATCH] Only store Sigma values in BART samples instead of object (#1527) Summary: Pull Request resolved: https://github.com/facebookresearch/beanmachine/pull/1527 Background: We are building Bayesian Additive Regression Trees (BART) as an experimental causal inference model in beanmachine. Details of the project can be found in https://docs.google.com/document/d/11nkB6UTGpvQBEC2yBjfgwAr8VabTlD7R9XufGQG0EvI/edit?usp=sharing and the proposed design can be found in the draft design document: https://docs.google.com/document/d/1o3J7yobDF0M9E27Y0tP2889fycmemXUZbHE5cebRqzs/edit?usp=sharing. In this diff: The noise standard deviation (sigma) parameter is never really used in the prediction tasks. While we would like to retain them for diagnostic purposes, there is no reason to store the NoiseStandardDeviation object in the sample trace. In this diff, we are modifying the BART class to only store float samples of the noise standard deviation. Reviewed By: feynmanliang Differential Revision: D37635208 fbshipit-source-id: be2f53b61b666fe9d50d2504a57351c91bd24915 --- .../ppl/experimental/causal_inference/models/bart/bart_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/beanmachine/ppl/experimental/causal_inference/models/bart/bart_model.py b/src/beanmachine/ppl/experimental/causal_inference/models/bart/bart_model.py index ae02b92fd4..078b0aa035 100644 --- a/src/beanmachine/ppl/experimental/causal_inference/models/bart/bart_model.py +++ b/src/beanmachine/ppl/experimental/causal_inference/models/bart/bart_model.py @@ -219,7 +219,7 @@ def _step(self) -> Tuple[List, float]: self.X ) self._update_sigma(self.y - self._predict_step()) - return self.all_trees, self.sigma + return self.all_trees, self.sigma.val def _update_leaf_mean(self, tree: Tree, partial_residual: torch.Tensor): """