python怎么使用预训练的模型_PyTorch使用预训练模型-程序员宅基地

技术标签: python怎么使用预训练的模型  

PyTorch模型加载的时候,有预训练模型,通过使用预训练模型可以给模型使用带来很多的便捷,对于模型的使用以下给出了一些总结,如有错误恳请指正。

一、直接加载预训练模型进行训练

1、加载保存的整个模型

torch.save(model,'model.pkl')

...

model = torch.load('model.pkl')

2、加载保存的模型参数

torch.save(model.state_dict(),'model_state_dict.pkl')

...

model.load_state_dict(torch.load('model_state_dict.pkl'))

关于模型的保存和加载,可以详细参照我的这篇文章:HUST小菜鸡:Pytorch搭建简单神经网络(三)——快速搭建、保存与提取​zhuanlan.zhihu.comv2-5aed9b4858ee329f1de0b9d5ff33ce4a_180x120.jpg

通过对模型参数的保存的解析,我们可以深入的了解

load_dict = torch.load('models/cifar10_statedict.pkl')

print(load_dict.keys())

print(type(load_dict))

输出的结果如下所示:

odict_keys(['conv1.0.weight', 'conv1.0.bias', 'conv2.0.weight', 'conv2.0.bias', 'conv3.0.weight', 'conv3.0.bias', 'conv4.0.weight', 'conv4.0.bias', 'conv5.0.weight', 'conv5.0.bias', 'conv6.0.weight', 'conv6.0.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.5.weight', 'classifier.5.bias'])

可以看出保存的state_dict其实是一个collections.OrderedDict的Object,和普通的dict不同的是,该类别是有着严格的顺序,而dict中的元素是没有严格的顺序。

但是有一个问题值得深入考量——两个网络的结构是一样的,但是结构的命名是不一样的,那么对于这种模型的加载,如果不一样的话会出现报错,该如何解决

参照以上结果的输出,state_dict中key就是网络结构的名称,所以当网络结构一样的时候,只需要修改索引key,就可以解决以上的问题,至于如何修改可以参照如下方式:https://stackoverflow.com/questions/12150872/change-key-in-ordereddict-without-losing-order​stackoverflow.com

二、加载部分预训练模型

我们经常对现有的经典网络进行如下操作,我们不修改网络的主体部分,我们只修改网络的输出,或者在最后加上一些网络层来达到我们想要的输出结果,虽然很难保证网络模型和某些公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

model = cifar10_cnn.CIFAR10_Nettest()

pretrained_dict = torch.load('models/cifar10_statedict.pkl')

model_dict = model.state_dict()

print('随机初始化权重第一层:',model_dict['conv1.0.weight'])

# 将pretrained_dict里不属于model_dict的键剔除掉

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

print('预训练权重第一层:',pretrained_dict['conv1.0.weight'])

# 更新现有的model_dict

model_dict.update(pretrained_dict) #利用预训练模型的参数,更新模型

model.load_state_dict(model_dict)

print('更新后权重第一层:',model_dict['conv1.0.weight'])

输出的部分结果如下所示,为了直观显示我只截取了中间的某一部分

