1 显存爆炸的问题

最近使用以下示例代码测试自定义深度学习网络时耗光了所有显存,出现了梯度爆炸的问题。

model.eval()
for batch_idx, data in enumerate(tqdm(data_loader)):
    image = data[0].to('cuda:0')
    ......

经过排查原因是没有加上with torch.no_grad()语句停止梯度更新,从而导致了显存爆炸的问题,正确的示例代码如下

model.eval()
with torch.no_grad()
    for batch_idx, data in enumerate(tqdm(data_loader)):
        image = data[0].to('cuda:0')
        ......

2 model.train、model.eval和with torch.no_grad

model.train()会将网络中的模块设置为训练模式,此时,如果神经网络中BN(batch normalization)层和Dropout层,那么这两个层将会起作用,防止网络出现过拟合的问题。

model.eval()则会将网络设置为测试模式,此时,不会启用神经网络中的BN(batch normalization)层和Dropout层,model.eval()是保证BN层直接使用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout层,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

with torch.no_grad()会将网络中Tensor的属性全部设置为False,并停止Autograd引擎,禁止梯度反向传播,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。因此,测试的时候加上此语句也不会影响测试精度的,只是停止了梯度更新而已。在测试和验证阶段,使用with torch.no_grad()会使得网络有更快的推理速度和内存使用,这使得我们在验证和测试网络时可以使用更大的batch size。