Skip to content

Commit

Permalink
Remove unnecessary Sequential in ConfigurableHiddenBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Jun 3, 2024
1 parent 9a0fc43 commit de0b91b
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions bayesflow/experimental/networks/resnet/hidden_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
register_keras_serializable,
)

from bayesflow.experimental.types import Tensor

@register_keras_serializable(package="bayesflow.networks.resnet")
class ConfigurableHiddenBlock(keras.layers.Layer):
def __init__(
Expand All @@ -24,21 +26,19 @@ def __init__(
self.activation_fn = keras.activations.get(activation)
self.residual = residual
self.spectral_norm = spectral_norm
self.dense_with_dropout = keras.Sequential()
dense = layers.Dense(
self.dense = layers.Dense(
units=units,
kernel_regularizer=kernel_regularizer,
kernel_initializer=kernel_initializer,
bias_regularizer=bias_regularizer
)
if spectral_norm:
self.dense_with_dropout.add(layers.SpectralNormalization(dense))
else:
self.dense_with_dropout.add(dense)
self.dense_with_dropout.add(keras.layers.Dropout(dropout_rate))
self.dense = layers.SpectralNormalization(self.dense)
self.dropout = keras.layers.Dropout(dropout_rate)

def call(self, inputs, training=False):
x = self.dense_with_dropout(inputs, training=training)
def call(self, inputs: Tensor, **kwargs):
x = self.dense(inputs, **kwargs)
x = self.dropout(x, **kwargs)
if self.residual:
x = x + inputs
return self.activation_fn(x)
Expand Down

0 comments on commit de0b91b

Please sign in to comment.