Pytorch – 梯度累积/梯度累加trick,在显存有限的情况下使用更大batch_size训练模型
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – 梯度累积/梯度累加trick,在显存有限的情况下使用更大batch_size训练模型
原文链接:https://www.stubbornhuang.com/2444/
发布于:2022年12月09日 14:13:29
修改于:2022年12月09日 14:13:29
1 batch size对模型训练的影响
小的batch size引入的数据集的数据量较小,随机性越大,在部分情况下模型难以收敛,影响模型训练效率。
而在合理的范围内,越大的batch size本质上是对训练数据更优的一种选择,能够是梯度下降的方向更加准确,震荡越小,有利于收敛的稳定性。
但是如果batch size过大,超出了一个合理的范围,会限制模型的探索能力,出现局部最优的情况。
所以,在模型训练的过程中,bacth size太小和太大都不是一个好的选择。
2 为什么要使用梯度累积(gradient accumulation)的方案?
深度学习发展到今天,数据已经从图片,文本逐渐发展到了音频、视频这种更复杂,更需要高维度表示的数据。
试想一下,如果我们需要基于一个超百万个的视频数据集进行模型训练,需要使用Resnet50或者Resnet101作为backbone对每一个视频的视频帧提取特征并进行组合,假设每个视频有250帧,那么我们使用batch size=2进行训练,就等于我们需要对500张图片都进行Resnet50或者Resnet101进行计算,而如果我们只有一张3090(24G),那么在这种情况下24G的显存显然是不够用的。
在显卡的显存不够多,不足以支撑大的batch_size数据训练的情况下,而我们又想使用大的batch_size进行模型的训练,那么这个时候我们就可以使用梯度累计的方式进行优化,以防止显存爆炸。
3 在Pytorch中使用梯度累积
在Pytorch中反向传播梯度是不清零的,所以要实现梯度累积是比较简单的。
3.1 常用的训练模式
在Pytorch中,训练一个epoch训练的常用代码如下
for batch_idx, (input_id, label) in enumerate(train_loader):
# 1. 模型输出
pred = model(input_id)
loss = criterion(pred, label)
# 2. 反向传播
optimizer.zero_grad() # 梯度清空
loss.backward() # 反向传播,计算梯度
optimizer.step() # 根据梯度,更新网络参数
总结步骤如下:
- 计算loss,获取batch输入,计算model输出,通过损失函数计算loss
- optimizer.zero_grad() 清空之前的梯度
- loss.backward()反向传播,计算当前梯度
- optimizer.step()根据梯度更新网络参数
简单来说,就是进来一个batch的数据,计算一次梯度,更新一次网络。
3.2 梯度累积
我们将上述代码修改为梯度累积的方式
for batch_idx ,(input_id, label) in enumerate(train_loader):
# 1. 模型输出
pred = model(input_id)
loss = criterion(pred, label)
# 2.1 损失规范
loss = loss / accumulation_steps
# 2.2 反向传播,计算梯度
loss.backward()
if (batch_idx+1) % accumulation_steps == 0:
optimizer.step() # 更新参数
optimizer.zero_grad() # 梯度清空,为下一次反向传播做准备
总结步骤如下:
- 计算loss,获取batch输入,计算model输出,通过损失函数计算loss
- loss.backward()反向传播,计算当前梯度
- 多次循环步骤 1-2,不清空梯度,使梯度累加在已有梯度上;
- 梯度累加了一定次数后,先optimizer.step() 根据累计的梯度更新网络参数,然后optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备;
总结来说:梯度累加就是,每次获取1个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。
在合理的范围内,batch size越大训练效果越好,梯度累积变相实现了batch_size的扩大,如果accumulation_steps
为8则batch size变相扩大了8倍,这是解决显存受限的很好的trick方式,不过在使用梯度累积的时候,学习率也要适当的放大。
参考链接
当前分类随机文章推荐
- Pytorch - torch.nn.Conv1d参数详解与使用 阅读2414次,点赞0次
- Pytorch - torch.distributed.init_process_group函数详解 阅读587次,点赞0次
- Pytorch - torch.stack参数详解与使用 阅读728次,点赞0次
- Pytorch - 内置的LSTM网络torch.nn.LSTM参数详解与使用示例 阅读2005次,点赞0次
- Pytorch - 手动调整学习率以及使用torch.optim.lr_scheduler调整学习率 阅读641次,点赞0次
- Pytorch - torch.nn.Module的parameters()和named_parameters() 阅读573次,点赞0次
- Pytorch - torch.cat函数 阅读294次,点赞0次
- Pytorch - 多GPU训练方式nn.DataParallel与nn.parallel.DistributedDataParallel的区别 阅读886次,点赞0次
- Pytorch - torch.nn.Conv2d参数详解与使用 阅读466次,点赞0次
- Pytorch - 梯度累积/梯度累加trick,在显存有限的情况下使用更大batch_size训练模型 阅读359次,点赞0次
全站随机文章推荐
- 书籍翻译 – Fundamentals of Computer Graphics, Fourth Edition,第2章 Miscellaneous Math中文翻译 阅读2705次,点赞14次
- 网站个性化 - 添加人形时钟 honehone_clock.js 阅读3072次,点赞0次
- C++ - 在某一天某个时间点定时执行任务,比如2022年9月19日晚上9点准点执行发送邮件函数 阅读420次,点赞0次
- 工具网站推荐 - 在线的数学公式、几何绘图网站推荐 阅读2355次,点赞0次
- 资源分享 - Computational Geometry:An Introduction(Franco P.Preparata, and Michael Shamos)英文高清PDF下载 阅读3565次,点赞0次
- Pytorch - torch.topk参数详解与使用 阅读260次,点赞0次
- 深度学习 - 图像标准化与归一化方法 阅读625次,点赞0次
- 资源分享 - Interactive Computer Graphics - A top-down approach with WebGL(Seven 7th Edition)英文高清PDF下载 阅读2716次,点赞0次
- 资源分享 - Beginning DirectX 11 Game Programming 英文高清PDF下载 阅读1318次,点赞0次
- 资源分享 - C++程序设计语言(第4部分 标准库),原书第4版 高清PDF下载 阅读2682次,点赞2次
评论
169