Skip to content

Commit

Permalink
Merge pull request #92 from elseml/Development
Browse files Browse the repository at this point in the history
Fix train_from_presimulation for model comparison
  • Loading branch information
elseml authored Jul 28, 2023
2 parents 1d7cb0f + d32ef29 commit 26de450
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions bayesflow/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import logging
import os
from pickle import load as pickle_load
import tensorflow as tf

import numpy as np
import tensorflow as tf
from tqdm.autonotebook import tqdm

from bayesflow.amortizers import (
Expand Down Expand Up @@ -737,7 +737,10 @@ def train_from_presimulation(
input_dict = self.configurator(epoch_data[index])

# Like the number of iterations, the batch size is inferred from presimulated dictionary or list
batch_size = epoch_data[index][DEFAULT_KEYS["sim_data"]].shape[0]
if isinstance(self.amortizer, AmortizedModelComparison):
batch_size = input_dict[DEFAULT_KEYS["summary_conditions"]].shape[0]
else:
batch_size = epoch_data[index][DEFAULT_KEYS["sim_data"]].shape[0]
loss = self._train_step(batch_size, _backprop_step, input_dict, **kwargs)

# Store returned loss
Expand Down

0 comments on commit 26de450

Please sign in to comment.