1 原文论文

wenet的论文Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition中的3.2.2节Dynamic Chunk Training中对动态chunk size有以下描述:

Motivated by the idea of unified E2E model, we further propose a dynamic chunk training. We can use dynamic chunk size for different batches in training, the dynamic chunk size range is a uniform distribution from 1 to max utterance length, namely the attention varies from left context attention to full context attention, and the model captures different information on various chunk size, and learns how to do accurate prediction when different limited right context provided. We call the chunks which sizes are from 1 to 25 as streaming chunk for streaming model and size which is max utterance length as none streaming chunk for none streaming model.

机翻成中文就是:

基于统一E2E模型的思想,我们进一步提出了一种动态组块训练。我们可以在训练中使用不同批次的动态块大小,动态块大小范围是从1到最大发声长度的均匀分布,即注意力从左上下文注意力变化到全上下文注意力,模型捕获不同块大小的不同信息,并学习在提供不同的有限右上下文时如何进行准确预测。我们将大小从1到25的块称为流式模型的流块,将最大发声长度的大小称为非流式模型的非流块。

2 源代码

上面论文提到的思想在wenet的源代码utils/mask.py的add_optional_chunk_mask函数的代码中也有体现,

def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
                            use_dynamic_chunk: bool,
                            use_dynamic_left_chunk: bool,
                            decoding_chunk_size: int, static_chunk_size: int,
                            num_decoding_left_chunks: int):
    """ Apply optional mask for encoder.

    Args:
        xs (torch.Tensor): padded input, (B, L, D), L for max length
        mask (torch.Tensor): mask for xs, (B, 1, L)
        use_dynamic_chunk (bool): whether to use dynamic chunk or not
        use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
            training.
        decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
            0: default for training, use random dynamic chunk.
            <0: for decoding, use full chunk.
            >0: for decoding, use fixed chunk size as set.
        static_chunk_size (int): chunk size for static chunk training/decoding
            if it's greater than 0, if use_dynamic_chunk is true,
            this parameter will be ignored
        num_decoding_left_chunks: number of left chunks, this is for decoding,
            the chunk size is decoding_chunk_size.
            >=0: use num_decoding_left_chunks
            <0: use all left chunks

    Returns:
        torch.Tensor: chunk mask of the input xs.
    """
    # Whether to use chunk mask or not
    if use_dynamic_chunk:
        max_len = xs.size(1)
        if decoding_chunk_size < 0:
            chunk_size = max_len
            num_left_chunks = -1
        elif decoding_chunk_size > 0:
            chunk_size = decoding_chunk_size
            num_left_chunks = num_decoding_left_chunks
        else:
            # chunk size is either [1, 25] or full context(max_len).
            # Since we use 4 times subsampling and allow up to 1s(100 frames)
            # delay, the maximum frame is 100 / 4 = 25.
            chunk_size = torch.randint(1, max_len, (1, )).item()
            num_left_chunks = -1
            if chunk_size > max_len // 2:
                chunk_size = max_len
            else:
                chunk_size = chunk_size % 25 + 1
                if use_dynamic_left_chunk:
                    max_left_chunks = (max_len - 1) // chunk_size
                    num_left_chunks = torch.randint(0, max_left_chunks,
                                                    (1, )).item()
        chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)
    elif static_chunk_size > 0:
        num_left_chunks = num_decoding_left_chunks
        chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)
    else:
        chunk_masks = masks
    return chunk_masks

其中,

# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2:
    chunk_size = max_len
else:
    chunk_size = chunk_size % 25 + 1
    if use_dynamic_left_chunk:
        max_left_chunks = (max_len - 1) // chunk_size
        num_left_chunks = torch.randint(0, max_left_chunks,
                                        (1, )).item()

就是动态chunk思想的实现,在上述代码中我们可以看到明显的注释

# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.

上述的注释的意思为:

chunk size要么在[1,25]的范围内取值,要么就使用全部上下文,因为我们使用了4次子采样并允许最多1s(100帧)的延迟,所以最大帧长度为100/4 = 25

对于上文中的4次子采样很好理解,在源码transformer/encoder.py中,类BaseEncoder的forward函数中,在语音数据进入编码器之前需要先经过self.embed的处理,self.embed使用了Conv2dSubsampling4对数据进行卷积,所以需要除以4。

那么对于上述注释的1s有100帧语音特征数据如何理解呢?下面将进行解释

3 wenet中语音数据特征提取

在wenet中源码dataset/dataset.py函数Dataset的以下代码涉及了语音特征的提取

assert feats_type in ['fbank', 'mfcc']
if feats_type == 'fbank':
    fbank_conf = conf.get('fbank_conf', {})
    dataset = Processor(dataset, processor.compute_fbank, **fbank_conf)
elif feats_type == 'mfcc':
    mfcc_conf = conf.get('mfcc_conf', {})
    dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf)

wenet提供了两种语音特征提取方式,一种Fbank特征,另一种为Mfcc特征,在wenet提供的示例配置文件中,这两种方法都使用帧长25ms,帧移10ms提取语音数据特征,这也就解释了为什么第2节的注释中说1s有100帧语音特征数据,1s为1000ms,按帧移10m,帧宽25ms提取特征那么1s所提特征数量就为:1000/10 = 100。

我们可以变相的把帧长25ms,帧移10ms理解为一个一维卷积操作(当然特征不是按卷积的方法来计算,不要误解),这个卷积的卷积核为25,步长为10,那么这样就好理解了。

具体的关于语音特征提取可参考:

4 小结

这段时间一直在看wenet的源码,由于之前没有接触过语音识别这块的内容,有些部分理解起来还有需要看比较多的知识。wenet确实在设计和实现的时候使用了很多的trick,需要我们静下心来好好理解基础原理和实现细节,每看懂一个部分就感觉豁然开朗一些,这种框架的设计真的很巧妙,非常值得学习。