From 340f2a61092710880bec08369fc276018cfcefc5 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Sun, 9 Jun 2024 05:16:19 -0400 Subject: [PATCH] Reflect refactorings --- bayesflow/experimental/utils/dispatch/find_pooling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/experimental/utils/dispatch/find_pooling.py b/bayesflow/experimental/utils/dispatch/find_pooling.py index 398dbbb4..408cc283 100644 --- a/bayesflow/experimental/utils/dispatch/find_pooling.py +++ b/bayesflow/experimental/utils/dispatch/find_pooling.py @@ -19,8 +19,8 @@ def _(name: str, **kwargs): case "min": pooling = keras.layers.Lambda(lambda inp: keras.ops.min(inp, axis=-2)) case "learnable" | "pma" | "attention": - from bayesflow.experimental.networks.set_transformer.pma import PoolingByMultiheadAttention - pooling = PoolingByMultiheadAttention(**kwargs) + from bayesflow.experimental.networks.transformers.pma import PoolingByMultiHeadAttention + pooling = PoolingByMultiHeadAttention(**kwargs) case other: raise ValueError(f"Unsupported pooling name: '{other}'.")