• 本站由于前段时间遭受到大量临时和国外邮箱注册,所以对可注册的邮箱类型进行了限制!

  • 工资「喂饱肚子」,副业「养活灵魂」!

  • 欢迎大家交换友链,可在https://www.stubbornhuang.com/申请友情链接进行友链交换申请!

  • 计算机图形学与计算几何经典必备书单整理,下载链接可参考:https://www.stubbornhuang.com/1256/

  • 感谢大家访问本站,希望本站的内容可以帮助到大家!

  • 问题反馈可发送邮件到stubbornhuang@qq.com

  • 本站会放置Google广告用于维持域名以及网站服务器费用。

  • 在本站开通年度VIP,无限制下载本站资源和阅读本站文章

  • 如果觉得本站的内容有帮助,可以考虑打赏博主哦!

深度学习 – Python实现CTC Decode解码算法Greedy Search Decode,Beam Search Decode,Prefix Beam Search Decode

深度学习 发布于2022-07-19 阅读 6,956次 0次评论 0次点赞 本文共9097个字,阅读需要23分钟。

在语音识别、OCR文字识别领域,我们在推理的最后一步就是从预测的概率矩阵中使用CTC解码算法找到可能性最大的序列。而常用的CTC解码算法一般有Greedy Search Decode(贪心搜索)、Beam Search Decode(束搜索)、Prefix Beam Search Decode(前缀束搜索)等,其中又以Greedy Search Decode(贪心搜索)和Prefix Beam Search Decode(前缀束搜索)使用的最多,本文将使用Python代码逐一实现上述三种算法。

1 Greedy Search Decode(贪心搜索)

1.1 原理

Greedy Search Decode(贪心搜索)主要是找到每一个时间步中概率最大的字符,然后去除相邻重复的字符,最后去除blank字符。

1.2 代码

import numpy as np
from itertools import groupby
import math
from collections import defaultdict

NEG_INF = -float("inf")

def softmax(logits):
    max_value = np.max(logits, axis=1, keepdims=True)
    exp = np.exp(logits - max_value)
    exp_sum = np.sum(exp, axis=1, keepdims=True)
    dist = exp / exp_sum
    return dist

def logsumexp(*args):
  """
  Stable log sum exp.
  """
  if all(a == NEG_INF for a in args):
      return NEG_INF
  a_max = max(args)
  lsp = math.log(sum(math.exp(a - a_max)
                      for a in args))
  return a_max + lsp

def ctc_greedy_search_decode(probs,blank = 0):
    index_list = np.argmax(probs, axis=1)

    # 去除相邻重复
    index_list = [key for key,group in groupby(index_list)]

    # 去除blank
    index_list = [*filter(lambda x: x != blank, index_list)]

    return index_list

if __name__ == '__main__':
    np.random.seed(11)
    time_step = 20
    vocab_size = 20

    # 假设时间步数为20步,词汇表中包含的词汇为20个
    probs = np.random.rand(time_step,vocab_size)

    # 输入ctc_decode需要进行softmax
    probs = softmax(probs)

    res = ctc_greedy_search_decode(probs)

    print('label:{}'.format(res))

输出

label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 1, 12]

2 Beam Search Decode(束搜索)

2.1 原理

Greedy Search Decode(贪心搜索)只能找到最优路径,而不能找到其他优选路径。而Beam Search Decode(束搜索)可根据设置的beam size找到beam size条优选路径,Beam Search Decode(束搜索)的具体原理可参考我的文章:深度学习 – 通俗理解Beam Search Algorithm算法

2.2 代码

import numpy as np
from itertools import groupby
import math
from collections import defaultdict

NEG_INF = -float("inf")

def softmax(logits):
    max_value = np.max(logits, axis=1, keepdims=True)
    exp = np.exp(logits - max_value)
    exp_sum = np.sum(exp, axis=1, keepdims=True)
    dist = exp / exp_sum
    return dist

def logsumexp(*args):
  """
  Stable log sum exp.
  """
  if all(a == NEG_INF for a in args):
      return NEG_INF
  a_max = max(args)
  lsp = math.log(sum(math.exp(a - a_max)
                      for a in args))
  return a_max + lsp


