Pytorch分布式训练DataParallel和DistributedDataParallel详解-程序员宅基地

技术标签: 工作积累  Pytorch  python  深度学习  pytorch  编程语言  

最近工作涉及到修改分布式训练代码,以前半懂非懂,这次改的时候漏了一些细节,带来不必要的麻烦,索性花点时间搞明白。

Pytorch 分布式训练主要有两种方式:

torch.nn.DataParallel ==> 简称 DP
torch.nn.parallel.DistributedDataParallel ==> 简称DDP

其中 DP 只用于单机多卡,DDP 可以用于单机多卡也可用于多机多卡,后者现在也是Pytorch训练的主流用法,DP写法比较简单,但即使在单机多卡情况下也比 DDP 慢。

可参考:https://pytorch.org/docs/stable/nn.html#dataparallel-layers-multi-gpu-distributed 。

本文主要介绍DP和DDP的使用方式。

DP

import torch
import torch.nn as nn

# 构造模型
net = model(imput_size, output_size)

# 模型放在GPU上
net = net.cuda()
net=nn.DataParallel(net)

# 数据放在GPU上
inputs, labels = inputs.cuda(), labels.cuda()

result = net(inputs)

# 其他和正常模型训练无差别

关于Dataparallel, 摘取主要源码:

class DataParallel(Module):
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallel, self).__init__()
        
        # 如果没有GPU可用,直接返回
        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []
            return

        # 如果有GPU,但没有指定的话,device_ids为所有可用GPU
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
            
        # 默认输出在0号卡上
        if output_device is None:
            output_device = device_ids[0]

总结

如果不设定好要使用的device_ids的话, 程序会自动找到这个机器上面可以用的所有的显卡用于训练。

如果想要限制使用的显卡数,怎么办呢?

在代码最前面使用:

os.environ['CUDA_VISIBLE_DEVICES'] == '0,5'
# 限制代码能看到的GPU个数,这里表示指定只使用实际的0号和5号GPU
# 注意:这里的赋值必须是字符串,list会报错

# 这时候device_count = 2
device_ids = range(torch.cuda.device_count()) 

# device_ids = [0,1] 这里的0就是上述指定的'0'号卡,1对应'5'号卡。
net = nn.DataParallel(net,device_ids)

# !!!模型和数据都由主gpu(0号卡)分发。

值得注意的是,在使用os.environ['CUDA_VISIBLE_DEVICES']对可以使用的显卡进行限定之后, 显卡的实际编号和程序看到的编号应该是不一样的

例如上面我们设定的是os.environ['CUDA_VISIBLE_DEVICES']="0,5", 但是程序看到的显卡编号应该被改成了'0,1'

也就是说程序所使用的显卡编号实际上是经过了一次映射之后才会映射到真正的显卡编号上面的, 例如这里的程序看到的1对应实际的5。

但是Dataparallel会带来显存的使用不平衡,具体分析见参考链接[2],而且碰到大的任务,时间和能力上都很受限。

DDP

为了弥补Dataparallel的不足,有torch.nn.parallel.DistributedDataParallel,这也是现在Pytorch分布式训练主推的。

DDP支持单机多卡和多机多卡,每张卡都有一个进程,这就涉及到进程通信,多进程通信初始化,是DDP使用最复杂的地方。

具体看下:

torch.distributed.init_process_group( )

详见:https://pytorch.org/docs/stable/distributed.html

常用参数:

  • backend: 后端, 实际上是多个机器之间交换数据的协议,官方和很多用户都强烈推荐’nccl’作为backend。

  • init_method: 机器之间交换数据需要指定一个主节点, 这个参数用来指定主节点的。

  • world_size: 参与job的进程数, 实际就是GPU的个数;

  • rank: 进程组中每个进程的唯一标识符。比如一个节点8张卡,world_size为8,每张卡的rank是对应的0-7的连续整数。

  • 顺便解释下local_rank: 假设有两个节点/机器,每个节点有8张卡,总共16张卡,对应16个进程。global rank是指0-15,对于节点1,local_rank为0-7,对于节点2,local_rank也是0-7。

初始化init_method的方法有两种, 一种是使用TCP进行初始化, 另外一种是使用共享文件系统进行初始化。

Pytorch作者推荐了这种初始化方式,来源见水印和参考链接,
https://pic3.zhimg.com/v2-44d292f6a84f99a107cd2fe7c0bb7a56_r.jpg
我们平常在集群上操作,可以通过os.environ获取每个进程的节点ip信息,全局rank以及local rank。

