Skip to content

Commit

Permalink
allow for scalar simulator return values even in batched contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Aug 28, 2024
1 parent 3ca8905 commit 137b78d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
13 changes: 8 additions & 5 deletions bayesflow/simulators/lambda_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def __init__(self, sample_fn: callable, *, is_batched: bool = False, cast_dtypes
:param sample_fn: The sampling function.
If in batched format, must accept a batch_shape argument as the first positional argument.
If in unbatched format (the default), may accept any keyword arguments.
Must return a dictionary of string keys and numpy array values.
Must return a dictionary of string keys and numpy array (or scalar) values.
:param is_batched: Whether the sampling function is in batched format.
:param cast_dtypes: Output data types to cast to.
:param cast_dtypes: Output data types to cast arrays to.
By default, we convert float64 (the default for numpy on x64 systems)
to float32 (the default for deep learning on any system).
"""
Expand All @@ -29,7 +29,7 @@ def __init__(self, sample_fn: callable, *, is_batched: bool = False, cast_dtypes

self.cast_dtypes = cast_dtypes

def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, any]:
# try to use only valid keyword arguments
kwargs = filter_kwargs(kwargs, self.sample_fn)

Expand All @@ -42,18 +42,21 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:

return data

def _sample_batch(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
def _sample_batch(self, batch_shape: Shape, **kwargs) -> dict[str, any]:
"""Samples a batch of data from an otherwise unbatched sampling function."""
data = batched_call(self.sample_fn, batch_shape, kwargs=kwargs, flatten=True)

data = tree_stack(data, axis=0, numpy=True)

return data

def _cast_dtypes(self, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
def _cast_dtypes(self, data: dict[str, any]) -> dict[str, any]:
data = data.copy()

for key, value in data.items():
if not isinstance(value, np.ndarray):
continue

dtype = str(value.dtype)
if dtype in self.cast_dtypes:
data[key] = value.astype(self.cast_dtypes[dtype])
Expand Down
7 changes: 6 additions & 1 deletion bayesflow/utils/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ def batched_call(
if map_predicate is None:

def map_predicate(arg):
return isinstance(arg, np.ndarray) or keras.ops.is_tensor(arg)
if isinstance(arg, np.ndarray):
return arg.ndim > len(batch_shape)
if keras.ops.is_tensor(arg):
return keras.ops.ndim(arg) > len(batch_shape)

return False

outputs = np.empty(batch_shape, dtype="object")

Expand Down

0 comments on commit 137b78d

Please sign in to comment.