def ctc_beam_search_decode(probs, beam_size=10 ,blank = 0):
    T, V = probs.shape
    log_probs = np.log(probs)

    beam = [([], 0)]
    for t in range(T):
        new_beam = []
        for prefix, score in beam:
            for i in range(V):
                new_prefix = prefix + [i]
                new_score = score + log_probs[t, i]

                new_beam.append((new_prefix, new_score))

        # top beam_size
        new_beam.sort(key=lambda x: x[1], reverse=True)
        beam = new_beam[:beam_size]

    # 去除相邻重复和blank
    res = []
    for label, score in beam:
        # 去除相邻重复
        index_list = [key for key, group in groupby(label)]

        # 去除blank
        index_list = [*filter(lambda x: x != blank, index_list)]

        res.append((index_list,score))

    return res

if __name__ == '__main__':
    np.random.seed(11)
    time_step = 20
    vocab_size = 20

    # 假设时间步数为20步,词汇表中包含的词汇为20个
    probs = np.random.rand(time_step,vocab_size)

    # 输入ctc_decode需要进行softmax
    probs = softmax(probs)

    res = ctc_beam_search_decode(probs,beam_size = 3,blank=0)
    for label, score in res:
        print('label:{},score={}'.format(label,score))

输出

label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 1, 12],score=-51.8869170531208
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 6, 15, 16, 7, 11, 18, 3, 1, 12],score=-51.88710645664252
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 12],score=-51.89227568133298

3 Prefix Beam Search Decode(前缀束搜索)

3.1 原理

上述的Beam Search Decode(束搜索)存在的一个问题是,在找到的beam size条路径中,可能存在多条重复路径(去相邻重复,去blank操作之后),这很大程度上影响了搜索结果的多样性。而Prefix Beam Search Decode(前缀束搜索)在搜索的过程中,在遇到blank时不改变之前的搜索结果,并且不断合并相同的前缀,这种方法很大程度上减少了重复路径,而这种方法也是目前CTC解码算法的最优选方法。其原理具体可参考这篇论文:First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs

论文中给出了算法的伪代码,
深度学习 - Python实现CTC Decode解码算法Greedy Search Decode,Beam Search Decode,Prefix Beam Search Decode-第0张图片

在下方的示例代码中用到了一个trick,一些文章中也没有说的很清楚,即函数logsumexp,当当前时间步的概率x和下一步的概率y都为一个非常小的值时,我们求解概率和x+y的数值可能会下溢,所以我们可以将x+y转移到对数空间避免下溢,我们知道

log(x \cdot y) = log x+log y

如果我们需要将log(x+y)转移到对数空间该怎么做呢?可以通过以下的方式

log(x+y) = log(exp(logx)+exp(logy))

log(exp(logx)+exp(logy))有一个专有的名字:log-sum-exp

log\text{-}sum\text{-}exp(u,v)=log(exp(u)+exp(v))

所以,log(x+y)可以写成

log(x+y) = log\text{-}sum\text{-}exp(logx,logy)

而我们可以通过以下公式计算log\text{-}sum\text{-}exp(u,v)

log\text{-}sum\text{-}exp(u,v) = max(u,v)+log(exp(u-max(u,v)) + exp(v - max(u,v)))

而这就是下列代码中logsumexp写成那种形式的原因。

更加详细的解释可参考以下文章:

3.3 代码

import numpy as np
from itertools import groupby
import math
from collections import defaultdict

NEG_INF = -float("inf")

def softmax(logits):
    max_value = np.max(logits, axis=1, keepdims=True)
    exp = np.exp(logits - max_value)
    exp_sum = np.sum(exp, axis=1, keepdims=True)
    dist = exp / exp_sum
    return dist

def logsumexp(*args):
  """
  Stable log sum exp.
  """
  if all(a == NEG_INF for a in args):
      return NEG_INF
  a_max = max(args)
  lsp = math.log(sum(math.exp(a - a_max)
                      for a in args))
  return a_max + lsp

