深度学习 – 语音识别框架wenet的非流式与流式混合训练机制
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:深度学习 – 语音识别框架wenet的非流式与流式混合训练机制
原文链接:https://www.stubbornhuang.com/2297/
发布于:2022年08月11日 9:20:07
修改于:2022年10月14日 10:17:01

1 wenet的非流式与流式混合训练机制
wenet实现了语音识别非流式与流式混合训练的机制。通过细读源码,其主要是通过动态修改网络的Encoder层(在wenet中主要使用了TransformerEncoder和Conformer)的attention mask来影响Encoder层中Self-Attention的计算结果,从而实现一次训练既可以实现非流式语音识别也可以实现流式语音识别。
在wenet中提出了dynamic chunk的训练机制,在训练过程中,每一个batch都使用dynamic chunk size进行训练,换句话说,每一个batch所使用的chunk size是不同的,是动态变化的。
记当前batch的最大的序列长度为L,而dynamic chunk size在1到该batch的最大序列长度L的范围内随机取值,如果随机值大于\frac{L}{2},则将chunk size的值取为L,如果随机值小于或者等于\frac{L}{2},则chunk size从[1,25]的范围内随机取值。如chunk size取值为batch中最大序列长度L,则表明使用全部上下文注意,这种则作为非流式训练的组成部分;而另一种将chunk size取值为[1,25],则表示使用部分上下文注意,这种取值方式则作为流式训练的组成部分。
而其核心代码主要是wenet/utils/mask.py中的add_optional_chunk_mask
函数,下面将通过这个函数的代码梳理下wenet如何通过该函数构造出不同的attention mask。
1.1 代码梳理
add_optional_chunk_mask源代码
def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int, static_chunk_size: int,
num_decoding_left_chunks: int):
""" Apply optional mask for encoder.
Args:
xs (torch.Tensor): padded input, (B, L, D), L for max length
mask (torch.Tensor): mask for xs, (B, 1, L)
use_dynamic_chunk (bool): whether to use dynamic chunk or not
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
training.
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
static_chunk_size (int): chunk size for static chunk training/decoding
if it's greater than 0, if use_dynamic_chunk is true,
this parameter will be ignored
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
torch.Tensor: chunk mask of the input xs.
"""
# Whether to use chunk mask or not
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_left_chunks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_left_chunks = num_decoding_left_chunks
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
else:
chunk_masks = masks
return chunk_masks
add_optional_chunk_mask
函数的函数参数详情如下:
- xs:tensor,已padding后的Tensor,形状为(B,L,D),其中L为最大长度
- masks:tensor,xs对应的padding mask,形状为(B,1,L)
- use_dynamic_chunk:是否使用dynamic chunk
- use_dynamic_left_chunk:是否使用dynamic left chunk
- decoding_chunk_size:dynamic chunk的decode chunk大小,在模型训练时默认为0,使用随机的dynamic chunk;如果decode_chunk_size小于0,使用full chunk;如果decode_chunk_size大于0,使用固定的decode_chunk_size
- static_chunk_size:在static chunk训练和解码的chunk size,当use_dynamic_chunk为true时,不管static_chunk_size设置为什么值都会被忽略
- num_decoding_left_chunks:在解码时需用到多少个left chunk,解码的chunk_size为decode_chunk_size,如果num_decoding_left_chunks小于0,则使用全部的left chunk;如果num_deocding_left_chunks大于等于0,则使用num_deocding_left_chunks个left chunk
提供了3种Encoder层attention mask的生成方式,
- 第一种是不计算attention mask,直接返回原有的padding mask
- 第二种是生成指定的static_chunk_size的attention mask
- 第三种则是dynamic chunk的attention mask
1.1.1 不计算attention mask
相关代码
else:
chunk_masks = masks
代码很简洁,就是直接将输入的padding mask直接赋值给chunk mask进行返回,不进行任何操作。
在此种方式下就是原生的TransformerEncoder或者ConformerEncoder的编码方式,所有的数据都参与编码过程。
1.1.2 static_chunk_size的attention mask
相关代码
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
在static_chunk_size的情况下,不需要向dynamic模式那样动态改变chunk size,所以相对来说比较简单,在参数num_left_chunks
和static_chunk_size
确定的情况下,直接使用subsequent_chunk_mask
函数求解出chunk mask,然后与输入的padding mask做&
操作得到最后的mask。subsequent_chunk_mask
函数的用法可以参考我另一篇博文深度学习 – 语音识别框架wenet源码wenet/utils/mask.py中的mask机制,该函数的主要作用是使用设定的chunk_size
创建形状为(size,size)的mask Tensor。
1.1.3 dynamic_chunk_size的attention mask
相关代码
# Whether to use chunk mask or not
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_left_chunks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_left_chunks = num_decoding_left_chunks
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
终于来到正题dynamic chunk的代码了。
上述代码中decoding_chunk_size不等于0主要是用于解码;而训练时,decoding_chunk_size等于0,训练的关键代码为
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
上述代码中,max_len
为当前batch的序列最大长度,首先通过
chunk_size = torch.randint(1, max_len, (1, )).item()
在[1,max_len]的范围内随机取值,然后将随机取值进行判断
if chunk_size > max_len // 2:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
如果随机取值大于序列长度的一半,则将chunk_size
设置为序列的最大长度值;如果如果随机取值小于或者等于序列长度的一半,则取chunk_size = chunk_size % 25 + 1
,即在[1,25]的范围内随机取值。如果chunk_size
设置为序列的最大长度值则为非流式训练,如果chunk_size
在[1,25]的范围内随机取值则为流式训练。
同时,为了减少模型在解码时不会因为chunk size设置的值所带来识别效果大幅变差的情况,使用了dynamic left chunk,相关的代码如下,
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
上述代码的核心思想是,如果我们使用了dynamic left chunk,则先根据chunk_size
与max_len
算出最大的左侧chunk数量max_left_chunk
,然后从[0, max_left_chunk]的范围内随机取值作为num_left_chunks
。
等上述所有过程处理完毕,则使用subsequent_chunk_mask
函数求解出chunk mask,然后与输入的padding mask做&
操作得到最后的mask。
2 总结
静下心来细读wenet的源码,还是收获颇多,终于解答了自己“为什么wenet既可以实现流式语音识别又可以实现非流式语音识别?”的疑问,它的这种设计对需要进行流式识别的领域有诸多启发,还是感谢wenet团队的无私奉献。
当前分类随机文章推荐
- 深度学习 - 基础的Greedy Search和Beam Search算法的Python实现 阅读686次,点赞0次
- Transformer - 理解Transformer必看系列之,1 Self-Attention自注意力机制与多头注意力原理 阅读521次,点赞0次
- 深度学习 - 为什么要初始化网络模型权重? 阅读343次,点赞0次
- 深度学习 - 归纳轻量级神经网络(长期更新) 阅读80次,点赞0次
- 深度学习 - 通俗理解Beam Search Algorithm算法 阅读640次,点赞0次
- 深度学习 - 语音识别框架中wenet最大动态chunk大小为什么取值为25? 阅读751次,点赞0次
- 深度学习 - 图像标准化与归一化方法 阅读479次,点赞0次
- 深度学习 - Transformer详解 阅读649次,点赞0次
- 深度学习 - CTC算法原理详解 阅读545次,点赞0次
- 深度学习 - 我的深度学习项目代码文件组织结构 阅读1069次,点赞3次
全站随机文章推荐
- C++ - 使用正则判断字符串是否全是中文 阅读1052次,点赞0次
- 资源分享 - Graphics Gems II 英文高清PDF下载 阅读1994次,点赞0次
- 深度学习 - 为什么要初始化网络模型权重? 阅读343次,点赞0次
- 工具软件 - 解决从Onenote复制文字到QQ变成图片的问题,2023年最新解决方案 阅读84次,点赞0次
- FFmpeg - FFmpeg历史版本下载和函数弃用列表 阅读1705次,点赞0次
- OpenGL画四个三角形组成四面体,并进行旋转 阅读3331次,点赞0次
- C++11 - 使用std::codecvt进行字符编码转换需要注意的时间效率问题 阅读1683次,点赞1次
- 资源分享 - Hands-On C++ Game Animation Programming 英文PDF下载 阅读543次,点赞0次
- 资源分享 - Data Structures and Algorithms for Game Developers 英文高清PDF下载 阅读1522次,点赞0次
- OpenCV - cv::VideoWriter::fourcc可支持的视频编码格式 阅读2280次,点赞0次
评论
167