MATLAB Deep Learning toolbox安装,改写后可以画每次迭代的损失值和预测准确率-程序员宅基地

技术标签: matlab  deep learning  MATLAB  Deep Learning  

 一、安装

建议大家更新MATLAB2022b,安装时就可以安装深度学习工具包,如果是之前的版本,可以通过以下方式安装。

1、GitHub下载deep Learning toolbox: 

https://github.com/rasmusbergpalm/DeepLearnToolbox

2、解压后的deep Learning toolbox文件夹(自动命名为DeepLearnToolbox-master)放到matlab安装根目录的toobox文件夹里。
3、添加路径,在命令行输入addpath(genpath(‘D:\MATLAB\toolbox\DeepLearnToolbox-master’)),这个路径要根据自己的安装位置修改。然后点击主页,点击设置路径,点击保存,每次开机就可以直接调用这个工具箱的函数了。

二、改写nntrain

工具箱中的原文件如下:

function [nn, L]  = nntrain(nn, train_x, train_y, opts, val_x, val_y)
%NNTRAIN trains a neural net
% [nn, L] = nnff(nn, x, y, opts) trains the neural network nn with input x and
% output y for opts.numepochs epochs, with minibatches of size
% opts.batchsize. Returns a neural network nn with updated activations,
% errors, weights and biases, (nn.a, nn.e, nn.W, nn.b) and L, the sum
% squared error for each training minibatch.

assert(isfloat(train_x), 'train_x must be a float');
assert(nargin == 4 || nargin == 6,'number ofinput arguments must be 4 or 6')

loss.train.e               = [];
loss.train.e_frac          = [];
loss.val.e                 = [];
loss.val.e_frac            = [];
opts.validation = 0;
if nargin == 6
    opts.validation = 1;
end

fhandle = [];
if isfield(opts,'plot') && opts.plot == 1
    fhandle = figure();
end

m = size(train_x, 1);

batchsize = opts.batchsize;
numepochs = opts.numepochs;

numbatches = m / batchsize;

assert(rem(numbatches, 1) == 0, 'numbatches must be a integer');

L = zeros(numepochs*numbatches,1);
n = 1;
for i = 1 : numepochs
    tic;
    
    kk = randperm(m);
    for l = 1 : numbatches
        batch_x = train_x(kk((l - 1) * batchsize + 1 : l * batchsize), :);
        
        %Add noise to input (for use in denoising autoencoder)
        if(nn.inputZeroMaskedFraction ~= 0)
            batch_x = batch_x.*(rand(size(batch_x))>nn.inputZeroMaskedFraction);
        end
        
        batch_y = train_y(kk((l - 1) * batchsize + 1 : l * batchsize), :);
        
        nn = nnff(nn, batch_x, batch_y);
        nn = nnbp(nn);
        nn = nnapplygrads(nn);
        
        L(n) = nn.L;
        
        n = n + 1;
    end
    
    t = toc;

    if opts.validation == 1
        loss = nneval(nn, loss, train_x, train_y, val_x, val_y);
        str_perf = sprintf('; Full-batch train mse = %f, val mse = %f', loss.train.e(end), loss.val.e(end));
    else
        loss = nneval(nn, loss, train_x, train_y);
        str_perf = sprintf('; Full-batch train err = %f', loss.train.e(end));
    end
    if ishandle(fhandle)
        nnupdatefigures(nn, fhandle, loss, opts, i);
    end
        
    disp(['epoch ' num2str(i) '/' num2str(opts.numepochs) '. Took ' num2str(t) ' seconds' '. Mini-batch mean squared error on training set is ' num2str(mean(L((n-numbatches):(n-1)))) str_perf]);
    nn.learningRate = nn.learningRate * nn.scaling_learningRate;
end
end

上面文件中,L是训练过程中的损失,每次迭代(epoch),计算每个批样本(batch)的损失值,这里对其进行改写,使其输出每次迭代后所有训练样本的损失值,以及预测测试集的准确率。

1、首先我们改变函数参数个数。

function [nn,Loss,accuracy, L]  = nntrain(nn, train_x, train_y, opts,test_x,test_y,val_x, val_y)

assert(nargin == 4 || nargin == 6|| nargin == 8,'number ofinput arguments must be 4 or 6 or 8')

if nargin == 8   
    opts.validation = 1;
end            

if nargin == 6   
    opts.test = 1;
end

参数中加入了测试集的input和labels,这里修改成输入参数为8个时,opts.validation=1。

增加一个opts.test参数。

给损失值和准确率分配空间。

Loss=zeros(numepochs,1);
accuracy=zeros(numepochs,1);

loss_batch(l)=nn.L;%计算一次迭代过程中,每个batch的损失

Loss(i)=sum(loss_batch)/numbatches;计算所有训练集的损失

