1 torch.stack参数详解与使用

1.1 torch.stack

1.函数形式

torch.stack(tensors, dim=0, *, out=None) → Tensor

2.函数功能

沿指定维度连接Tensor序列,所有的Tensor必须是同样大小

3.函数参数

  • tensors:Tensor序列,需要连接的Tensor序列
  • dim:int类型,连接Tensor的维度,必须在0和所需连接的Tensor维度之间

4.函数返回值

返回连接后的Tensor

1.2 torch.stack的使用

1.2.1 torch.stack连接一维Tensor序列

import torch

if __name__ == '__main__':
    tensor1 = torch.tensor([1,2,3,4])
    tensor2 = torch.tensor([5,6,7,8])
    tensor3 = torch.tensor([9,10,11,12])
    tensor4 = torch.tensor([13,14,15,16])

    output0 = torch.stack([tensor1,tensor2,tensor3,tensor4],dim=0)
    output1 = torch.stack([tensor1, tensor2, tensor3, tensor4], dim=1)

    print('torch.stack dim=0:{0},{1}'.format(output0,output0.shape))
    print('torch.stack dim=1:{0},{1}'.format(output1, output1.shape))

输出

torch.stack dim=0:tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]]),torch.Size([4, 4])
torch.stack dim=1:tensor([[ 1,  5,  9, 13],
        [ 2,  6, 10, 14],
        [ 3,  7, 11, 15],
        [ 4,  8, 12, 16]]),torch.Size([4, 4])

1.2.2 torch.stack连接二维Tensor序列

import torch

if __name__ == '__main__':
    tensor1 = torch.tensor([1,2,3,4]).view(2,2)
    tensor2 = torch.tensor([5,6,7,8]).view(2,2)
    tensor3 = torch.tensor([9,10,11,12]).view(2,2)
    tensor4 = torch.tensor([13,14,15,16]).view(2,2)

    output0 = torch.stack([tensor1,tensor2,tensor3,tensor4],dim=0)
    output1 = torch.stack([tensor1, tensor2, tensor3, tensor4], dim=1)
    output2 = torch.stack([tensor1, tensor2, tensor3, tensor4], dim=2)

    print('torch.stack dim=0:{0},{1}'.format(output0, output0.shape))
    print('torch.stack dim=1:{0},{1}'.format(output1, output1.shape))
    print('torch.stack dim=2:{0},{1}'.format(output2, output2.shape))

输出

torch.stack dim=0:tensor([[[ 1,  2],
         [ 3,  4]],

        [[ 5,  6],
         [ 7,  8]],

        [[ 9, 10],
         [11, 12]],

        [[13, 14],
         [15, 16]]]),torch.Size([4, 2, 2])
torch.stack dim=1:tensor([[[ 1,  2],
         [ 5,  6],
         [ 9, 10],
         [13, 14]],

        [[ 3,  4],
         [ 7,  8],
         [11, 12],
         [15, 16]]]),torch.Size([2, 4, 2])
torch.stack dim=2:tensor([[[ 1,  5,  9, 13],
         [ 2,  6, 10, 14]],

        [[ 3,  7, 11, 15],
         [ 4,  8, 12, 16]]]),torch.Size([2, 2, 4])

1.2.3 torch.stack连接三维Tensor序列

import torch

if __name__ == '__main__':
    tensor1 = torch.arange(0,8).view(2,2,2)
    tensor2 = torch.arange(8,16).view(2,2,2)
    tensor3 = torch.arange(16,24).view(2,2,2)
    tensor4 = torch.arange(24,32).view(2,2,2)

    output0 = torch.stack([tensor1,tensor2,tensor3,tensor4],dim=0)
    output1 = torch.stack([tensor1, tensor2, tensor3, tensor4], dim=1)
    output2 = torch.stack([tensor1, tensor2, tensor3, tensor4], dim=2)
    output3 = torch.stack([tensor1, tensor2, tensor3, tensor4], dim=3)

    print('torch.stack dim=0:{0},{1}'.format(output0, output0.shape))
    print('torch.stack dim=1:{0},{1}'.format(output1, output1.shape))
    print('torch.stack dim=2:{0},{1}'.format(output2, output2.shape))
    print('torch.stack dim=3:{0},{1}'.format(output3, output3.shape))

输出

torch.stack dim=0:tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]]],


        [[[ 8,  9],
          [10, 11]],

         [[12, 13],
          [14, 15]]],


        [[[16, 17],
          [18, 19]],

         [[20, 21],
          [22, 23]]],


        [[[24, 25],
          [26, 27]],

         [[28, 29],
          [30, 31]]]]),torch.Size([4, 2, 2, 2])
torch.stack dim=1:tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 8,  9],
          [10, 11]],

         [[16, 17],
          [18, 19]],

         [[24, 25],
          [26, 27]]],


        [[[ 4,  5],
          [ 6,  7]],

         [[12, 13],
          [14, 15]],

         [[20, 21],
          [22, 23]],

         [[28, 29],
          [30, 31]]]]),torch.Size([2, 4, 2, 2])
torch.stack dim=2:tensor([[[[ 0,  1],
          [ 8,  9],
          [16, 17],
          [24, 25]],

         [[ 2,  3],
          [10, 11],
          [18, 19],
          [26, 27]]],


        [[[ 4,  5],
          [12, 13],
          [20, 21],
          [28, 29]],

         [[ 6,  7],
          [14, 15],
          [22, 23],
          [30, 31]]]]),torch.Size([2, 2, 4, 2])
torch.stack dim=3:tensor([[[[ 0,  8, 16, 24],
          [ 1,  9, 17, 25]],

         [[ 2, 10, 18, 26],
          [ 3, 11, 19, 27]]],


        [[[ 4, 12, 20, 28],
          [ 5, 13, 21, 29]],

         [[ 6, 14, 22, 30],
          [ 7, 15, 23, 31]]]]),torch.Size([2, 2, 2, 4])