随机初始化权重第一层: tensor([[[[ 0.0142, 0.1039, 0.1260],

[ 0.1805, -0.0533, 0.0007],

[-0.1032, -0.1039, -0.0633]],

[[ 0.0714, -0.0053, 0.0059],

[-0.0528, 0.0438, -0.1108],

[ 0.0544, 0.0157, 0.1265]],

预训练权重第一层: tensor([[[[ 8.0685e-02, -3.8643e-02, 3.4450e-02],

[-2.3942e-01, -1.5474e-01, 1.3142e-01],

[-9.4602e-02, 6.4120e-02, -9.4336e-02]],

[[ 9.7318e-02, 1.0526e-01, 2.3400e-03],

[-5.8471e-02, -8.8146e-02, -1.6053e-01],

[-1.0788e-01, -5.9083e-02, -9.0651e-02]],

更新后权重第一层: tensor([[[[ 8.0685e-02, -3.8643e-02, 3.4450e-02],

[-2.3942e-01,

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_39567169/article/details/109940576

智能推荐

WCE Windows hash抓取工具 教程_wce.exe -s aaa:win-9r7tfgsiqkf:0000000000000000000-程序员宅基地

文章浏览阅读6.9k次。WCE 下载地址:链接:https://share.weiyun.com/5MqXW47 密码:bdpqku工具界面_wce.exe -s aaa:win-9r7tfgsiqkf:00000000000000000000000000000000:a658974b892e

各种“网络地球仪”-程序员宅基地

文章浏览阅读4.5k次。Weather Globe(Mackiev)Google Earth(Google)Virtual Earth(Microsoft)World Wind(NASA)Skyline Globe(Skylinesoft)ArcGISExplorer(ESRI)国内LTEarth(灵图)、GeoGlobe(吉奥)、EV-Globe(国遥新天地) 软件名称: 3D Weather Globe(http:/_网络地球仪

程序员的办公桌上,都出现过哪些神奇的玩意儿 ~_程序员展示刀,产品经理展示枪-程序员宅基地

文章浏览阅读1.9w次,点赞113次,收藏57次。我要买这些东西,然后震惊整个办公室_程序员展示刀,产品经理展示枪

霍尔信号、编码器信号与电机转向-程序员宅基地

文章浏览阅读1.6w次,点赞7次,收藏63次。霍尔信号、编码器信号与电机转向从电机出轴方向看去,电机轴逆时针转动,霍尔信号的序列为编码器信号的序列为将霍尔信号按照H3 H2 H1的顺序组成三位二进制数,则霍尔信号翻译成状态为以120°放置霍尔为例如不给电机加电,使用示波器测量三个霍尔信号和电机三相反电动势,按照上面所说的方向用手转动电机得到下图① H1的上升沿对应电机q轴与H1位置电角度夹角为0°,..._霍尔信号

个人微信淘宝客返利机器人搭建教程_怎么自己制作返利机器人-程序员宅基地

文章浏览阅读7.1k次,点赞5次,收藏36次。个人微信淘宝客返利机器人搭建一篇教程全搞定天猫淘宝有优惠券和返利,仅天猫淘宝每年返利几十亿,你知道么?技巧分享:在天猫淘宝京东拼多多上挑选好产品后,按住标题文字后“复制链接”,把复制的淘口令或链接发给机器人,复制机器人返回优惠券口令或链接,再打开天猫或淘宝就能领取优惠券啦下面教你如何搭建一个类似阿可查券返利机器人搭建查券返利机器人前提条件1、注册微信公众号(订阅号、服务号皆可)2、开通阿里妈妈、京东联盟、拼多多联盟一、注册微信公众号https://mp.weixin.qq.com/cgi-b_怎么自己制作返利机器人

【团队技术知识分享 一】技术分享规范指南-程序员宅基地

文章浏览阅读2.1k次,点赞2次,收藏5次。技术分享时应秉持的基本原则:应有团队和个人、奉献者(统筹人)的概念,同时匹配团队激励、个人激励和最佳奉献者激励;团队应该打开工作内容边界,成员应该来自各内容方向;评分标准不应该过于模糊,否则没有意义,应由客观的基础分值以及分团队的主观综合结论得出。应有心愿单激励机制,促进大家共同聚焦到感兴趣的事情上;选题应有规范和框架,具体到某个小类,这样收获才有目标性,发布分享主题时大家才能快速判断是否是自己感兴趣的;流程和分享的模版应该有固定范式,避免随意的格式导致随意的内容,评分也应该部分参考于此;参会原则,应有_技术分享

随便推点

O2OA开源企业办公开发平台:使用Vue-CLI开发O2应用_vue2 oa-程序员宅基地

文章浏览阅读1k次。在模板中,我们使用了标签,将由o2-view组件负责渲染,给o2-view传入了两个参数:app="内容管理数据"和name="所有信息",我们将在o2-view组件中使用这两个参数,用于展现“内容管理数据”这个数据应用下的“所有信息”视图。在o2-view组件中,我们主要做的事是,在vue组件挂载后,将o2的视图组件,再挂载到o2-view组件的根Dom对象。当然,这里我们要在我们的O2服务器上创建好数据应用和视图,对应本例中,就是“内容管理数据”应用下的“所有信息”视图。..._vue2 oa

[Lua]table使用随笔-程序员宅基地

文章浏览阅读222次。table是lua中非常重要的一种类型,有必要对其多了解一些。

JAVA反射机制原理及应用和类加载详解-程序员宅基地

文章浏览阅读549次,点赞30次,收藏9次。我们前面学习都有一个概念,被private封装的资源只能类内部访问,外部是不行的,但这个规定被反射赤裸裸的打破了。反射就像一面镜子,它可以清楚看到类的完整结构信息,可以在运行时动态获取类的信息,创建对象以及调用对象的属性和方法。

Linux-LVM与磁盘配额-程序员宅基地

文章浏览阅读1.1k次,点赞35次,收藏12次。Logical Volume Manager,逻辑卷管理能够在保持现有数据不变的情况下动态调整磁盘容量,从而提高磁盘管理的灵活性/boot分区用于存放引导文件,不能基于LVM创建PV(物理卷):基于硬盘或分区设备创建而来,生成N多个PE,PE默认大小4M物理卷是LVM机制的基本存储设备,通常对应为一个普通分区或整个硬盘。创建物理卷时,会在分区或硬盘的头部创建一个保留区块,用于记录 LVM 的属性,并把存储空间分割成默认大小为 4MB 的基本单元(PE),从而构成物理卷。

车充产品UL2089安规测试项目介绍-程序员宅基地

文章浏览阅读379次,点赞7次,收藏10次。4、Dielecteic voltage-withstand test 介电耐压试验。1、Maximum output voltage test 输出电压试验。6、Resistance to crushing test 抗压碎试验。8、Push-back relief test 阻力缓解试验。7、Strain relief test 应变消除试验。2、Power input test 功率输入试验。3、Temperature test 高低温试验。5、Abnormal test 故障试验。

IMX6ULL系统移植篇-系统烧写原理说明_正点原子 imx6ull nand 烧录-程序员宅基地

文章浏览阅读535次。镜像烧写说明_正点原子 imx6ull nand 烧录