基于matlab的dbn算法实现


下载deeplearningtoolbox或者本人提交的zip文件包即可直接运行。command window运行以下程序:

%function test_example_DBN
load mnist_uint8;

train_x = double(train_x) / 255;
test_x  = double(test_x)  / 255;
train_y = double(train_y);
test_y  = double(test_y);

%%  ex1 train a 100 hidden unit RBM and visualize its weights
rand('state',0)
dbn.sizes = [100];
opts.numepochs =   1;
opts.batchsize = 100;
opts.momentum  =   0;
opts.alpha     =   1;
dbn = dbnsetup(dbn, train_x, opts);
dbn = dbntrain(dbn, train_x, opts);
figure; visualize(dbn.rbm{1}.W');   %  Visualize the RBM weights

%%  ex2 train a 100-100 hidden unit DBN and use its weights to initialize a NN
rand('state',0)
%train dbn
dbn.sizes = [100 100];
opts.numepochs =   1;
opts.batchsize = 100;
opts.momentum  =   0;
opts.alpha     =   1;
dbn = dbnsetup(dbn, train_x, opts);
dbn = dbntrain(dbn, train_x, opts);

%unfold dbn to nn
nn = dbnunfoldtonn(dbn, 10);
nn.activation_function = 'sigm';

%train nn
opts.numepoc
hs =  1;
opts.batchsize = 100;
nn = nntrain(nn, train_x, train_y, opts);[er, bad] = nntest(nn, test_x, test_y);

assert(er < 0.10, 'Too big error');

```

抑或者,下载mnist文件包,之后转成二进制文件,再根据本人上传或者其他转换文件,生成.mat文件,确保matlab可以运行。

之后参见githud.com上面的代码介绍也可实现dbn功能。

转载请注明出处


 

你可能感兴趣的:(dbn)