diff --git a/bayesflow/simulators/composite_simulator.py b/bayesflow/simulators/composite_simulator.py index f8c0eb4b..c6b7a5de 100644 --- a/bayesflow/simulators/composite_simulator.py +++ b/bayesflow/simulators/composite_simulator.py @@ -1,8 +1,8 @@ from collections.abc import Sequence import numpy as np -from bayesflow.types import Shape -from bayesflow.utils import validate_batch_shape +from bayesflow.types import ShapeLike +from bayesflow.utils import validate_shape from .simulator import Simulator @@ -14,8 +14,8 @@ def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = False self.simulators = simulators self.expand_outputs = expand_outputs - def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: - batch_shape = validate_batch_shape(batch_shape) + def sample(self, batch_shape: ShapeLike, **kwargs) -> dict[str, np.ndarray]: + batch_shape = validate_shape(batch_shape) data = {} for simulator in self.simulators: diff --git a/bayesflow/simulators/hierarchical_simulator.py b/bayesflow/simulators/hierarchical_simulator.py index 6657034c..fc53847e 100644 --- a/bayesflow/simulators/hierarchical_simulator.py +++ b/bayesflow/simulators/hierarchical_simulator.py @@ -2,18 +2,18 @@ import keras import numpy as np -from bayesflow.types import Shape +from bayesflow.types import ShapeLike from .simulator import Simulator -from bayesflow.utils import validate_batch_shape +from bayesflow.utils import validate_shape class HierarchicalSimulator(Simulator): def __init__(self, hierarchy: Sequence[Simulator]): self.hierarchy = hierarchy - def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: - batch_shape = validate_batch_shape(batch_shape) + def sample(self, batch_shape: ShapeLike, **kwargs) -> dict[str, np.ndarray]: + batch_shape = validate_shape(batch_shape) input_data = {} output_data = {} diff --git a/bayesflow/simulators/lambda_simulator.py b/bayesflow/simulators/lambda_simulator.py index 3f919bbb..d762d635 100644 --- a/bayesflow/simulators/lambda_simulator.py +++ b/bayesflow/simulators/lambda_simulator.py @@ -4,8 +4,8 @@ from bayesflow.utils import batched_call, filter_kwargs, tree_stack from .simulator import Simulator -from bayesflow.utils import validate_batch_shape -from ..types import Shape +from bayesflow.utils import validate_shape +from ..types import Shape, ShapeLike class LambdaSimulator(Simulator): @@ -43,8 +43,9 @@ def __init__( self.reserved_arguments = reserved_arguments - def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: - batch_shape = validate_batch_shape(batch_shape) + def sample(self, batch_shape: ShapeLike, **kwargs) -> dict[str, np.ndarray]: + batch_shape = validate_shape(batch_shape) + # add reserved arguments kwargs = self.reserved_arguments | kwargs diff --git a/bayesflow/simulators/model_comparison_simulator.py b/bayesflow/simulators/model_comparison_simulator.py index 8bc76998..9f5293de 100644 --- a/bayesflow/simulators/model_comparison_simulator.py +++ b/bayesflow/simulators/model_comparison_simulator.py @@ -1,13 +1,13 @@ from collections.abc import Sequence import numpy as np -from bayesflow.types import Shape +from bayesflow.types import ShapeLike from bayesflow.utils import tree_stack from bayesflow.utils import numpy_utils as npu from .simulator import Simulator -from bayesflow.utils import validate_batch_shape +from bayesflow.utils import validate_shape class ModelComparisonSimulator(Simulator): @@ -41,8 +41,8 @@ def __init__( self.logits = logits self.use_mixed_batches = use_mixed_batches - def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: - batch_shape = validate_batch_shape(batch_shape) + def sample(self, batch_shape: ShapeLike, **kwargs) -> dict[str, np.ndarray]: + batch_shape = validate_shape(batch_shape) if not self.use_mixed_batches: # draw one model index for the whole batch (faster) diff --git a/bayesflow/simulators/simulator.py b/bayesflow/simulators/simulator.py index f0fdddba..97d31d20 100644 --- a/bayesflow/simulators/simulator.py +++ b/bayesflow/simulators/simulator.py @@ -1,23 +1,25 @@ from collections.abc import Callable import numpy as np -from bayesflow.types import Shape -from bayesflow.utils import tree_concatenate +from bayesflow.types import ShapeLike +from bayesflow.utils import tree_concatenate, validate_shape class Simulator: - def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: + def sample(self, batch_shape: ShapeLike, **kwargs) -> dict[str, np.ndarray]: raise NotImplementedError def rejection_sample( self, - batch_shape: Shape, + batch_shape: ShapeLike, predicate: Callable[[dict[str, np.ndarray]], np.ndarray], *, axis: int = 0, sample_size: int = None, **kwargs, ) -> dict[str, np.ndarray]: + batch_shape = validate_shape(batch_shape) + if sample_size is None: sample_shape = batch_shape else: diff --git a/bayesflow/types/__init__.py b/bayesflow/types/__init__.py index 16253178..d954bae6 100644 --- a/bayesflow/types/__init__.py +++ b/bayesflow/types/__init__.py @@ -1,2 +1,2 @@ -from .shape import Shape +from .shape import Shape, ShapeLike from .tensor import Tensor diff --git a/bayesflow/types/shape.py b/bayesflow/types/shape.py index 32b630b4..850cd512 100644 --- a/bayesflow/types/shape.py +++ b/bayesflow/types/shape.py @@ -1 +1,2 @@ -Shape = int | tuple[int, ...] +Shape = tuple[int, ...] +ShapeLike = int | Shape diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index fc1ec300..6cc822ea 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -36,7 +36,7 @@ from .optimal_transport import optimal_transport -from .validators import validate_batch_shape +from .validators import validate_shape from .tensor_utils import ( expand_left, diff --git a/bayesflow/utils/validators.py b/bayesflow/utils/validators.py index 54aab3d1..2a4b275d 100644 --- a/bayesflow/utils/validators.py +++ b/bayesflow/utils/validators.py @@ -1,8 +1,11 @@ -from bayesflow.types import Shape +from bayesflow.types import Shape, ShapeLike -def validate_batch_shape(batch_shape: Shape) -> tuple: - if isinstance(batch_shape, int): - batch_shape = (batch_shape,) +def validate_shape(shape: ShapeLike) -> Shape: + if isinstance(shape, int): + return (shape,) - return batch_shape + if not isinstance(shape, tuple) or not all(isinstance(dim, int) for dim in shape): + raise ValueError(f"Invalid shape: {shape}") + + return shape