-
Notifications
You must be signed in to change notification settings - Fork 1
/
nnet_L2.m
67 lines (45 loc) · 1.78 KB
/
nnet_L2.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
function [lsizeOUT,ltypeOUT] = nnet_L2(alldata,trainIND,testIND,runName,weightcost,varargin)
seed = 1234;
randn('state', seed );
rand('twister', seed+1 );
%you will NEVER need more than a few hundred epochs unless you are doing
%something very wrong. Here 'epoch' means parameter update, not 'pass over
%the training set'.
maxepoch = 100;
indata = alldata(:,trainIND);
intest = alldata(:,testIND);
indata = indata(:,randperm(size(indata,2)));
outdata = indata;
intest = intest(:,randperm(size(intest,2)));
outtest = intest;
runDesc = ['seed = ' num2str(seed) 'd is reduced to PC scores with 95 of the variance' ];
%next try using autodamp = 0 for rho computation. both for version 6 and
%versions with rho and cg-backtrack computed on the training set
if isempty(varargin)
layersizes = [100 50 20 2 20 50 100];
else
layersizes = varargin{1};
end
layertypes = {'tanh', 'tanh', 'tanh', 'linear', 'tanh','tanh', 'tanh', 'linear'};
resumeFile = [];
lsizeOUT = [size(alldata,1) layersizes size(alldata,1)];
ltypeOUT = layertypes;
ltypeOUT{find(layersizes == min(layersizes))} = 'linearSTORE';
paramsp = [];
Win = [];
bin = [];
%[Win, bin] = loadPretrainedNet_curves;
numchunks = 4;
numchunks_test = 1;
mattype = 'gn'; %Gauss-Newton. The other choices probably won't work for whatever you're doing
%mattype = 'hess';
%mattype = 'empfish';
rms = 0;
hybridmode = 1;
%decay = 1.0;
decay = 0.95;
%jacket = 0;
%this enables Jacket mode for the GPU
jacket = 1;
errtype = 'L2'; %report the L2-norm error (in addition to the quantity actually being optimized, i.e. the log-likelihood)
nnet_train_2( runName, runDesc, paramsp, Win, bin, resumeFile, maxepoch, indata, outdata, numchunks, intest, outtest, numchunks_test, layersizes, layertypes, mattype, rms, errtype, hybridmode, weightcost, decay, jacket);