Pytorch – 修改Pytoch中torchvision.models预置模型的方法
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – 修改Pytoch中torchvision.models预置模型的方法
原文链接:https://www.stubbornhuang.com/2572/
发布于:2023年03月28日 9:47:48
修改于:2023年03月28日 9:47:48
转载自https://chenglu.me/blogs/pytorch-model-modification-part1,少量修改,如侵权,请联系我进行删除。
在深度学习网络构建时,我们可能需要对Pytorch中的torchvision.models中的模型进行一些修改,比如说,将torchvision.models中的Resnet18作为主网络,但是需要修改其最后的全连接层的分类数;再者将torchvision.models中的某个网络作为另一个模型的backbone或者特征提取器,这个时候需要删除全连接或者网络中其他的一些层。
在Pytorch中,模型是torch.nn.Model
的子类的对象,修改模型本质上就是修改子类,对类做修改我们可以使用继承或者组合的方法。
1 通过继承修改模型
首先创建自己需要的模型类,然后其父类指向需要被修改的模型,这时自己的模型则具有完备的父类行为,再在子类中实现魔改的逻辑。
其大致的框架代码如下所示:
from torchvision.models import ResNet
class CustomizedResNet(ResNet):
def __init__(self):
super().__init__()
...
def forward(self, x):
...
下面这个例子,我们将对 ResNet 进行魔改,把 ResNet 4 个 stage 输出的特征连起来,最后过一个全链接后输出一个标量。
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet
class CustomizedResNet(ResNet):
def __init__(self, block, layers, num_classes=2):
super().__init__(block, layers, num_classes)
self.fc = torch.nn.Linear(int(512 * block.expansion * 1.875), num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
x = torch.cat(
[
self.avgpool(x1),
self.avgpool(x2),
self.avgpool(x3),
self.avgpool(x4),
],
dim=1,
)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
new_resnet34 = CustomizedResNet(BasicBlock, [3, 4, 6, 3], num_classes=1)
new_resnet50 = CustomizedResNet(Bottleneck, [3, 4, 6, 3], num_classes=1)
new_resnet101 = CustomizedResNet(Bottleneck, [3, 4, 23, 3], num_classes=1)
new_resnet200 = CustomizedResNet(Bottleneck, [3, 24, 36, 3], num_classes=1)
2 通过组合修改模型
在面向对象编程中我们可能听说过「组合优于继承」,在模型修改的场景中其实也是这样,大多数情况下我们可能都适用组合而非继承。
首先依然需要创建模型的类,但这个类不再继承自魔改的类,而是直接继承 PyTorch 的模型基类 torch.nn.Module
,然后将需要魔改的类作为类变量融入到模型中。
下面是大致的框架代码:
from torchvision.models import resnet18
class CustomizedResNet(torch.nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
...
def forward(self, x):
...
my_resnet18 = CustomizedResNet(resnet18)
我们也使用这种来实现与第1节中相同的模型修改:
from torchvision.models import resnet50
class CustomizedResNet(torch.nn.Module):
def __init__(self, backbone, num_classes=2):
super().__init__()
self.backbone = backbone
self.fc = torch.nn.Linear(3840, num_classes)
def forward(self, x):
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
x1 = self.backbone.layer1(x)
x2 = self.backbone.layer2(x1)
x3 = self.backbone.layer3(x2)
x4 = self.backbone.layer4(x3)
x = torch.cat(
[
self.backbone.avgpool(x1),
self.backbone.avgpool(x2),
self.backbone.avgpool(x3),
self.backbone.avgpool(x4),
],
dim=1,
)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
new_resnet50 = CustomizedResNet(resnet50())
3 通过monkey patch的方式修改模型
monkey patch允许在运行期间动态替换,修改一个类的方法或者模块。
比如说:
class A:
def func(self):
print('hi')
def monkey(self):
print('hi,monkey')
a = A()
a.func()
#运行结果:hi
a.func = a.monkey
a.func()
#运行结果:hi,monkey
这种方式在程序设计的角度是具有破坏性,不是推荐使用的方法。
猴子补丁修改模型非常简单粗暴,直接使用需要修改的模型创建对象,然后直接对对象的属性做出修改,下面是把 ResNet34 的输出从 1000 改为 1 的例子:
from torchvision.models import resnet50
model = resnet50()
model.fc = torch.nn.Linear(2048, 1)
此外这种方法也仅能实现一些简单的需求,对于复杂的需求还是推荐使用组合的方法来完成。
当前分类随机文章推荐
- Pytorch - reshape和view的用法和区别 阅读322次,点赞0次
- Pytorch - 梯度累积/梯度累加trick,在显存有限的情况下使用更大batch_size训练模型 阅读358次,点赞0次
- Pytorch - torch.topk参数详解与使用 阅读260次,点赞0次
- Pytorch - 模型断点续训,optimizer.step()报错:RuntimeError Expected all tensors to be on the same device, but found cuda:0 阅读61次,点赞0次
- 深度学习 - 我的深度学习项目代码文件组织结构 阅读1354次,点赞3次
- Pytorch - 使用pytorch自带的Resnet作为网络的backbone 阅读353次,点赞0次
- Pytorch - .to()和.cuda()的区别 阅读806次,点赞0次
- Pytorch - nn.Transformer、nn.TransformerEncoderLayer、nn.TransformerEncoder、nn.TransformerDecoder、nn.TransformerDecoder参数详解 阅读2642次,点赞1次
- Pytorch - 内置的LSTM网络torch.nn.LSTM参数详解与使用示例 阅读2005次,点赞0次
- Pytorch - 没有使用with torch.no_grad()造成测试网络时显存爆炸的问题 阅读554次,点赞0次
全站随机文章推荐
- 资源分享 - GPU Pro 360 - Guide to Image Space 英文高清PDF下载 阅读2017次,点赞1次
- 资源分享 - Geometric tools for computer graphics(Philip J. Schneider, and David H. Eberly)英文高清PDF下载 阅读3442次,点赞0次
- 资源分享 - GPU Pro 4 - Advanced Rendering Techniques 英文高清PDF下载 阅读2675次,点赞1次
- Pytorch - RuntimeError: No rendezvous handler for env://错误 阅读1018次,点赞0次
- WordPress - 获取某个用户发表的文章数量 阅读1874次,点赞0次
- 资源分享 - Computer Facial Animation , Second Edition 英文高清PDF下载 阅读1850次,点赞0次
- Youtube运营 - Youtube做哪些视频内容可以快速通过人工审核 阅读287次,点赞1次
- Python - 使用onnxruntime加载和推理onnx模型 阅读346次,点赞0次
- 资源分享 - 3D Graphics Rendering Cookbook - A comprehensive guide to exploring rendering algorithms in modern OpenGL and Vulkan 英文高清PDF下载 阅读4177次,点赞1次
- 深度学习 - 卷积神经网络CNN简介 阅读531次,点赞0次
评论
169