-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'streamlined-backend' of https://github.com/stefanradev9…
…3/BayesFlow into streamlined-backend
- Loading branch information
Showing
18 changed files
with
292 additions
and
702 deletions.
There are no files selected for viewing
59 changes: 0 additions & 59 deletions
59
bayesflow/experimental/networks/coupling_flow/couplings/coupling.py
This file was deleted.
Oops, something went wrong.
60 changes: 60 additions & 0 deletions
60
bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
|
||
import keras | ||
from keras.saving import ( | ||
register_keras_serializable | ||
) | ||
|
||
from bayesflow.experimental.types import Tensor | ||
from bayesflow.experimental.utils import find_network | ||
from ..invertible_layer import InvertibleLayer | ||
from ..transforms import find_transform | ||
|
||
|
||
@register_keras_serializable(package="bayesflow.networks.coupling_flow") | ||
class SingleCoupling(InvertibleLayer): | ||
""" | ||
Implements a single coupling layer as a composition of a subnet and a transform. | ||
Subnet output tensors are linearly mapped to the correct dimension. | ||
""" | ||
def __init__(self, network: str = "resnet", transform: str = "affine", **kwargs): | ||
super().__init__(**kwargs) | ||
self.dense = keras.layers.Dense(None, kernel_initializer="zeros", bias_initializer="zeros") | ||
self.network = find_network(network) | ||
self.transform = find_transform(transform) | ||
|
||
# noinspection PyMethodOverriding | ||
def build(self, x1_shape, x2_shape): | ||
self.dense.units = self.transform.params_per_dim * x2_shape[-1] | ||
|
||
def call(self, x1: Tensor, x2: Tensor, conditions: any = None, inverse: bool = False) -> ((Tensor, Tensor), Tensor): | ||
if inverse: | ||
return self._inverse(x1, x2, conditions=conditions) | ||
return self._forward(x1, x2, conditions=conditions) | ||
|
||
def _forward(self, x1: Tensor, x2: Tensor, conditions: any = None) -> ((Tensor, Tensor), Tensor): | ||
""" Transform (x1, x2) -> (x1, f(x2; x1)) """ | ||
z1 = x1 | ||
parameters = self.get_parameters(x1, conditions) | ||
z2, log_det = self.transform(x2, parameters=parameters) | ||
|
||
return (z1, z2), log_det | ||
|
||
def _inverse(self, z1: Tensor, z2: Tensor, conditions: any = None) -> ((Tensor, Tensor), Tensor): | ||
""" Transform (x1, f(x2; x1)) -> (x1, x2) """ | ||
x1 = z1 | ||
parameters = self.get_parameters(x1, conditions) | ||
x2, log_det = self.transform(z2, parameters=parameters, inverse=True) | ||
|
||
return (x1, x2), log_det | ||
|
||
def get_parameters(self, x, conditions: any = None) -> dict[str, Tensor]: | ||
# TODO: pass conditions to subnet via kwarg if possible | ||
if keras.ops.is_tensor(conditions): | ||
x = keras.ops.concatenate([x, conditions], axis=-1) | ||
|
||
parameters = self.dense(self.network(x)) | ||
parameters = self.transform.split_parameters(parameters) | ||
parameters = self.transform.constrain_parameters(parameters) | ||
|
||
return parameters |
1 change: 1 addition & 0 deletions
1
bayesflow/experimental/networks/coupling_flow/permutations/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
|
||
from .orthogonal import OrthogonalPermutation | ||
from .fixed_permutation import FixedPermutation | ||
from .swap import Swap |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 3 additions & 0 deletions
3
bayesflow/experimental/networks/coupling_flow/permutations/swap.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
23 changes: 0 additions & 23 deletions
23
bayesflow/experimental/networks/coupling_flow/subnets/__init__.py
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.