Skip to content

Commit

Permalink
Cover cases with no conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Aug 26, 2024
1 parent 38baad2 commit 31d1c03
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,12 @@ def get_config(self):

return base_config | config

def sample(self, data: Mapping[str, Tensor], num_samples: int = 1, numpy: bool = True) -> dict[str, Tensor]:
def sample(
self, data: Mapping[str, Tensor], num_samples: int = 1, numpy: bool = True, batch_shape: Shape = None
) -> dict[str, Tensor]:
data = self.data_adapter.configure(data)
data = keras.tree.map_structure(keras.ops.convert_to_tensor, data)
data = {"inference_variables": self._sample(num_samples, **data)}
data = {"inference_variables": self._sample(num_samples, batch_shape=batch_shape, **data)}
data = self.data_adapter.deconfigure(data)

if numpy:
Expand All @@ -133,7 +135,11 @@ def sample(self, data: Mapping[str, Tensor], num_samples: int = 1, numpy: bool =
return data

def _sample(
self, num_samples: Shape, inference_conditions: Tensor = None, summary_variables: Tensor = None
self,
num_samples: int,
batch_shape: Shape = None,
inference_conditions: Tensor = None,
summary_variables: Tensor = None,
) -> Tensor:
if self.summary_network is not None:
summary_outputs = self.summary_network(summary_variables)
Expand All @@ -143,8 +149,12 @@ def _sample(
else:
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)

inference_conditions = expand_tile(inference_conditions, axis=1, n=num_samples)
batch_shape = (keras.ops.shape(inference_conditions)[0], num_samples)
if batch_shape is None:
if inference_conditions is not None:
inference_conditions = expand_tile(inference_conditions, axis=1, n=num_samples)
batch_shape = (keras.ops.shape(inference_conditions)[0], num_samples)
else:
batch_shape = (num_samples,)

return self.inference_network.sample(batch_shape, conditions=inference_conditions)

Expand Down

0 comments on commit 31d1c03

Please sign in to comment.