Pytorch中,unsqueezesqueeze为两个对应的反操作函数,其中,unsqueeze主要用于为输入张量升维,squeeze主要用于给张量降维,两者的具体用法可以参考下文。

1 unsqueeze

形式

torch.unsqueeze(input, dim)

或者

Tensor.unsqueeze(dim)

功能

在输入张量input的指定索引插入1维度。

dim的数值处于[-input.dim()-1,input.dim()+1]​的区间内,如果dim的值为负,则计算为dim=dim+input.dim()+1。

参数

  • input:输入的张量
  • dim:插入维度的索引

使用示例:

import torch

if __name__ == '__main__':
    a = torch.randn(size=(5, 3, 224, 224))
    print(a.shape)

    b = torch.unsqueeze(a, 1)
    print(b.shape)

    c = b.unsqueeze(3)
    print(c.shape)

输出

torch.Size([5, 3, 224, 224])
torch.Size([5, 1, 3, 224, 224])
torch.Size([5, 1, 3, 1, 224, 224])

2 squeeze

形式

torch.squeeze(input, dim=None)

或者

Tensor.squeeze(dim=None)

功能

删除input输入张量中所有大小为1的维度。

比如说,输入张量的维度为(A \times 1 \times B \times C \times 1 \times D),那么经过squeeze函数处理之后,输出张量的维度为(A \times B \times C \times D)

如果指定了dim,那么squeeze算子仅仅只会在给定的dim上操作,比如输入张量维度为(A \times 1 \times B)squeeze(input,0)不会改变输入张量,只有使用squeeze(input,1)才会将输出张量(A \times B)

参数

  • input:输入张量
  • dim:需要操作的维度索引,如果被指定,则只会在该维度上应用squeeze操作

使用示例:

import torch

if __name__ == '__main__':
    a = torch.randn(size=(5, 1, 224, 224, 1))
    print(a.shape)

    b = torch.squeeze(a)
    print(b.shape)

    c = b.unsqueeze(3)
    print(c.shape)

    d = c.squeeze(3)
    print(d.shape)

输出

torch.Size([5, 1, 224, 224, 1])
torch.Size([5, 224, 224])
torch.Size([5, 224, 224, 1])
torch.Size([5, 224, 224])