1 torch.cat函数

形式

torch.cat(tensors, dim=0, *, out=None)

功能

在指定的维度连接给定序列的张量,所有张量必须具有相同的形状(连接维度除外)或者为空。

参数

  • tensors:相同形状的张量序列,非空张量必须具有相同形状(连接维度除外)
  • dim:张量连接的维度

使用示例

import torch

if __name__ == '__main__':
    a = torch.randn(size=(2, 2))
    print(a)

    b = torch.randn(size=(2, 2))
    print(b)

    c = torch.cat((a, b), dim=0)
    print(c)

    d = torch.cat((a, b), dim=1)
    print(d)

输出

tensor([[ 0.8793,  0.3727],
        [-2.3334, -1.4567]])
tensor([[-1.0906, -1.2683],
        [ 0.7161, -0.5843]])
tensor([[ 0.8793,  0.3727],
        [-2.3334, -1.4567],
        [-1.0906, -1.2683],
        [ 0.7161, -0.5843]])
tensor([[ 0.8793,  0.3727, -1.0906, -1.2683],
        [-2.3334, -1.4567,  0.7161, -0.5843]])