def ctc_prefix_beam_search_decode(prob, beam_size=10, blank=0):
    T, V = prob.shape
    log_prob = np.log(prob)

    beam = [(tuple(), (0, NEG_INF))]  # blank, non-blank
    for t in range(T):  # for every timestep
        new_beam = defaultdict(lambda: (NEG_INF, NEG_INF))

        for prefix, (p_b, p_nb) in beam:
            for i in range(V):  # for every state
                p = log_prob[t, i]

                if i == blank:  # propose a blank
                    new_p_b, new_p_nb = new_beam[prefix]
                    new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
                    new_beam[prefix] = (new_p_b, new_p_nb)
                    continue
                else:  # extend with non-blank
                    end_t = prefix[-1] if prefix else None

                    # exntend current prefix
                    new_prefix = prefix + (i,)
                    new_p_b, new_p_nb = new_beam[new_prefix]
                    if i != end_t:
                        new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
                    else:
                        new_p_nb = logsumexp(new_p_nb, p_b + p)
                    new_beam[new_prefix] = (new_p_b, new_p_nb)

                    # keep current prefix
                    if i == end_t:
                        new_p_b, new_p_nb = new_beam[prefix]
                        new_p_nb = logsumexp(new_p_nb, p_nb + p)
                        new_beam[prefix] = (new_p_b, new_p_nb)

        # top beam_size
        beam = sorted(new_beam.items(), key=lambda x: logsumexp(*x[1]), reverse=True)
        beam = beam[:beam_size]

    return beam

if __name__ == '__main__':
    np.random.seed(11)
    time_step = 20
    vocab_size = 20

    # 假设时间步数为20步,词汇表中包含的词汇为20个
    probs = np.random.rand(time_step,vocab_size)

    # 输入ctc_decode需要进行softmax
    probs = softmax(probs)

    res = ctc_prefix_beam_search_decode(probs,beam_size = 3,blank=0)
    for beam in res:
        print('label:{},score={}'.format(beam[0],logsumexp(*beam[1])))

输出

label:(12, 7, 9, 19, 2, 15, 12, 11, 3),score=-43.130412256239644
label:(12, 7, 9, 19, 2, 15, 12, 11, 3, 12),score=-43.59912015650705
label:(12, 7, 9, 19, 2, 15, 12, 11, 3, 11),score=-43.61975284105764

4 对统一概率矩阵三种CTC解码算法的执行结果

import numpy as np
from itertools import groupby
import math
from collections import defaultdict

NEG_INF = -float("inf")

def softmax(logits):
    max_value = np.max(logits, axis=1, keepdims=True)
    exp = np.exp(logits - max_value)
    exp_sum = np.sum(exp, axis=1, keepdims=True)
    dist = exp / exp_sum
    return dist

def logsumexp(*args):
  """
  Stable log sum exp.
  """
  if all(a == NEG_INF for a in args):
      return NEG_INF
  a_max = max(args)
  lsp = math.log(sum(math.exp(a - a_max)
                      for a in args))
  return a_max + lsp

def ctc_greedy_search_decode(probs,blank = 0):
    index_list = np.argmax(probs, axis=1)

    # 去除相邻重复
    index_list = [key for key,group in groupby(index_list)]

    # 去除blank
    index_list = [*filter(lambda x: x != blank, index_list)]

    return index_list


def ctc_beam_search_decode(probs, beam_size=10 ,blank = 0):
    T, V = probs.shape
    log_probs = np.log(probs)

    beam = [([], 0)]
    for t in range(T):
        new_beam = []
        for prefix, score in beam:
            for i in range(V):
                new_prefix = prefix + [i]
                new_score = score + log_probs[t, i]

                new_beam.append((new_prefix, new_score))

        # top beam_size
        new_beam.sort(key=lambda x: x[1], reverse=True)
        beam = new_beam[:beam_size]

    # 去除相邻重复和blank
    res = []
    for label, score in beam:
        # 去除相邻重复
        index_list = [key for key, group in groupby(label)]

        # 去除blank
        index_list = [*filter(lambda x: x != blank, index_list)]

        res.append((index_list,score))

    return res

