1 torch.cat参数详解与使用

1.1 torch.cat

1.函数形式

torch.cat(tensors, dim=0, *, out=None) → Tensor

2.函数功能

在指定的维度串联指定Tensor序列,所有Tensor都必须具有相同的形状(连接维度除外),或者Tensor为空。

torch.cat可以看作是torch.split或者torch.chunk的反操作。

3.函数参数

  • tensors:Tensor序列,除cat的维度形状可以不同之外,输入的Tensor必须类型相同且形状相同
  • dim:int类型,Tensor序列连接的维度

4.函数返回值

返回串联好的Tensor

1.2 torch.cat的使用

1.2.1 torch.cat串联一维Tensor序列

import torch

if __name__ == '__main__':
    tensor1 = torch.tensor([1,2,3,4])
    tensor2 = torch.tensor([5,6,7,8])
    tensor3 = torch.tensor([9,10,11,12])
    tensor4 = torch.tensor([13,14,15,16])

    output0 = torch.cat([tensor1,tensor2,tensor3,tensor4],dim=0)

    print(output0)

输出

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16])

1.2.2 torch.cat串联二维Tensor序列

import torch

if __name__ == '__main__':
    tensor1 = torch.tensor([1,2,3,4]).view(2,2)
    tensor2 = torch.tensor([5,6,7,8]).view(2,2)
    tensor3 = torch.tensor([9,10,11,12]).view(2,2)
    tensor4 = torch.tensor([13,14,15,16]).view(2,2)

    output0 = torch.cat([tensor1,tensor2,tensor3,tensor4],dim=0)
    output1 = torch.cat([tensor1, tensor2, tensor3, tensor4], dim=1)

    print('torch.cat dim=0:{0},{1}'.format(output0,output0.shape))
    print('torch.cat dim=1:{0},{1}'.format(output1,output1.shape))

输出

torch.cat dim=0:tensor([[ 1,  2],
        [ 3,  4],
        [ 5,  6],
        [ 7,  8],
        [ 9, 10],
        [11, 12],
        [13, 14],
        [15, 16]]),torch.Size([8, 2])
torch.cat dim=1:tensor([[ 1,  2,  5,  6,  9, 10, 13, 14],
        [ 3,  4,  7,  8, 11, 12, 15, 16]]),torch.Size([2, 8])

1.2.3 torch.cat串联三维Tensor序列

import torch

if __name__ == '__main__':
    tensor1 = torch.arange(0,8).view(2,2,2)
    tensor2 = torch.arange(8,16).view(2,2,2)
    tensor3 = torch.arange(16,24).view(2,2,2)
    tensor4 = torch.arange(24,32).view(2,2,2)

    output0 = torch.cat([tensor1,tensor2,tensor3,tensor4],dim=0)
    output1 = torch.cat([tensor1, tensor2, tensor3, tensor4], dim=1)
    output2 = torch.cat([tensor1, tensor2, tensor3, tensor4], dim=2)

    print('torch.cat dim=0:{0},{1}'.format(output0,output0.shape))
    print('torch.cat dim=1:{0},{1}'.format(output1,output1.shape))
    print('torch.cat dim=2:{0},{1}'.format(output2,output2.shape))

输出

torch.cat dim=0:tensor([[[ 0,  1],
         [ 2,  3]],

        [[ 4,  5],
         [ 6,  7]],

        [[ 8,  9],
         [10, 11]],

        [[12, 13],
         [14, 15]],

        [[16, 17],
         [18, 19]],

        [[20, 21],
         [22, 23]],

        [[24, 25],
         [26, 27]],

        [[28, 29],
         [30, 31]]]),torch.Size([8, 2, 2])
torch.cat dim=1:tensor([[[ 0,  1],
         [ 2,  3],
         [ 8,  9],
         [10, 11],
         [16, 17],
         [18, 19],
         [24, 25],
         [26, 27]],

        [[ 4,  5],
         [ 6,  7],
         [12, 13],
         [14, 15],
         [20, 21],
         [22, 23],
         [28, 29],
         [30, 31]]]),torch.Size([2, 8, 2])
torch.cat dim=2:tensor([[[ 0,  1,  8,  9, 16, 17, 24, 25],
         [ 2,  3, 10, 11, 18, 19, 26, 27]],

        [[ 4,  5, 12, 13, 20, 21, 28, 29],
         [ 6,  7, 14, 15, 22, 23, 30, 31]]]),torch.Size([2, 2, 8])