You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
You can refer to the 220 line for self.channel_axis = -1 if data_format == 'channels_last' else 1 and the 282 line for self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis], the self.axes
Reproducible Code
Which OS are you using ?
python3, TL2.2 and TF2.0.
Please provide a reproducible code of your issue. Without any reproducible code, you will probably not receive any help.
classBatchNorm(Layer):
""" The :class:`BatchNorm` is a batch normalization layer for both fully-connected and convolution outputs. See ``tf.nn.batch_normalization`` and ``tf.nn.moments``. Parameters ---------- decay : float A decay factor for `ExponentialMovingAverage`. Suggest to use a large value for large dataset. epsilon : float Eplison. act : activation function The activation function of this layer. is_train : boolean Is being used for training or inference. beta_init : initializer or None The initializer for initializing beta, if None, skip beta. Usually you should not skip beta unless you know what happened. gamma_init : initializer or None The initializer for initializing gamma, if None, skip gamma. When the batch normalization layer is use instead of 'biases', or the next layer is linear, this can be disabled since the scaling can be done by the next layer. see `Inception-ResNet-v2 <https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py>`__ moving_mean_init : initializer or None The initializer for initializing moving mean, if None, skip moving mean. moving_var_init : initializer or None The initializer for initializing moving var, if None, skip moving var. num_features: int Number of features for input tensor. Useful to build layer if using BatchNorm1d, BatchNorm2d or BatchNorm3d, but should be left as None if using BatchNorm. Default None. data_format : str channels_last 'channel_last' (default) or channels_first. name : None or str A unique layer name. Examples --------- With TensorLayer >>> net = tl.layers.Input([None, 50, 50, 32], name='input') >>> net = tl.layers.BatchNorm()(net) Notes ----- The :class:`BatchNorm` is universally suitable for 3D/4D/5D input in static model, but should not be used in dynamic model where layer is built upon class initialization. So the argument 'num_features' should only be used for subclasses :class:`BatchNorm1d`, :class:`BatchNorm2d` and :class:`BatchNorm3d`. All the three subclasses are suitable under all kinds of conditions. References ---------- - `Source <https://github.com/ry/tensorflow-resnet/blob/master/resnet.py>`__ - `stackoverflow <http://stackoverflow.com/questions/38312668/how-does-one-do-inference-with-batch-normalization-with-tensor-flow>`__ """def__init__(
self,
decay=0.9,
epsilon=0.00001,
act=None,
is_train=False,
beta_init=tl.initializers.zeros(),
gamma_init=tl.initializers.random_normal(mean=1.0, stddev=0.002),
moving_mean_init=tl.initializers.zeros(),
moving_var_init=tl.initializers.zeros(),
num_features=None,
data_format='channels_last',
name=None,
):
super(BatchNorm, self).__init__(name=name, act=act)
self.decay=decayself.epsilon=epsilonself.data_format=data_formatself.beta_init=beta_initself.gamma_init=gamma_initself.moving_mean_init=moving_mean_initself.moving_var_init=moving_var_initself.num_features=num_features#self.channel_axis = -1 if data_format == 'channels_last' else 1## add ##self.data_format=data_formatself.axes=Noneifnum_featuresisnotNone:
self.build(None)
self._built=Trueifself.decay<0.0or1.0<self.decay:
raiseValueError("decay should be between 0 to 1")
logging.info(
"BatchNorm %s: decay: %f epsilon: %f act: %s is_train: %s"%
(self.name, decay, epsilon, self.act.__name__ifself.actisnotNoneelse'No Activation', is_train)
)
## skip ##defforward(self, inputs):
self._check_input_shape(inputs)
## add ##self.channel_axis=len(inputs.shape) -1ifself.data_format=='channels_last'else1ifself.axesisNone:
self.axes= [iforiinrange(len(inputs.shape)) ifi!=self.channel_axis]
mean, var=tf.nn.moments(inputs, self.axes, keepdims=False)
ifself.is_train:
# update moving_mean and moving_varself.moving_mean=moving_averages.assign_moving_average(
self.moving_mean, mean, self.decay, zero_debias=False
)
self.moving_var=moving_averages.assign_moving_average(self.moving_var, var, self.decay, zero_debias=False)
outputs=batch_normalization(inputs, mean, var, self.beta, self.gamma, self.epsilon, self.data_format)
else:
outputs=batch_normalization(
inputs, self.moving_mean, self.moving_var, self.beta, self.gamma, self.epsilon, self.data_format
)
ifself.act:
outputs=self.act(outputs)
returnoutputs
just delete line 220 code and add self.data_format = data_format in init, then add self.channel_axis = len(inputs.shape) - 1 if self.data_format == 'channels_last' else 1 in forward.
The text was updated successfully, but these errors were encountered:
New Issue Checklist
Issue Description
according to batch normalization implement in TL, which can be find
at "https://github.com/tensorlayer/tensorlayer/blob/v2.2.0/tensorlayer/layers/normalization.py", the mean and var are computed with the whole inputs, not channel-wise inputs when the init are set as
data_format='channels_last'
.You can refer to the 220 line for
self.channel_axis = -1 if data_format == 'channels_last' else 1
and the 282 line forself.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis], the
self.axesReproducible Code
python3, TL2.2 and TF2.0.
just delete line 220 code and add
self.data_format = data_format
in init, then addself.channel_axis = len(inputs.shape) - 1 if self.data_format == 'channels_last' else 1
in forward.The text was updated successfully, but these errors were encountered: