Pytorch – torch.nn.Conv1d参数详解与使用
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – torch.nn.Conv1d参数详解与使用
原文链接:https://www.stubbornhuang.com/2185/
发布于:2022年06月28日 16:25:31
修改于:2022年08月31日 15:21:05
1 torch.nn.Conv1d
torch.nn.Conv1d主要是对一维输入Tensor应用一维卷积。
如果一维卷积输入为(N,C_{in},L),输出为(N,C_{out},L_{out}),那么这两者的关系可描述为
其中\star为cross-correlation算子,N为batch size,C为输入通道数,L为输入序列的长度。
1.1 torch.nn.Conv1d
形式
torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
参数
- in_channels(int):输入的特征维度
- out_channels(int):输出的特征维度
- kernel_size(int或者tuple):卷积核的大小
- stride(int或者tuple):默认值为1,卷积的步幅
- padding(int或者tuple):默认值为0,添加到输入两侧的零填充数量
- padding_mode(字符串):默认值为"zeros",可选值为"zeros"、"reflect"、"replicate"、“circular”
- dilation(int或者tuple):内核元素之间的间距
- groups(int):默认值为1,从输入通道到输出通道的阻塞连接数
- bias(bool):默认值为True,如果为True,则向输出添加可学习的偏差。
可以通过这个链接https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md查看stride、padding、dilation等参数对卷积过程的影响。
输入与输出维度
一般,输入输出具有以下维度
- Input:(N,C_{in},L_{in})
- Output:(N,C_{out},L_{out})
其中,L_{out}可通过以下公式计算
1.2 torch.nn.Conv1d的简单使用
假设有batch_size为8,input_channels 特征维度为16,长度为50的输入序列,然后使用卷积核大小为3,卷积步幅为2,padding为0的一维卷积层对该输入序列进行一维卷积。
从1.1节中我们知道,torch.nn.Conv1d的输入输出的维度为:
- Input:(N,C_{in},L_{in})
- Output:(N,C_{out},L_{out})
那么按照上述描述,N为batch size的大小即为8,C_{out}被指定为33,而
L_{out} & = \left \lfloor \frac{50 + 2 \times 0 - 1 \times (2 - 1) -1 }{2} + 1 \right \rfloor \\
L_{out} & = \left \lfloor 24.5 \right \rfloor \\
L_{out} & = 24
\end{aligned}
从而得出Output的输出维度为(8,33,24)
pytorch代码实现如下:
import torch
if __name__ == '__main__':
input = torch.randn(8, 16, 50)
conv1d = torch.nn.Conv1d(16, 33, 3, stride=2)
output = conv1d(input)
print(output.shape)
输出
torch.Size([8, 33, 24])
当前分类随机文章推荐
- 深度学习 - 我的深度学习项目代码文件组织结构 阅读1401次,点赞3次
- Pytorch - 用Pytorch实现ResNet 阅读507次,点赞0次
- Pytorch - torch.chunk参数详解与使用 阅读1127次,点赞0次
- Pytorch - 模型微调时删除原有模型中的某一层的方法 阅读2130次,点赞0次
- Python - list/numpy/pytorch tensor相互转换 阅读1731次,点赞0次
- Pytorch - torch.stack参数详解与使用 阅读766次,点赞0次
- Pytorch - Pytoch结合Tensorboard实现数据可视化 阅读231次,点赞0次
- Pytorch - RuntimeError: No rendezvous handler for env://错误 阅读1058次,点赞0次
- Pytorch - reshape和view的用法和区别 阅读356次,点赞0次
- Pytorch - nn.Transformer、nn.TransformerEncoderLayer、nn.TransformerEncoder、nn.TransformerDecoder、nn.TransformerDecoder参数详解 阅读2785次,点赞1次
全站随机文章推荐
- Python3 - 正则表达式去除字符串中的特殊符号 阅读13467次,点赞1次
- 资源分享 - 全局光照算法技术 第2版 , Advanced Global Illumination 2nd Edition 中文版PDF下载 阅读2055次,点赞2次
- VTK - 冠脉重建点匹配坐标数据下载 阅读3910次,点赞5次
- 工具推荐 - 一些好用的DNS服务器 阅读866次,点赞0次
- 资源分享 - Real-Time 3D Character Animation with Visual C++ 英文高清PDF下载 阅读1631次,点赞0次
- C++ - Jni中的GetByteArrayElements和GetByteArrayRegion的区别和使用示例 阅读3156次,点赞0次
- C++ – 字节数组byte[]或者unsigned char[]与long double的相互转换 阅读1105次,点赞0次
- MindSpore - LeNet5的MindSpore实现 阅读1103次,点赞0次
- ThreeJS - 直接设置Fbx模型的某个关节的位移和旋转值 阅读1801次,点赞0次
- 资源分享 - Geometric Algebra for Computer Science - An Object-Oriented Approach to Geometry (First Edition) 英文高清PDF下载 阅读2192次,点赞0次
评论
169