Pytorch – 没有使用with torch.no_grad()造成测试网络时显存爆炸的问题
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – 没有使用with torch.no_grad()造成测试网络时显存爆炸的问题
原文链接:https://www.stubbornhuang.com/2322/
发布于:2022年08月23日 10:29:14
修改于:2022年08月23日 10:29:14

1 显存爆炸的问题
最近使用以下示例代码测试自定义深度学习网络时耗光了所有显存,出现了梯度爆炸的问题。
model.eval()
for batch_idx, data in enumerate(tqdm(data_loader)):
image = data[0].to('cuda:0')
......
经过排查原因是没有加上with torch.no_grad()
语句停止梯度更新,从而导致了显存爆炸的问题,正确的示例代码如下
model.eval()
with torch.no_grad()
for batch_idx, data in enumerate(tqdm(data_loader)):
image = data[0].to('cuda:0')
......
2 model.train、model.eval和with torch.no_grad
model.train()
会将网络中的模块设置为训练模式,此时,如果神经网络中BN(batch normalization)层和Dropout层,那么这两个层将会起作用,防止网络出现过拟合的问题。
model.eval()
则会将网络设置为测试模式,此时,不会启用神经网络中的BN(batch normalization)层和Dropout层,model.eval()是保证BN层直接使用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout层,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。
with torch.no_grad()
会将网络中Tensor的属性全部设置为False,并停止Autograd引擎,禁止梯度反向传播,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。因此,测试的时候加上此语句也不会影响测试精度的,只是停止了梯度更新而已。在测试和验证阶段,使用with torch.no_grad()
会使得网络有更快的推理速度和内存使用,这使得我们在验证和测试网络时可以使用更大的batch size。
当前分类随机文章推荐
- Pytorch - 使用pytorch自带的Resnet作为网络的backbone 阅读210次,点赞0次
- Pytorch - 使用torchsummary/torchsummaryX/torchinfo库打印模型结构、输出维度和参数信息 阅读1263次,点赞1次
- Pytorch - 没有使用with torch.no_grad()造成测试网络时显存爆炸的问题 阅读434次,点赞0次
- Pytorch - torch.cat函数 阅读191次,点赞0次
- Pytorch - 训练网络时出现_pickle.UnpicklingError: pickle data was truncated错误 阅读748次,点赞0次
- Pytorch - torch.chunk参数详解与使用 阅读888次,点赞0次
- Pytorch - RuntimeError: No rendezvous handler for env://错误 阅读791次,点赞0次
- Pytorch - reshape和view的用法和区别 阅读209次,点赞0次
- Pytorch - torch.optim优化器 阅读546次,点赞0次
- Pytorch - 内置的LSTM网络torch.nn.LSTM参数详解与使用示例 阅读1483次,点赞0次
全站随机文章推荐
- 资源分享 - 用Python写网络爬虫(第2版 Katharine Jarmul,Richard Lawson著 李斌译) 阅读1850次,点赞0次
- WordPress - get_footer函数,加载主题底部页脚footer模板 阅读753次,点赞0次
- C++ - 在Windows/Linux上创建单级目录以及多级目录的跨平台方法 阅读922次,点赞0次
- FFmpge - Ubuntu编译FFmpeg出现WARNING: pkg-config not found, library detection may fail警告 阅读3603次,点赞0次
- Pytorch - 一文搞懂如何使用Pytorch构建与训练自定义深度学习网络(数据集自定义与加载,模型训练,模型测试,模型保存与加载) 阅读1019次,点赞2次
- 工具API推荐 - 通过QQ号获取QQ头像 阅读1186次,点赞0次
- Modern OpenGL从零开始 - 从茫茫多的OpenGL第三方库讲起 阅读3415次,点赞1次
- 宝塔面板 - 安装Php扩展如memcached失败的解决方案 阅读1467次,点赞0次
- Pytorch – 使用torch.matmul()替换torch.einsum('bhxyd,md->bhxym',(a,b))算子模式 阅读915次,点赞0次
- WordPress - get_sidebar函数,加载主题侧边栏模板 阅读812次,点赞0次
评论
167