pytorch 神经网络套路 使用Dataset,DataLoader实现多维输入特征的二分类_多维特征二分类-程序员宅基地

技术标签: Pytorch  笔记  pytorch  神经网络  Python  分类  

1.数据集:

传送门:内含刘老师讲课视频PPT及相关数据集,本文所用数据集名为diabetes.cvs.gz

链接:https://pan.baidu.com/s/1vZ27gKp8Pl-qICn_p2PaSw
提取码:cxe4

其中,x1,,,x8表示不同特征,y表示分类

2.模型:

刘老师视频中采用以上模型,本文线性层输出特征改为4,2,1,其他保持不变。

loss:BCELoss

optimizer:SGD

3.python代码:

本文采用pytorch定义的Dataset,DataLoader,以Minibatch的风格加载数据集

import numpy as np
import torch
from torch import nn
from torch.nn import Linear, BCELoss
from torch.optim import SGD
import matplotlib.pyplot as plt

# 准备DataSet,DataLoader
from torch.utils.data import Dataset, DataLoader

# 数据集路径名
path = "diabetes.csv.gz"


# 数据集类
class dataset(Dataset):
    def __init__(self, path):
        # 因为数据集较小,直接加载在内存里
        xy = np.loadtxt(path, delimiter=',', dtype=np.float32)
        self.length = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    def __getitem__(self, item):
        return self.x_data[item], self.y_data[item]

    def __len__(self):
        return self.length


# 数据集类实例化
my_dataset = dataset(path)

train_loader = DataLoader(my_dataset, batch_size=10, shuffle=True)


# 建立模型,3个线性层,3个sigmoid非线性激活函数
class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()
        self.linear1 = Linear(8, 4, bias=True)
        self.linear2 = Linear(4, 2, bias=True)
        self.linear3 = Linear(2, 1, bias=True)

    def forward(self, x):
        x = torch.sigmoid(self.linear1(x))
        x = torch.sigmoid(self.linear2(x))
        x = torch.sigmoid(self.linear3(x))
        return x


# 类实例化
my_model = model()

# 二分类问题,继续采用BCELoss
loss_cal = BCELoss(size_average=True)

# 随机梯度下降
optimizer = SGD(my_model.parameters(), lr=0.01)

# 空列表
epoch_list = []
loss_list = []

for epoch in range(100000):
    for data in train_loader:
        x, y = data
        epoch_list.append(epoch)
        # 前向计算
        y_pred = my_model(x)
        loss = loss_cal(y_pred, y)
        loss_list.append(loss.item())
        # 梯度清零
        optimizer.zero_grad()
        # 反向传播
        loss.backward()
        # 参数调整
        optimizer.step()

# 画出loss随epoch变化曲线图
plt.figure()
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()
import torch
import numpy as np
from torch import nn
from torch.nn import Module

x_arr=np.squeeze(np.array([np.random.rand(10,1) for i in range(10)]),2)
y_arr=np.array([[np.random.randint(0,2)] for i in range(10)])

x_list_tensor=torch.from_numpy(x_arr).float()
y_list_tensor=torch.from_numpy(y_arr).float()

class model(Module):
 def __init__(self):
  super(model,self).__init__()
  self.linear=nn.Linear(10,1)

 def forward(self,x):
  x=self.linear(x)
  x=torch.sigmoid(x)
  return x

my_model=model()

optimizer=torch.optim.SGD(my_model.parameters(),lr=1e-3,momentum=0.08,weight_decay=0.001)

criterition=torch.nn.BCELoss(size_average=True)

for i in range(1000):
 my_model.train()
 for x,y in zip(x_list_tensor,y_list_tensor):
  y_pred=my_model(x)
  print(x,y,y_pred)
  loss=criterition(y_pred,y)

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  print(my_model.linear.weight.data)

4.可视化结果:

随着epoch增加,loss逐渐减小并收敛。

5.以上均为个人学习pytorch基础入门中的基础,浅做记录,如有错误,请各位大佬批评指正!

6.关于问题描述和原理的部分图片参考刘老师的视频课件,本文也是课后作业的一部分,特此附上视频链接,《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili,希望大家都有所进步!

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/kids_budong_c/article/details/123195213

智能推荐

使用nginx解决浏览器跨域问题_nginx不停的xhr-程序员宅基地

