diff --git a/bayesflow/experimental/utils/dictutils.py b/bayesflow/experimental/utils/dictutils.py index adee00fb..c8d54711 100644 --- a/bayesflow/experimental/utils/dictutils.py +++ b/bayesflow/experimental/utils/dictutils.py @@ -19,7 +19,7 @@ def concatenate_tensors(tensor_dict: dict[str, Tensor], filter_list: list, axis: An optional axis can be specified (default: last axis). """ - return ops.concatenate([v for k, v in tensor_dict.items() if k in filter_list]) + return ops.concatenate([v for k, v in tensor_dict.items() if k in filter_list], axis=axis) def keras_kwargs(kwargs: dict):