Pytorch – 使用pytorch自带的Resnet作为网络的backbone
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – 使用pytorch自带的Resnet作为网络的backbone
原文链接:https://www.stubbornhuang.com/2468/
发布于:2023年01月06日 13:31:19
修改于:2023年01月09日 15:04:01
在使用Pytorch搭建自己的神经网络框架时,经常需要使用Pytorch中内置的torchvision.models
中的模型作为特征提取的Backbone,然后再在这个基础上进行更加复杂的网络搭建。
在这里以使用Pytorch中内置的Resnet18为例,如何作为Backbone层进行使用,看以下示例代码
# -*- coding: utf-8 -*-
import torch.nn as nn
import torchvision
class Resnet18Backbone(nn.Module):
def __init__(self):
super(Resnet18Backbone, self).__init__()
self.model = torchvision.models.resnet18(pretrained=True)
self.model.fc = nn.Sequential()
def forward(self, x):
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = self.model.avgpool(x)
return x
使用上述代码,如果输入Tensor的维度为[1,3,244,244],fowward
输出的Tensor的维度为[1,512,1,1],如果我们需要输出的Tensor维度为[1,512],需要squeeze
相应的维度,修改后的代码如下
# -*- coding: utf-8 -*-
import torch.nn as nn
import torchvision
class Resnet18Backbone(nn.Module):
def __init__(self):
super(Resnet18Backbone, self).__init__()
self.model = torchvision.models.resnet18(pretrained=True)
self.model.fc = nn.Sequential()
def forward(self, x):
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = self.model.avgpool(x)
x = x.squeeze(2).squeeze(2)
return x
好了,上述代码的Resnet18Backbone
可以作为网络中的一层进行使用,这里都是以ResNet的Adaptive Average Pooling层作为backbone的输出层,如果我们仅仅需要前面的卷积层作为输出层,可以参考以下代码。
比如,如果我们要使用ResNet18的Adaptive Average Pooling作为backbone的输出层,我们可以这样写,
# backbone
if backbone_name == 'resnet_18':
resnet_net = torchvision.models.resnet18(pretrained=True)
modules = list(resnet_net.children())[:-1]
backbone = nn.Sequential(*modules)
backbone.out_channels = 512
elif backbone_name == 'resnet_34':
resnet_net = torchvision.models.resnet34(pretrained=True)
modules = list(resnet_net.children())[:-1]
backbone = nn.Sequential(*modules)
backbone.out_channels = 512
elif backbone_name == 'resnet_50':
resnet_net = torchvision.models.resnet50(pretrained=True)
modules = list(resnet_net.children())[:-1]
backbone = nn.Sequential(*modules)
backbone.out_channels = 2048
elif backbone_name == 'resnet_101':
resnet_net = torchvision.models.resnet101(pretrained=True)
modules = list(resnet_net.children())[:-1]
backbone = nn.Sequential(*modules)
backbone.out_channels = 2048
elif backbone_name == 'resnet_152':
resnet_net = torchvision.models.resnet152(pretrained=True)
modules = list(resnet_net.children())[:-1]
backbone = nn.Sequential(*modules)
backbone.out_channels = 2048
elif backbone_name == 'resnet_50_modified_stride_1':
resnet_net = resnet50(pretrained=True)
modules = list(resnet_net.children())[:-1]
backbone = nn.Sequential(*modules)
backbone.out_channels = 2048
elif backbone_name == 'resnext101_32x8d':
resnet_net = torchvision.models.resnext101_32x8d(pretrained=True)
modules = list(resnet_net.children())[:-1]
backbone = nn.Sequential(*modules)
backbone.out_channels = 2048
如果我们仅仅只是需要前面的卷积层作为backbone,我们可以这样写
# backbone
if backbone_name == 'resnet_18':
resnet_net = torchvision.models.resnet18(pretrained=True)
modules = list(resnet_net.children())[:-2]
backbone = nn.Sequential(*modules)
elif backbone_name == 'resnet_34':
resnet_net = torchvision.models.resnet34(pretrained=True)
modules = list(resnet_net.children())[:-2]
backbone = nn.Sequential(*modules)
elif backbone_name == 'resnet_50':
resnet_net = torchvision.models.resnet50(pretrained=True)
modules = list(resnet_net.children())[:-2]
backbone = nn.Sequential(*modules)
elif backbone_name == 'resnet_101':
resnet_net = torchvision.models.resnet101(pretrained=True)
modules = list(resnet_net.children())[:-2]
backbone = nn.Sequential(*modules)
elif backbone_name == 'resnet_152':
resnet_net = torchvision.models.resnet152(pretrained=True)
modules = list(resnet_net.children())[:-2]
backbone = nn.Sequential(*modules)
elif backbone_name == 'resnet_50_modified_stride_1':
resnet_net = resnet50(pretrained=True)
modules = list(resnet_net.children())[:-2]
backbone = nn.Sequential(*modules)
elif backbone_name == 'resnext101_32x8d':
resnet_net = torchvision.models.resnext101_32x8d(pretrained=True)
modules = list(resnet_net.children())[:-2]
backbone = nn.Sequential(*modules)
参考链接
当前分类随机文章推荐
- Pytorch - 没有使用with torch.no_grad()造成测试网络时显存爆炸的问题 阅读575次,点赞0次
- Pytorch - 训练网络时出现_pickle.UnpicklingError: pickle data was truncated错误 阅读1157次,点赞0次
- Pytorch - 内置的LSTM网络torch.nn.LSTM参数详解与使用示例 阅读2071次,点赞0次
- Pytorch - masked_fill方法参数详解与使用 阅读776次,点赞0次
- Pytorch - 手动调整学习率以及使用torch.optim.lr_scheduler调整学习率 阅读671次,点赞0次
- Pytorch - torch.distributed.init_process_group函数详解 阅读617次,点赞0次
- Pytorch - torch.optim优化器 阅读696次,点赞0次
- Pytorch - 用Pytorch实现ResNet 阅读501次,点赞0次
- Pytorch - 使用torch.matmul()替换torch.einsum('nctw,cd->ndtw',(a,b))算子模式 阅读2266次,点赞1次
- Pytorch - torch.nn.Conv1d参数详解与使用 阅读2530次,点赞0次
全站随机文章推荐
- Windows - 虚拟按键Virtual-Key Codes大全 阅读3567次,点赞0次
- C++ - 拷贝构造函数与拷贝构造函数调用时机 阅读389次,点赞0次
- 资源分享 - Speech and Language Processing - An Introduction to Natural Language Processing, Computational Linguistics, and Speech Recognition , Third Edition draft 英文高清PDF下载 阅读659次,点赞0次
- 资源分享 - Essential Mathematics for Games and Interactive Applications(First Edition) 英文高清PDF下载 阅读1921次,点赞0次
- OpenCV - 使用findContours()查找图片轮廓线,并将轮廓线坐标点输出 阅读4913次,点赞0次
- 资源分享 - OpenGL 4.0 Shading Language Cookbook (First Edition) 英文高清PDF下载 阅读1714次,点赞0次
- 资源分享 - Physically Based Rendering From Theory To Implementation (First Edition)英文高清PDF下载 阅读3007次,点赞0次
- Python – 解决opencv-python使用cv2.imwrite()保存中文路径图片失败的问题 阅读1761次,点赞0次
- C++ - C++实现Python numpy的矩阵维度转置算法,例如(N,H,W,C)转换为(N,C,H,W) 阅读4152次,点赞4次
- Python - glob模块详解以及glob.glob、glob.iglob函数的使用 阅读1113次,点赞0次
评论
169