-
Notifications
You must be signed in to change notification settings - Fork 18
/
SoftmaxDiffLoss.m
56 lines (48 loc) · 1.74 KB
/
SoftmaxDiffLoss.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
classdef SoftmaxDiffLoss < dagnn.ElementWise
properties
opts = {}
mode = 'MI'
temperature = 2
origstyle = 'multiclass'
opts_vl = struct()
end
properties (Transient)
average = 0
numAveraged = 0
end
methods
function outputs = forward(obj, inputs, params)
outputs{1} = vl_nnsoftmaxdiff(inputs{1}, inputs{2}, [], obj.opts_vl) ;
n = obj.numAveraged ;
m = n + size(inputs{1},4) ;
obj.average = bsxfun(@plus, n * obj.average, gather(outputs{1})) / m ;
obj.numAveraged = m ;
end
function [derInputs, derParams] = backward(obj, inputs, params, derOutputs)
derInputs{1} = vl_nnsoftmaxdiff(inputs{1}, inputs{2}, derOutputs{1}, obj.opts_vl) ;
derInputs{2} = [] ;
derParams = {} ;
end
function reset(obj)
obj.average = 0 ;
obj.numAveraged = 0 ;
end
function outputSizes = getOutputSizes(obj, inputSizes, paramSizes)
outputSizes{1} = [1 1 1 inputSizes{1}(4)] ;
end
function rfs = getReceptiveFields(obj)
% the receptive field depends on the dimension of the variables
% which is not known until the network is run
rfs(1,1).size = [NaN NaN] ;
rfs(1,1).stride = [NaN NaN] ;
rfs(1,1).offset = [NaN NaN] ;
rfs(2,1) = rfs(1,1) ;
end
function obj = SoftmaxDiffLoss(varargin)
obj.load(varargin) ;
obj.opts_vl.mode = obj.mode;
obj.opts_vl.temperature = obj.temperature;
obj.opts_vl.origstyle = obj.origstyle;
end
end
end