文章浏览阅读1k次。通过使用ajax方法跨域请求是浏览器所不允许的,浏览器出于安全考虑是禁止的。警告信息如下:不过jQuery对跨域问题也有解决方案,使用jsonp的方式解决,方法如下:$.ajax({ async:false, url: 'http://www.mysite.com/demo.do', // 跨域URL ty..._nginx不停的xhr

在 Oracle 中配置 extproc 以访问 ST_Geometry-程序员宅基地

文章浏览阅读2k次。关于在 Oracle 中配置 extproc 以访问 ST_Geometry,也就是我们所说的 使用空间SQL 的方法,官方文档链接如下。http://desktop.arcgis.com/zh-cn/arcmap/latest/manage-data/gdbs-in-oracle/configure-oracle-extproc.htm其实简单总结一下,主要就分为以下几个步骤。..._extproc

Linux C++ gbk转为utf-8_linux c++ gbk->utf8-程序员宅基地

文章浏览阅读1.5w次。linux下没有上面的两个函数,需要使用函数 mbstowcs和wcstombsmbstowcs将多字节编码转换为宽字节编码wcstombs将宽字节编码转换为多字节编码这两个函数,转换过程中受到系统编码类型的影响,需要通过设置来设定转换前和转换后的编码类型。通过函数setlocale进行系统编码的设置。linux下输入命名locale -a查看系统支持的编码_linux c++ gbk->utf8

IMP-00009: 导出文件异常结束-程序员宅基地

文章浏览阅读750次。今天准备从生产库向测试库进行数据导入,结果在imp导入的时候遇到“ IMP-00009:导出文件异常结束” 错误,google一下,发现可能有如下原因导致imp的数据太大,没有写buffer和commit两个数据库字符集不同从低版本exp的dmp文件,向高版本imp导出的dmp文件出错传输dmp文件时,文件损坏解决办法:imp时指定..._imp-00009导出文件异常结束

python程序员需要深入掌握的技能_Python用数据说明程序员需要掌握的技能-程序员宅基地

文章浏览阅读143次。当下是一个大数据的时代,各个行业都离不开数据的支持。因此,网络爬虫就应运而生。网络爬虫当下最为火热的是Python,Python开发爬虫相对简单,而且功能库相当完善,力压众多开发语言。本次教程我们爬取前程无忧的招聘信息来分析Python程序员需要掌握那些编程技术。首先在谷歌浏览器打开前程无忧的首页,按F12打开浏览器的开发者工具。浏览器开发者工具是用于捕捉网站的请求信息,通过分析请求信息可以了解请..._初级python程序员能力要求

Spring @Service生成bean名称的规则(当类的名字是以两个或以上的大写字母开头的话,bean的名字会与类名保持一致)_@service beanname-程序员宅基地

文章浏览阅读7.6k次,点赞2次,收藏6次。@Service标注的bean,类名:ABDemoService查看源码后发现,原来是经过一个特殊处理:当类的名字是以两个或以上的大写字母开头的话,bean的名字会与类名保持一致public class AnnotationBeanNameGenerator implements BeanNameGenerator { private static final String C..._@service beanname

随便推点

二叉树的各种创建方法_二叉树的建立-程序员宅基地

文章浏览阅读6.9w次,点赞73次,收藏463次。1.前序创建#include<stdio.h>#include<string.h>#include<stdlib.h>#include<malloc.h>#include<iostream>#include<stack>#include<queue>using namespace std;typed_二叉树的建立

解决asp.net导出excel时中文文件名乱码_asp.net utf8 导出中文字符乱码-程序员宅基地

文章浏览阅读7.1k次。在Asp.net上使用Excel导出功能,如果文件名出现中文,便会以乱码视之。 解决方法: fileName = HttpUtility.UrlEncode(fileName, System.Text.Encoding.UTF8);_asp.net utf8 导出中文字符乱码

笔记-编译原理-实验一-词法分析器设计_对pl/0作以下修改扩充。增加单词-程序员宅基地

文章浏览阅读2.1k次,点赞4次,收藏23次。第一次实验 词法分析实验报告设计思想词法分析的主要任务是根据文法的词汇表以及对应约定的编码进行一定的识别,找出文件中所有的合法的单词,并给出一定的信息作为最后的结果,用于后续语法分析程序的使用;本实验针对 PL/0 语言 的文法、词汇表编写一个词法分析程序,对于每个单词根据词汇表输出: (单词种类, 单词的值) 二元对。词汇表:种别编码单词符号助记符0beginb..._对pl/0作以下修改扩充。增加单词

android adb shell 权限,android adb shell权限被拒绝-程序员宅基地

文章浏览阅读773次。我在使用adb.exe时遇到了麻烦.我想使用与bash相同的adb.exe shell提示符,所以我决定更改默认的bash二进制文件(当然二进制文件是交叉编译的,一切都很完美)更改bash二进制文件遵循以下顺序> adb remount> adb push bash / system / bin /> adb shell> cd / system / bin> chm..._adb shell mv 权限

投影仪-相机标定_相机-投影仪标定-程序员宅基地

文章浏览阅读6.8k次,点赞12次,收藏125次。1. 单目相机标定引言相机标定已经研究多年,标定的算法可以分为基于摄影测量的标定和自标定。其中,应用最为广泛的还是张正友标定法。这是一种简单灵活、高鲁棒性、低成本的相机标定算法。仅需要一台相机和一块平面标定板构建相机标定系统,在标定过程中,相机拍摄多个角度下(至少两个角度,推荐10~20个角度)的标定板图像(相机和标定板都可以移动),即可对相机的内外参数进行标定。下面介绍张氏标定法(以下也这么称呼)的原理。原理相机模型和单应矩阵相机标定,就是对相机的内外参数进行计算的过程,从而得到物体到图像的投影_相机-投影仪标定

Wayland架构、渲染、硬件支持-程序员宅基地

文章浏览阅读2.2k次。文章目录Wayland 架构Wayland 渲染Wayland的 硬件支持简 述: 翻译一篇关于和 wayland 有关的技术文章, 其英文标题为Wayland Architecture .Wayland 架构若是想要更好的理解 Wayland 架构及其与 X (X11 or X Window System) 结构;一种很好的方法是将事件从输入设备就开始跟踪, 查看期间所有的屏幕上出现的变化。这就是我们现在对 X 的理解。 内核是从一个输入设备中获取一个事件,并通过 evdev 输入_wayland

推荐文章

热门文章

相关标签