Skip to content

Commit

Permalink
Change metrics to global loss (default) and also allow summary-net or…
Browse files Browse the repository at this point in the history
… inference-net specific metrics for future
  • Loading branch information
marvinschmitt committed May 28, 2024
1 parent d1b9845 commit ad6c7be
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions bayesflow/experimental/amortizers/amortizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,36 @@ def reset_metrics(self):
def metrics(self):
return [self.loss_tracker]

def compute_metrics(self, x: dict, y: dict, y_pred: dict, **kwargs):
inferred_variables = self.configure_inferred_variables(x)
observed_variables = self.configure_observed_variables(x)

if self.summary_network:
summary_conditions = self.configure_summary_conditions(x)
summary_metrics = self.summary_network.compute_metrics(
x=(observed_variables, summary_conditions),
y=y.get("summary_targets"),
y_pred=y_pred.get("summary_outputs")
)
else:
summary_metrics = {}

inference_conditions = self.configure_inference_conditions(x, y_pred.get("summary_outputs"))
inference_metrics = self.inference_network.compute_metrics(
x=(inferred_variables, inference_conditions),
y=y.get("inference_targets"),
y_pred=y_pred.get("inference_outputs")
)

summary_metrics = {f"summary/{key}": value for key, value in summary_metrics.items()}
inference_metrics = {f"inference/{key}": value for key, value in inference_metrics.items()}

out_metrics = {'loss': self.loss_tracker.total.value}
out_metrics = out_metrics | summary_metrics
out_metrics = out_metrics | inference_metrics

return out_metrics

def sample(self, data: dict, num_samples: int, sample_summaries=False, **kwargs):

# Configure everything -> inference conditions / summary conditions
Expand Down

0 comments on commit ad6c7be

Please sign in to comment.