关于获取节点信息的详细代码:

import os

os.environ['SLURM_NTASKS']          # 可用作world size
os.environ['SLURM_NODEID']          # node id
os.environ['SLURM_PROCID']          # 可用作全局rank
os.environ['SLURM_LOCALID']         # local_rank
os.environ['SLURM_NODELIST']   # 从中取得一个ip作为通讯ip

单机多卡,主机只有一个相对没有那么复杂,按照官网推荐的设置就好。

因此,torch中DDP的使用如下方式:

import os
import re

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# 1. 获取环境信息
rank = int(os.environ['SLURM_PROCID'])
world_size = int(os.environ['SLURM_NTASKS'])
local_rank = int(os.environ['SLURM_LOCALID'])
node_list = str(os.environ['SLURM_NODELIST'])       

# 对ip进行操作
node_parts = re.findall('[0-9]+', node_list)
host_ip = '{}.{}.{}.{}'.format(node_parts[1], node_parts[2], node_parts[3], node_parts[4])

 # 注意端口一定要没有被使用
port = "23456"                                         

 # 使用TCP初始化方法
init_method = 'tcp://{}:{}'.format(host_ip, port)      

# 多进程初始化,初始化通信环境
dist.init_process_group("nccl", init_method=init_method,
                        world_size=world_size, rank=rank) 

# 指定每个节点上的device
torch.cuda.set_device(local_rank)
                     
model = model.cuda()

# 当前模型所在local_rank
model = DDP(model, device_ids=[local_rank])             # 指定当前卡上的GPU号

input = input.cuda()
output = model(input)

# 此后训练流程与普通模型无异

最近官方表述中加了一个store参数,更新了下使用方法,大差不差。
具体参考:https://pytorch.org/docs/stable/distributed.html

使用TCP进行初始化,需要读取ip,我们在集群上通过os.environ可以很方便完成初始化。

我平常提交任务的slurm指令这样写:

# 单机多卡
# 8个任务对应8个进程,每个节点上跑8个任务

srun -n8 --gres=gpu:8 --ntasks-per-node=8 python train.py
# 多机多卡
# 16个任务对应16个进程,每个节点最多跑8个任务/进程,每张卡占满8个GPU
# 因此这里是申请了16/8=2个节点,即在两个机器上跑。

srun -n16 --gres=gpu:8 --ntasks-per-node=8 python train.py

参考:
[1]https://blog.csdn.net/weixin_40087578/article/details/87186613
[2]https://zhuanlan.zhihu.com/p/86441879
[3]https://zhuanlan.zhihu.com/p/68717029

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_38278334/article/details/105605994

智能推荐

Multi-Scale Guided Concurrent Reflection Removal Network_reflection network-程序员宅基地

文章浏览阅读364次。gradient inference network(GiN):输入是4通道张量,它是输入混合图像及其对应梯度的组合.The image inference network (IiN):以混合图像为输入,提取描述全局结构和高层语义信息的背景特征表示来估计B和R。GIN网络用的是一个镜像框架结构,即首尾结构对称(分别对应编码和解码结构)。编码结构由五个卷积层构成,先一个步长1..._reflection network

select2 下拉选择后首次未触发change事件_select2 change-程序员宅基地

文章浏览阅读7.5k次。问题现象: select2 下拉选择框,首次切换到“全部”选项不会触发change事件。问题背景: select2 下拉选择框,有设置默认值(非全部),在加载数据时,改动后端返回数据,加了一条“全部”的下拉选择内容:list.unshift({'id':'','text':'全部'}); 问题分析: 首次切换到“全部”选项后,并未触发change事件; 而首次切换到其他..._select2 change

python升序降序_python 根据两个字段排序, 一个升序, 一个降序-程序员宅基地

文章浏览阅读867次。给定一个字符串, 输出出现次数最多的前三个字符, 若两字符出现次数相同, 则按字典顺序排列.# 样例输入aabbbccde# 样例输出b 3a 2c 2就是先将第二字段降序排序, 再将第一字段升序排序, 关键就是sorted函数key的指定, 可以用 lambda 或operator.itemgetter开始我是这样做的:from collections import Counterc = Cou..._py 多字段排序先升序再降序

红帽 RHEL power8 rhel-server-7.2-ppc64le-dvd.iso-程序员宅基地