下面我们判断测试集输入参数是否为[ ],不为[ ]时,计算预测准确率。

    if opts.test==1
        if isempty(test_x)||isempty(test_y)
            opts.test=0;
        else
            [er, bad] = nntest(nn, test_x, test_y);
            accuracy(i)=1-er;
        end
    end

以下是修改后的nntrain函数。

function [nn,Loss,accuracy, L]  = nntrain(nn, train_x, train_y, opts,test_x,test_y,val_x, val_y)
%NNTRAIN trains a neural net
% [nn, L] = nnff(nn, x, y, opts) trains the neural network nn with input x and
% output y for opts.numepochs epochs, with minibatches of size
% opts.batchsize. Returns a neural network nn with updated activations,
% errors, weights and biases, (nn.a, nn.e, nn.W, nn.b) and L, the sum
% squared error for each training minibatch.

assert(isfloat(train_x), 'train_x must be a float');
assert(nargin == 4 || nargin == 6|| nargin == 8,'number ofinput arguments must be 4 or 6 or 8')

loss.train.e               = [];
loss.train.e_frac          = [];
loss.val.e                 = [];
loss.val.e_frac            = [];
opts.validation = 0;
opts.test = 0;
if nargin == 8   
    opts.validation = 1;
end

if nargin == 6   
    opts.test = 1;
end
fhandle = [];
if isfield(opts,'plot') && opts.plot == 1
    fhandle = figure();
end

m = size(train_x, 1);

batchsize = opts.batchsize;
numepochs = opts.numepochs;

numbatches = m / batchsize;

assert(rem(numbatches, 1) == 0, 'numbatches must be a integer');

L = zeros(numepochs*numbatches,1);
n = 1;
Loss=zeros(numepochs,1);
accuracy=zeros(numepochs,1);
for i = 1 : numepochs
    tic;
    loss_batch=zeros(numbatches,1);
    kk = randperm(m);
    for l = 1 : numbatches
        batch_x = train_x(kk((l - 1) * batchsize + 1 : l * batchsize), :);
        
        %Add noise to input (for use in denoising autoencoder)
        if(nn.inputZeroMaskedFraction ~= 0)
            batch_x = batch_x.*(rand(size(batch_x))>nn.inputZeroMaskedFraction);
        end
        
        batch_y = train_y(kk((l - 1) * batchsize + 1 : l * batchsize), :);
        
        nn = nnff(nn, batch_x, batch_y);
        nn = nnbp(nn);
        nn = nnapplygrads(nn);
        
        L(n) = nn.L;
        n = n + 1;
        loss_batch(l)=nn.L;
    end
    t = toc;
    
    Loss(i)=sum(loss_batch)/numbatches;
    
    if opts.test==1
        if isempty(test_x)||isempty(test_y)
            opts.test=0;
        else
            [er, bad] = nntest(nn, test_x, test_y);
            accuracy(i)=1-er;
        end
    end
    
    if opts.validation == 1
        loss = nneval(nn, loss, train_x, train_y, val_x, val_y);
        str_perf = sprintf('; Full-batch train mse = %f, val mse = %f', loss.train.e(end), loss.val.e(end));
    else
        loss = nneval(nn, loss, train_x, train_y);
        str_perf = sprintf('; Full-batch train err = %f', loss.train.e(end));
    end
    if ishandle(fhandle)
        nnupdatefigures(nn, fhandle, loss, opts, i);
    end
    
        
    disp(['epoch ' num2str(i) '/' num2str(opts.numepochs) '. Took ' num2str(t) ' seconds' '. Mini-batch mean squared error on training set is ' num2str(mean(L((n-numbatches):(n-1)))) str_perf]);
    nn.learningRate = nn.learningRate * nn.scaling_learningRate;
end
end

 

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_60587058/article/details/123718619

智能推荐

[基于harbor部署私有仓库] 5 k8s使用harbor私有镜像仓库_安装kubsphere用什么镜像源-程序员宅基地

文章浏览阅读762次。上一篇,已经讲解了如何给harbor镜像仓库推送镜像。这一篇分享下,在k8s里头_安装kubsphere用什么镜像源

用java代码读取dbf文件_java 使用poi读取dbf文件-程序员宅基地

文章浏览阅读2.2k次。import com.linuxense.javadbf.DBFReader;import org.apache.poi.hssf.usermodel.HSSFWorkbook;import org.apache.poi.ss.usermodel.Cell;import org.apache.poi.ss.usermodel.Row;import org.apache.poi.ss.use..._java 使用poi读取dbf文件

TCP通信丢包主要问题及具体问题分析_tcp的客户端发送报文给服务器,不产生丢包或网络阻塞,但是数据不一致是因为-程序员宅基地

文章浏览阅读7.2k次。今天在公司问老大,公司的项目底层,是使用的TCP,因为可靠,自动断线重连,在底层都实现了,但是我记得TCP也会有掉包的问题,所以这文章就诞生了——关于TCP掉包的问题,TCP是基于不可靠的网络实现可靠的传输,肯定也会存在掉包的情况。 如果通信中发现缺少数据或者丢包,那么,最大的可能在于程序发送的过程或者接收的过程出现问题。 例如服务器给客户端发大量数据,Send的频率很高,_tcp的客户端发送报文给服务器,不产生丢包或网络阻塞,但是数据不一致是因为

