Pytorch – torch.unsqueeze和torch.squeeze函数
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – torch.unsqueeze和torch.squeeze函数
原文链接:https://www.stubbornhuang.com/2440/
发布于:2022年12月08日 15:01:13
修改于:2022年12月08日 15:04:11
Pytorch中,unsqueeze
和squeeze
为两个对应的反操作函数,其中,unsqueeze
主要用于为输入张量升维,squeeze
主要用于给张量降维,两者的具体用法可以参考下文。
1 unsqueeze
形式
torch.unsqueeze(input, dim)
或者
Tensor.unsqueeze(dim)
功能
在输入张量input的指定索引插入1维度。
dim的数值处于[-input.dim()-1,input.dim()+1]的区间内,如果dim的值为负,则计算为dim=dim+input.dim()+1。
参数
- input:输入的张量
- dim:插入维度的索引
使用示例:
import torch
if __name__ == '__main__':
a = torch.randn(size=(5, 3, 224, 224))
print(a.shape)
b = torch.unsqueeze(a, 1)
print(b.shape)
c = b.unsqueeze(3)
print(c.shape)
输出
torch.Size([5, 3, 224, 224])
torch.Size([5, 1, 3, 224, 224])
torch.Size([5, 1, 3, 1, 224, 224])
2 squeeze
形式
torch.squeeze(input, dim=None)
或者
Tensor.squeeze(dim=None)
功能
删除input输入张量中所有大小为1的维度。
比如说,输入张量的维度为(A \times 1 \times B \times C \times 1 \times D),那么经过squeeze函数处理之后,输出张量的维度为(A \times B \times C \times D)。
如果指定了dim,那么squeeze算子仅仅只会在给定的dim上操作,比如输入张量维度为(A \times 1 \times B),squeeze(input,0)
不会改变输入张量,只有使用squeeze(input,1)
才会将输出张量(A \times B)。
参数
- input:输入张量
- dim:需要操作的维度索引,如果被指定,则只会在该维度上应用squeeze操作
使用示例:
import torch
if __name__ == '__main__':
a = torch.randn(size=(5, 1, 224, 224, 1))
print(a.shape)
b = torch.squeeze(a)
print(b.shape)
c = b.unsqueeze(3)
print(c.shape)
d = c.squeeze(3)
print(d.shape)
输出
torch.Size([5, 1, 224, 224, 1])
torch.Size([5, 224, 224])
torch.Size([5, 224, 224, 1])
torch.Size([5, 224, 224])
当前分类随机文章推荐
- 深度学习 - 我的深度学习项目代码文件组织结构 阅读1355次,点赞3次
- Pytorch - 训练网络时出现_pickle.UnpicklingError: pickle data was truncated错误 阅读1114次,点赞0次
- Pytorch - 为什么要设置随机数种子? 阅读531次,点赞0次
- Pytorch - 使用pytorch自带的Resnet作为网络的backbone 阅读353次,点赞0次
- Pytorch - 用Pytorch实现ResNet 阅读460次,点赞0次
- Pytorch - torch.nn.Conv1d参数详解与使用 阅读2414次,点赞0次
- Pytorch - torch.chunk参数详解与使用 阅读1095次,点赞0次
- Pytorch - torch.stack参数详解与使用 阅读728次,点赞0次
- Pytorch - 内置的CTC损失函数torch.nn.CTCLoss参数详解与使用示例 阅读1247次,点赞1次
- Python - list/numpy/pytorch tensor相互转换 阅读1697次,点赞0次
全站随机文章推荐
- Mediapipe - 将Mediapipe handtracking封装成动态链接库dll/so,实现在桌面应用中嵌入手势识别功能 阅读10039次,点赞21次
- UnrealEngine4 - C++层打印信息到屏幕 阅读2481次,点赞0次
- C++读取Shp文件并将Shp转化为DXF 阅读3175次,点赞1次
- 宝塔面板 - 安装Php扩展如memcached失败的解决方案 阅读1670次,点赞0次
- Python - 不依赖第三方库对类对象进行json序列化与反序列化 阅读1400次,点赞0次
- 资源分享 - 用Python写网络爬虫(第2版 Katharine Jarmul,Richard Lawson著 李斌译) 阅读1969次,点赞0次
- 资源分享 - Visualizing Quaternions 英文高清PDF下载 阅读1794次,点赞0次
- 深度学习 - 通俗理解Beam Search Algorithm算法 阅读794次,点赞0次
- failed to find an available destination > EOF 阅读18261次,点赞32次
- ThreeJS - three.moudle.js报Uncaught SyntaxError:Unexpected token ‘export‘错误 阅读1922次,点赞0次
评论
169