文章浏览阅读3.1k次。红帽 RHEL power8 服务器小端版本,找了很久才找到,官方不提供下载了,放这里给大家对于没有HMC,没有显卡的小机运维,不想搭一堆环境的人来说是福音,SUSE11没有小端版本,12用的引导界面在SMS下全是乱码,只能用吐血两字来形容,centos 7 装上会出现IOA口驱动失败的情况,目前找不到原因,这个是官方支持带驱动的,rhel-server-7.2-ppc64le-dvd.i_rhel-server-7.2-ppc64le-dvd.iso

android studio 配置Kotlin环境_kotlin怎么指定jdk-程序员宅基地

文章浏览阅读9.9k次,点赞2次,收藏10次。2017年随着google发布了Kotlin作为android的一级语言,与java100%互通。开发者就陆陆续续从java转到Kotlin中了,我现在有学习了Kotlin几天,的确感觉Kotlin写起来非常简洁,下面我介绍一下如何在android studio配置Kotlin环境。步骤1.在android studio中下载插件(windows)点击File->Setting->..._kotlin怎么指定jdk

解决腾讯ACE-Guard进程导致的英雄联盟掉帧问题_lol 屏蔽 anticheatexpert_aceguard-程序员宅基地

文章浏览阅读108次。REM 如果不想该程序一次性查杀后就退出,可以注释掉16行的"exit",这样程序就会一直后台监控。_aceguard

随便推点

Servlet--Request生命周期_tomcat中request的生命周期-程序员宅基地

文章浏览阅读5k次,点赞4次,收藏13次。Servlet--Request生命周期一、Request、Response对象的生命周期1、浏览器像servlet发送请求2、tomcat收到请求后,创建Request和Response两个对象的生命周期,并且将浏览器请求的参数传递给Servlet3、Servlet接收到请求后,调用doget或者dopost方法。处理浏览器的请求信息,然后通过Response返回_tomcat中request的生命周期

解决GitHub不能访问的几个办法_github打不开-程序员宅基地

文章浏览阅读7.9w次,点赞16次,收藏80次。GitHub页面时而能访问,时而不能。不是慢,而是不能访问。当然,下载它的比如仓库Release下的压缩包比较慢则是另一回事。蛋疼的影响不限于打不开页面,更多的在于不能git pull和git push等操作。范围方面,凡国内不管是家宽、移动网络还是云上的,都受到一致的影响。_github打不开

经纬恒润测试开发面经_经纬恒润 面试-程序员宅基地

文章浏览阅读4.5k次,点赞3次,收藏21次。9.24 15:00 电话一面 35min面试官是一个声音巨好听的小哥哥......,迷恋ing,而且也超级温柔,嘻嘻嘻嘻嘻嘻1.自我介绍2.讲项目是不是自己做的 怎么做的 项目分工 担任角色 项目测试(全程死抠测试,单元测试死抠.....) 为什么做这个项目3.对软件测试的理解4.针对我的专业有疑问,主修课程有哪些,5.为什么做测试6.你觉得互..._经纬恒润 面试

19款最好用的免费数据挖掘工具大汇总(干货)_)好用(19)-程序员宅基地

文章浏览阅读6.7k次,点赞2次,收藏14次。数据在当今世界意味着金钱。随着向基于app的世界的过渡,数据呈指数增长。然而,大多数数据是非结构化的,因此需要一个过程和方法从数据中提取有用的信息,并将其转换为可理解的和可用的形式。数据挖掘或“数据库中的知识发现”是通过人工智能、机器学习、统计和数据库系统发现大数据集中的模式的过程。免费的数据挖掘工具包括从完整的模型开发环境如Knime和Orange,到各种用Java、c++编写的库,最常..._)好用(19)

java-net-php-python-jspm米兰酒店管理系统计算机毕业设计程序_米兰酒店管理系统登录-程序员宅基地

文章浏览阅读191次。java-net-php-python-jspm米兰酒店管理系统计算机毕业设计程序。springboot基于B_S模式的后勤管理系统-在线报修系统。springcloud基于微服务架构的乐居租房网的设计与实现。springboot基于springboot的社会公益平台。ssm基于web的考试资料交易系统的设计与实现。ssm基于JEE的人才招聘系统的智能化管理。springboot中国民航酒店分销系统。_米兰酒店管理系统登录

成本中心通过利润中心来和公司代码对应_sap 成本中心关联公司-程序员宅基地

文章浏览阅读7.4k次,点赞2次,收藏2次。成本中心是无法直接和公司代码进行配对的。但是利润中心能够绑定公司代码再通过利润中心的对应公司代码可以进行成本中心对应公司代码的对应_sap 成本中心关联公司

推荐文章

热门文章

相关标签