Skip to content

Commit

Permalink
rename validate_batch_shape -> validate_shape
Browse files Browse the repository at this point in the history
introduce ShapeLike, restoring old value of Shape
add actual validation to validate_shape
  • Loading branch information
LarsKue committed Oct 22, 2024
1 parent ff2b36c commit d0f23fd
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 28 deletions.
8 changes: 4 additions & 4 deletions bayesflow/simulators/composite_simulator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections.abc import Sequence
import numpy as np

from bayesflow.types import Shape
from bayesflow.utils import validate_batch_shape
from bayesflow.types import ShapeLike
from bayesflow.utils import validate_shape

from .simulator import Simulator

Expand All @@ -14,8 +14,8 @@ def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = False
self.simulators = simulators
self.expand_outputs = expand_outputs

def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
batch_shape = validate_batch_shape(batch_shape)
def sample(self, batch_shape: ShapeLike, **kwargs) -> dict[str, np.ndarray]:
batch_shape = validate_shape(batch_shape)

data = {}
for simulator in self.simulators:
Expand Down
8 changes: 4 additions & 4 deletions bayesflow/simulators/hierarchical_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
import keras
import numpy as np

from bayesflow.types import Shape
from bayesflow.types import ShapeLike

from .simulator import Simulator
from bayesflow.utils import validate_batch_shape
from bayesflow.utils import validate_shape


class HierarchicalSimulator(Simulator):
def __init__(self, hierarchy: Sequence[Simulator]):
self.hierarchy = hierarchy

def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
batch_shape = validate_batch_shape(batch_shape)
def sample(self, batch_shape: ShapeLike, **kwargs) -> dict[str, np.ndarray]:
batch_shape = validate_shape(batch_shape)

input_data = {}
output_data = {}
Expand Down
9 changes: 5 additions & 4 deletions bayesflow/simulators/lambda_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from bayesflow.utils import batched_call, filter_kwargs, tree_stack

from .simulator import Simulator
from bayesflow.utils import validate_batch_shape
from ..types import Shape
from bayesflow.utils import validate_shape
from ..types import Shape, ShapeLike


class LambdaSimulator(Simulator):
Expand Down Expand Up @@ -43,8 +43,9 @@ def __init__(

self.reserved_arguments = reserved_arguments

def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
batch_shape = validate_batch_shape(batch_shape)
def sample(self, batch_shape: ShapeLike, **kwargs) -> dict[str, np.ndarray]:
batch_shape = validate_shape(batch_shape)

# add reserved arguments
kwargs = self.reserved_arguments | kwargs

Expand Down
8 changes: 4 additions & 4 deletions bayesflow/simulators/model_comparison_simulator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from collections.abc import Sequence
import numpy as np

from bayesflow.types import Shape
from bayesflow.types import ShapeLike
from bayesflow.utils import tree_stack

from bayesflow.utils import numpy_utils as npu

from .simulator import Simulator
from bayesflow.utils import validate_batch_shape
from bayesflow.utils import validate_shape


class ModelComparisonSimulator(Simulator):
Expand Down Expand Up @@ -41,8 +41,8 @@ def __init__(
self.logits = logits
self.use_mixed_batches = use_mixed_batches

def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
batch_shape = validate_batch_shape(batch_shape)
def sample(self, batch_shape: ShapeLike, **kwargs) -> dict[str, np.ndarray]:
batch_shape = validate_shape(batch_shape)

if not self.use_mixed_batches:
# draw one model index for the whole batch (faster)
Expand Down
10 changes: 6 additions & 4 deletions bayesflow/simulators/simulator.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from collections.abc import Callable
import numpy as np

from bayesflow.types import Shape
from bayesflow.utils import tree_concatenate
from bayesflow.types import ShapeLike
from bayesflow.utils import tree_concatenate, validate_shape


class Simulator:
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
def sample(self, batch_shape: ShapeLike, **kwargs) -> dict[str, np.ndarray]:
raise NotImplementedError

def rejection_sample(
self,
batch_shape: Shape,
batch_shape: ShapeLike,
predicate: Callable[[dict[str, np.ndarray]], np.ndarray],
*,
axis: int = 0,
sample_size: int = None,
**kwargs,
) -> dict[str, np.ndarray]:
batch_shape = validate_shape(batch_shape)

if sample_size is None:
sample_shape = batch_shape
else:
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .shape import Shape
from .shape import Shape, ShapeLike
from .tensor import Tensor
3 changes: 2 additions & 1 deletion bayesflow/types/shape.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
Shape = int | tuple[int, ...]
Shape = tuple[int, ...]
ShapeLike = int | Shape
2 changes: 1 addition & 1 deletion bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from .optimal_transport import optimal_transport

from .validators import validate_batch_shape
from .validators import validate_shape

from .tensor_utils import (
expand_left,
Expand Down
13 changes: 8 additions & 5 deletions bayesflow/utils/validators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from bayesflow.types import Shape
from bayesflow.types import Shape, ShapeLike


def validate_batch_shape(batch_shape: Shape) -> tuple:
if isinstance(batch_shape, int):
batch_shape = (batch_shape,)
def validate_shape(shape: ShapeLike) -> Shape:
if isinstance(shape, int):
return (shape,)

return batch_shape
if not isinstance(shape, tuple) or not all(isinstance(dim, int) for dim in shape):
raise ValueError(f"Invalid shape: {shape}")

return shape

0 comments on commit d0f23fd

Please sign in to comment.