Skip to content

Commit

Permalink
Reflect refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Jun 9, 2024
1 parent 9846263 commit 340f2a6
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions bayesflow/experimental/utils/dispatch/find_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def _(name: str, **kwargs):
case "min":
pooling = keras.layers.Lambda(lambda inp: keras.ops.min(inp, axis=-2))
case "learnable" | "pma" | "attention":
from bayesflow.experimental.networks.set_transformer.pma import PoolingByMultiheadAttention
pooling = PoolingByMultiheadAttention(**kwargs)
from bayesflow.experimental.networks.transformers.pma import PoolingByMultiHeadAttention
pooling = PoolingByMultiHeadAttention(**kwargs)
case other:
raise ValueError(f"Unsupported pooling name: '{other}'.")

Expand Down

0 comments on commit 340f2a6

Please sign in to comment.