1 batch size对模型训练的影响

小的batch size引入的数据集的数据量较小,随机性越大,在部分情况下模型难以收敛,影响模型训练效率。

而在合理的范围内,越大的batch size本质上是对训练数据更优的一种选择,能够是梯度下降的方向更加准确,震荡越小,有利于收敛的稳定性。

但是如果batch size过大,超出了一个合理的范围,会限制模型的探索能力,出现局部最优的情况。

所以,在模型训练的过程中,bacth size太小和太大都不是一个好的选择。

2 为什么要使用梯度累积(gradient accumulation)的方案?

深度学习发展到今天,数据已经从图片,文本逐渐发展到了音频、视频这种更复杂,更需要高维度表示的数据。

试想一下,如果我们需要基于一个超百万个的视频数据集进行模型训练,需要使用Resnet50或者Resnet101作为backbone对每一个视频的视频帧提取特征并进行组合,假设每个视频有250帧,那么我们使用batch size=2进行训练,就等于我们需要对500张图片都进行Resnet50或者Resnet101进行计算,而如果我们只有一张3090(24G),那么在这种情况下24G的显存显然是不够用的。

在显卡的显存不够多,不足以支撑大的batch_size数据训练的情况下,而我们又想使用大的batch_size进行模型的训练,那么这个时候我们就可以使用梯度累计的方式进行优化,以防止显存爆炸。

3 在Pytorch中使用梯度累积

在Pytorch中反向传播梯度是不清零的,所以要实现梯度累积是比较简单的。

3.1 常用的训练模式

在Pytorch中,训练一个epoch训练的常用代码如下

for batch_idx, (input_id, label) in enumerate(train_loader):
    # 1. 模型输出
    pred = model(input_id)
    loss = criterion(pred, label)

    # 2. 反向传播
    optimizer.zero_grad() # 梯度清空   

    loss.backward() # 反向传播,计算梯度

    optimizer.step() # 根据梯度,更新网络参数

总结步骤如下:

  1. 计算loss,获取batch输入,计算model输出,通过损失函数计算loss
  2. optimizer.zero_grad() 清空之前的梯度
  3. loss.backward()反向传播,计算当前梯度
  4. optimizer.step()根据梯度更新网络参数

简单来说,就是进来一个batch的数据,计算一次梯度,更新一次网络。

3.2 梯度累积

我们将上述代码修改为梯度累积的方式

for batch_idx ,(input_id, label) in enumerate(train_loader):
    # 1. 模型输出
    pred = model(input_id)
    loss = criterion(pred, label)

    # 2.1 损失规范
    loss = loss / accumulation_steps  

    # 2.2 反向传播,计算梯度
    loss.backward()

    if (batch_idx+1) % accumulation_steps == 0:
        optimizer.step() # 更新参数

        optimizer.zero_grad() # 梯度清空,为下一次反向传播做准备

总结步骤如下:

  1. 计算loss,获取batch输入,计算model输出,通过损失函数计算loss
  2. loss.backward()反向传播,计算当前梯度
  3. 多次循环步骤 1-2,不清空梯度,使梯度累加在已有梯度上;
  4. 梯度累加了一定次数后,先optimizer.step() 根据累计的梯度更新网络参数,然后optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备;

总结来说:梯度累加就是,每次获取1个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。

在合理的范围内,batch size越大训练效果越好,梯度累积变相实现了batch_size的扩大,如果accumulation_steps为8则batch size变相扩大了8倍,这是解决显存受限的很好的trick方式,不过在使用梯度累积的时候,学习率也要适当的放大。

参考链接