centos7 配置本地yum源_yum clean all 已加载插件:fastestmirror 正在清理软件源: clouder-程序员宅基地

文章浏览阅读232次。[base-local] #唯一标识,不能重复name=CentOS-local #名字(随便)baseurl=file:///mnt/cdrom #上方步骤一挂载镜像创建的目录enabled=1 #yum源是否启用 1-启用 0-不启用gpgcheck=1 #对源进行检测,安全检测 1-开启 0-不开启,本地源一般不检测,网络源一般检测gpgkey=file:///etc/pki/r..._yum clean all 已加载插件:fastestmirror 正在清理软件源: cloudera-manager os7_

echarts基础语法_echarts splitnumber-程序员宅基地

文章浏览阅读3.4k次,点赞2次,收藏28次。一.首页知识点推荐:点击首页->可视化实验室里面有很多意想不到的宝藏二.名词解析1.基本名词xAxis 横坐标yAxis 纵坐标grid 整个坐标系是基于grid这个网格去定位的legend 图例dataRange 值域选择,常用于展现地域数据时选择值域范围dataZoom 数据区域缩放,常用于展现大量数据时选择可视范围toolbox 工具箱tooltip 气泡提示框,常用于展现更详细的数据timeline 时间轴series 存放数据的大数组.._echarts splitnumber

汽车厂商的摘星指南:我们能从如祺出行身上学到什么?-程序员宅基地

文章浏览阅读528次。自从Uber、滴滴、神州等等一系列企业之间的战争偃旗息鼓之后,网约车市场已经安静许久了。但不论任何市场,都会有变量的存在,就当人们认为网约车市场趋于稳固时,新的变量又出现..._但无论汽车厂商

随便推点

Kafka 消息监控 - Kafka Eagle_kafka有消息轨迹功能吗-程序员宅基地

文章浏览阅读1.4k次。1.概述  在开发工作当中,消费 Kafka 集群中的消息时,数据的变动是我们所关心的,当业务并不复杂的前提下,我们可以使用 Kafka 提供的命令工具,配合 Zookeeper 客户端工具,可以很方便的完成我们的工作。随着业务的复杂化,Group 和 Topic 的增加,此时我们使用 Kafka 提供的命令工具,已预感到力不从心,这时候 Kafka 的监控系统此刻便尤为显得重要,我们需要_kafka有消息轨迹功能吗

Kendall’s tau-b,pearson、spearman三种相关性的区别(有空整理信息检索评价指标)-程序员宅基地

文章浏览阅读1.6k次。同样可参考:http://blog.csdn.net/wsywl/article/details/5889419http://wenku.baidu.com/link?url=pEBtVQFzTx0I9T9vr01WS6_NmOY7EylNwa-suKpx3ab1YZfL4QvYsPt2chXyvXOvU3bBa_CrTOaZ0QV_KmcMCmTrqXvZQNKy-cLHQ8J2Y0q..._kendall tau 和线性相关系数 区别

HttpURLConnection上传文件(图片)小试-程序员宅基地

文章浏览阅读662次。需求:用HttpURLConnection模拟上传图片并把图片的名称也要传递过去.简单分析:写入流的时候依次写入 图片名称 + "|" 分隔符 + 图片流然后服务器接收的再处理流.分别取出图片名和图片./** *//** * 上传方法 * 返回上传完毕的文件名 * * */ public String upload(File f) { ..._httpurlconnection上传图片

docker windows10中安装node.js_windows docker 安装nodejs-程序员宅基地

文章浏览阅读1.3k次。docker windows10 中安装node.js启动docker服务获取node最新镜像运行镜像常用参数查看node版本号启动docker服务获取node最新镜像运行命令docker search node获取node镜像接着拉取node最新镜像,运行命令docker pull node当出现图中提示时,表明拉取镜像成功我们可以通过命令查看镜像来确认node是否拉取成功..._windows docker 安装nodejs

yii framework学习笔记-程序员宅基地

文章浏览阅读105次。一、验证和授权 1、基本验证授权方式 在控制器重重写filters方法,这个filter指定的是过滤器,可以是当前控制的方法,必须是以filter开头。 我们来看看通过yii 中示例中博客的例子。 <?phpclass TblPostController extends Controller{ /** * @return array 过滤器列表..._yii regeframe

springmvc 配置文件_配置文件bluej-程序员宅基地

文章浏览阅读296次。1、在同级目录下。默认就ok springMVC1 index.html index.htm index.jsp default.html default.htm default.jsp springMVC org.springframework.web.servlet.DispatcherSe_配置文件bluej

推荐文章

热门文章

相关标签