Pytorch – 模型断点续训,optimizer.step()报错:RuntimeError Expected all tensors to be on the same device, but found cuda:0
1 模型断点续训,optimizer.step()报错:RuntimeError Expected all tensors to be on the same device, but found cuda:0
Pytroch在实现断点续训功能时,在保存模型文件时,需要同时保存model、optimizer、lr_scheduler的state_dict,比如
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.lr_scheduler.state_dict(),
}, model_save_path)
然后在加载模型时,除了加载模型的权重之外,还需要同时加载optimizer和lr_scheduler的权重,比如
model_weights = modified_weights(check_point_state_dict['model_state_dict'])
optimizer.load_state_dict(check_point_state_dict["optimizer_state_dict"])
lr_scheduler.load_state_dict(check_point_state_dict["scheduler_state_dict"])
这个时候比较容易犯的错误是,optimizer默认是在cpu上加载权重的,而我们之后继续训练模型时都是在GPU上进行了,所以如果optimizer没有任何修改,则会出在optimizer.step()
执行时出现
RuntimeError: Expected all tensors to be on the same device, but found cuda:0
其实际上就是optimizer的权重没有在GPU上,所以解决方法就是将optimizer的权重转移到GPU上,示例代码如下
optimizer.load_state_dict(check_point_state_dict["optimizer_state_dict"])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(self.output_device)
其中self.output_device
就是项目中的GPU索引号。
修改完成之后,错误解决。
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – 模型断点续训,optimizer.step()报错:RuntimeError Expected all tensors to be on the same device, but found cuda:0
原文链接:https://www.stubbornhuang.com/2603/
发布于:2023年05月08日 11:03:42
修改于:2023年05月08日 11:04:20
当前分类随机文章推荐
- Pytorch - 没有使用with torch.no_grad()造成测试网络时显存爆炸的问题 阅读817次,点赞0次
- Pytorch - 训练网络时出现_pickle.UnpicklingError: pickle data was truncated错误 阅读2232次,点赞1次
- Pytorch - 一文搞懂如何使用Pytorch构建与训练自定义深度学习网络(数据集自定义与加载,模型训练,模型测试,模型保存与加载) 阅读1378次,点赞2次
- Pytorch - 内置的CTC损失函数torch.nn.CTCLoss参数详解与使用示例 阅读2198次,点赞1次
- Pytorch - 使用torch.matmul()替换torch.einsum('nctw,cd->ndtw',(a,b))算子模式 阅读3222次,点赞1次
- Pytorch - masked_fill方法参数详解与使用 阅读1301次,点赞0次
- Pytorch - torch.nn.Conv1d参数详解与使用 阅读4017次,点赞0次
- Pytorch - Pytoch结合Tensorboard实现数据可视化 阅读442次,点赞0次
- Pytorch – 使用torch.matmul()替换torch.einsum('bhxyd,md->bhxym',(a,b))算子模式 阅读1664次,点赞0次
- Pytorch - torch.optim优化器 阅读1146次,点赞0次
全站随机文章推荐
- C++ – UTF8编码下的全角字符转半角字符 阅读2431次,点赞0次
- C++ – 字节数组byte[]或者unsigned char[]与bool的相互转换 阅读1407次,点赞1次
- Python - 运行YOLOv5出现AttributeError: module 'torchvision' has no attribute 'ops' 阅读2663次,点赞1次
- 资源分享 - Ray Tracing Gems - High-Quality and Real-Time Rendering with DXR and Other APIs 英文高清PDF下载 阅读2984次,点赞0次
- Pytorch - 模型断点续训,optimizer.step()报错:RuntimeError Expected all tensors to be on the same device, but found cuda:0 阅读313次,点赞0次
- 资源分享 - 深度学习的数学 (涌井良幸 涌井贞美著) 高清PDF下载 阅读5683次,点赞3次
- 资源分享 - GPU Pro 360 - Guide to Image Space 英文高清PDF下载 阅读2391次,点赞1次
- 深度学习 - 动作识别Action Recognition最重要的问题 阅读682次,点赞1次
- WordPress - 网站加载自定义字体的最佳方式 阅读176次,点赞1次
- 资源分享 - Computational Geometry:An Introduction(Franco P.Preparata, and Michael Shamos)英文高清PDF下载 阅读3884次,点赞0次
评论
169