Skip to content

Commit

Permalink
Merge pull request #173 from Chase-Grajeda/summary-lstnet
Browse files Browse the repository at this point in the history
LSTNet Implementation
  • Loading branch information
stefanradev93 authored Jun 13, 2024
2 parents 73602de + a604d99 commit 2f87322
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 2 deletions.
5 changes: 3 additions & 2 deletions bayesflow/experimental/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from .coupling_flow import CouplingFlow
from .deep_set import DeepSet
from .flow_matching import FlowMatching
from .inference_network import InferenceNetwork
from .mlp import MLP
from .resnet import ResNet
from .transformers import SetTransformer
from .summary_network import SummaryNetwork

from .inference_network import InferenceNetwork
from .summary_network import SummaryNetwork
1 change: 1 addition & 0 deletions bayesflow/experimental/networks/lstnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .lstnet import LSTNet
64 changes: 64 additions & 0 deletions bayesflow/experimental/networks/lstnet/lstnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import keras
from bayesflow.experimental.types import Tensor
from bayesflow.experimental.utils import keras_kwargs
from keras import layers, Sequential, regularizers
from keras.saving import (register_keras_serializable)
from .skip_gru import SkipGRU
from ...networks.resnet import ResNet

@register_keras_serializable(package="bayesflow.networks.lstnet")
class LSTNet(keras.Model):
"""
Implements a LSTNet Architecture as described in [1]
[1] Y. Zhang and L. Mikelsons, Solving Stochastic Inverse Problems with Stochastic BayesFlow,
2023 IEEE/ASME International Conference on Advanced Intelligent Mechatronics (AIM),
Seattle, WA, USA, 2023, pp. 966-972, doi: 10.1109/AIM46323.2023.10196190.
TODO: Add proper docstring
"""

def __init__(
self,
cnn_out: int = 128,
kernel_size: int = 4,
kernel_initializer: str = "glorot_uniform",
kernel_regularizer: regularizers.Regularizer | None = None,
activation: str = "relu",
gru_out: int = 64,
skip_outs: list[int] = [32],
skip_steps: list[int] = [2],
resnet_out: int = 32,
**kwargs
):
if len(skip_outs) != len(skip_steps):
raise ValueError("hidden_out must have same length as skip_steps")

super().__init__(**keras_kwargs(kwargs))

# Define model
self.model = Sequential()
self.conv1 = layers.Conv1D(
filters=cnn_out,
kernel_size=kernel_size,
activation=activation,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer
)
self.bnorm = layers.BatchNormalization()
self.skip_gru = SkipGRU(gru_out, skip_outs, skip_steps)
self.resnet = ResNet(width=resnet_out)

# Aggregate layers In: (batch, time steps, num series)
self.model.add(self.conv1) # -> (batch, reduced time steps, cnn_out)
self.model.add(self.bnorm) # -> (batch, reduced time steps, cnn_out)
self.model.add(self.skip_gru) # -> (batch, _)
self.model.add(self.resnet) # -> (batch, resnet_out)

def call(self, x: Tensor) -> Tensor:
x = self.model(x)
return x

def build(self, input_shape):
self.call(keras.ops.zeros(input_shape))
44 changes: 44 additions & 0 deletions bayesflow/experimental/networks/lstnet/skip_gru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import keras
from keras.saving import register_keras_serializable
from keras import layers
from bayesflow.experimental.types import Tensor
from bayesflow.experimental.utils import keras_kwargs

@register_keras_serializable(package="bayesflow.networks.skip_gru")
class SkipGRU(keras.Model):
"""
Implements a Skip GRU layer as described in [1]
[1] Y. Zhang and L. Mikelsons, Solving Stochastic Inverse Problems with Stochastic BayesFlow,
2023 IEEE/ASME International Conference on Advanced Intelligent Mechatronics (AIM),
Seattle, WA, USA, 2023, pp. 966-972, doi: 10.1109/AIM46323.2023.10196190.
TODO: Add proper docstring
"""
def __init__(self, gru_out: int, skip_outs: list[int], skip_steps: list[int], **kwargs):
super().__init__(**keras_kwargs(kwargs))
self.gru_out = gru_out
self.skip_steps = skip_steps
self.gru = layers.GRU(gru_out)
self.skip_grus = [layers.GRU(skip_outs[i]) for i in range(len(self.skip_steps))]

def call(self, x: Tensor) -> Tensor:
sgru = self.gru(x)
for i, skip_step in enumerate(self.skip_steps):
# Reshape, remove skipped time points
skip_length = x.shape[1] // skip_step
s = x[:, -skip_length * skip_step:, :]
s = keras.ops.reshape(s, (-1, s.shape[2], skip_length, skip_step))
s = keras.ops.transpose(s, [0, 3, 2, 1])
s = keras.ops.reshape(s, (-1, s.shape[2], s.shape[3]))

# Reapply GRU, add to working tensor
s = self.skip_grus[i](s)
s = keras.ops.reshape(s, (-1, skip_step * s.shape[1]))
sgru = keras.ops.concatenate([sgru, s], axis=1)

return sgru

def build(self, input_shape):
self.call(keras.ops.zeros(input_shape))

0 comments on commit 2f87322

Please sign in to comment.