def ctc_prefix_beam_search_decode(prob, beam_size=10, blank=0):
    T, V = prob.shape
    log_prob = np.log(prob)

    beam = [(tuple(), (0, NEG_INF))]  # blank, non-blank
    for t in range(T):  # for every timestep
        new_beam = defaultdict(lambda: (NEG_INF, NEG_INF))

        for prefix, (p_b, p_nb) in beam:
            for i in range(V):  # for every state
                p = log_prob[t, i]

                if i == blank:  # propose a blank
                    new_p_b, new_p_nb = new_beam[prefix]
                    new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
                    new_beam[prefix] = (new_p_b, new_p_nb)
                    continue
                else:  # extend with non-blank
                    end_t = prefix[-1] if prefix else None

                    # exntend current prefix
                    new_prefix = prefix + (i,)
                    new_p_b, new_p_nb = new_beam[new_prefix]
                    if i != end_t:
                        new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
                    else:
                        new_p_nb = logsumexp(new_p_nb, p_b + p)
                    new_beam[new_prefix] = (new_p_b, new_p_nb)

                    # keep current prefix
                    if i == end_t:
                        new_p_b, new_p_nb = new_beam[prefix]
                        new_p_nb = logsumexp(new_p_nb, p_nb + p)
                        new_beam[prefix] = (new_p_b, new_p_nb)

        # top beam_size
        beam = sorted(new_beam.items(), key=lambda x: logsumexp(*x[1]), reverse=True)
        beam = beam[:beam_size]

    return beam

if __name__ == '__main__':
    np.random.seed(11)
    time_step = 20
    vocab_size = 20

    # 假设时间步数为20步,词汇表中包含的词汇为20个
    probs = np.random.rand(time_step,vocab_size)

    # 输入ctc_decode需要进行softmax
    probs = softmax(probs)

    res = ctc_greedy_search_decode(probs)

    print('ctc_greedy_search_decode:\nlabel:{}\n'.format(res))

    print('ctc_beam_search_decode:')
    res = ctc_beam_search_decode(probs,beam_size = 3,blank=0)
    for label, score in res:
        print('label:{},score={}'.format(label,score))
    print('\n')


    print('ctc_prefix_beam_search_decode')
    res = ctc_prefix_beam_search_decode(probs,beam_size = 3,blank=0)
    for beam in res:
        print('label:{},score={}'.format(beam[0],logsumexp(*beam[1])))
    print('\n')

输出

ctc_greedy_search_decode:
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 1, 12]

ctc_beam_search_decode:
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 1, 12],score=-51.8869170531208
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 6, 15, 16, 7, 11, 18, 3, 1, 12],score=-51.88710645664252
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 12],score=-51.89227568133298


ctc_prefix_beam_search_decode
label:(12, 7, 9, 19, 2, 15, 12, 11, 3),score=-43.130412256239644
label:(12, 7, 9, 19, 2, 15, 12, 11, 3, 12),score=-43.59912015650705
label:(12, 7, 9, 19, 2, 15, 12, 11, 3, 11),score=-43.61975284105764

参考链接

欢迎扫码关注我的微信公众号,及时获取文章更新

微信公众号二维码

本文作者:StubbornHuang

版权声明:本文为站长原创文章,如果转载请注明原文链接!

原文标题:深度学习 – Python实现CTC Decode解码算法Greedy Search Decode,Beam Search Decode,Prefix Beam Search Decode

原文链接:https://www.stubbornhuang.com/2222/

发布于:2022年07月19日 8:53:27

修改于:2023年06月25日 20:53:23

声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。

文章末尾
上一篇
深度学习 - 基础的Greedy Search和Beam Search算法的Python实现
深度学习
下一篇
深度学习 - 从矩阵运算的角度理解Transformer中的self-attention自注意力机制
Transformer
当前分类随机文章推荐

发表评论

您必须 [ 登录 ] 才能发表留言!

关注我们的公众号

微信公众号