• 问题反馈可发送邮件到stubbornhuang@qq.com

  • 工资「喂饱肚子」,副业「养活灵魂」!

  • 本站会放置Google广告用于维持域名以及网站服务器费用。

  • 在本站开通年度VIP,无限制下载本站资源和阅读本站文章

  • 如果觉得本站的内容有帮助,可以考虑打赏博主哦!

  • 感谢大家访问本站,希望本站的内容可以帮助到大家!

  • 欢迎大家交换友链,可在https://www.stubbornhuang.com/申请友情链接进行友链交换申请!

  • 计算机图形学与计算几何经典必备书单整理,下载链接可参考:https://www.stubbornhuang.com/1256/

  • 本站由于前段时间遭受到大量临时和国外邮箱注册,所以对可注册的邮箱类型进行了限制!

Pytorch – 修改Pytoch中torchvision.models预置模型的方法

Pytorch 发布于2023-03-28 阅读 3,394次 0次评论 0次点赞 本文共2821个字,阅读需要8分钟。

转载自https://chenglu.me/blogs/pytorch-model-modification-part1,少量修改,如侵权,请联系我进行删除。

在深度学习网络构建时,我们可能需要对Pytorch中的torchvision.models中的模型进行一些修改,比如说,将torchvision.models中的Resnet18作为主网络,但是需要修改其最后的全连接层的分类数;再者将torchvision.models中的某个网络作为另一个模型的backbone或者特征提取器,这个时候需要删除全连接或者网络中其他的一些层。

在Pytorch中,模型是torch.nn.Model的子类的对象,修改模型本质上就是修改子类,对类做修改我们可以使用继承或者组合的方法。

1 通过继承修改模型

首先创建自己需要的模型类,然后其父类指向需要被修改的模型,这时自己的模型则具有完备的父类行为,再在子类中实现魔改的逻辑。

其大致的框架代码如下所示:

from torchvision.models import ResNet


class CustomizedResNet(ResNet):

    def __init__(self):
        super().__init__()
        ...

    def forward(self, x):
        ...

下面这个例子,我们将对 ResNet 进行魔改,把 ResNet 4 个 stage 输出的特征连起来,最后过一个全链接后输出一个标量。

from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet


class CustomizedResNet(ResNet):
    def __init__(self, block, layers, num_classes=2):
        super().__init__(block, layers, num_classes)
        self.fc = torch.nn.Linear(int(512 * block.expansion * 1.875), num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        x = torch.cat(
            [
                self.avgpool(x1),
                self.avgpool(x2),
                self.avgpool(x3),
                self.avgpool(x4),
            ],
            dim=1,
        )

        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


new_resnet34 = CustomizedResNet(BasicBlock, [3, 4, 6, 3], num_classes=1)
new_resnet50 = CustomizedResNet(Bottleneck, [3, 4, 6, 3], num_classes=1)
new_resnet101 = CustomizedResNet(Bottleneck, [3, 4, 23, 3], num_classes=1)
new_resnet200 = CustomizedResNet(Bottleneck, [3, 24, 36, 3], num_classes=1)

2 通过组合修改模型

在面向对象编程中我们可能听说过「组合优于继承」,在模型修改的场景中其实也是这样,大多数情况下我们可能都适用组合而非继承。

首先依然需要创建模型的类,但这个类不再继承自魔改的类,而是直接继承 PyTorch 的模型基类 torch.nn.Module,然后将需要魔改的类作为类变量融入到模型中。

下面是大致的框架代码:

from torchvision.models import resnet18


class CustomizedResNet(torch.nn.Module):

    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        ...

    def forward(self, x):
        ...


my_resnet18 = CustomizedResNet(resnet18)

我们也使用这种来实现与第1节中相同的模型修改:

from torchvision.models import resnet50


class CustomizedResNet(torch.nn.Module):
    def __init__(self, backbone, num_classes=2):
        super().__init__()
        self.backbone = backbone
        self.fc = torch.nn.Linear(3840, num_classes)

    def forward(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x1 = self.backbone.layer1(x)
        x2 = self.backbone.layer2(x1)
        x3 = self.backbone.layer3(x2)
        x4 = self.backbone.layer4(x3)

        x = torch.cat(
            [
                self.backbone.avgpool(x1),
                self.backbone.avgpool(x2),
                self.backbone.avgpool(x3),
                self.backbone.avgpool(x4),
            ],
            dim=1,
        )

        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


new_resnet50 = CustomizedResNet(resnet50())

3 通过monkey patch的方式修改模型

monkey patch允许在运行期间动态替换,修改一个类的方法或者模块。

比如说:

class A:
    def func(self):
        print('hi')
    def monkey(self):
        print('hi,monkey')

a = A()
a.func()
#运行结果:hi
a.func = a.monkey
a.func()
#运行结果:hi,monkey

这种方式在程序设计的角度是具有破坏性,不是推荐使用的方法

猴子补丁修改模型非常简单粗暴,直接使用需要修改的模型创建对象,然后直接对对象的属性做出修改,下面是把 ResNet34 的输出从 1000 改为 1 的例子:

from torchvision.models import resnet50

model = resnet50()
model.fc = torch.nn.Linear(2048, 1)

此外这种方法也仅能实现一些简单的需求,对于复杂的需求还是推荐使用组合的方法来完成。

欢迎扫码关注我的微信公众号,及时获取文章更新

微信公众号二维码

本文作者:StubbornHuang

版权声明:本文为站长原创文章,如果转载请注明原文链接!

原文标题:Pytorch – 修改Pytoch中torchvision.models预置模型的方法

原文链接:https://www.stubbornhuang.com/2572/

发布于:2023年03月28日 9:47:48

修改于:2023年03月28日 9:47:48

声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。

文章末尾
上一篇
OnnxRuntime – 模型部署笔记3,总结OnnxRuntime模型推理流程
OnnxRuntime
下一篇
OpenCV - 在图像处理教程使用最广的测试美女图片Lenna,Lenna原图
OpenCV
当前分类随机文章推荐

发表评论

您必须 [ 登录 ] 才能发表留言!

关注我们的公众号

微信公众号