束搜索

上一节介绍了如何训练输入输出均为不定长序列的编码器—解码器,这一节我们介绍如何使用编码器—解码器来预测不定长的序列。

上一节里已经提到,在准备训练数据集时,我们通常会在样本的输入序列和输出序列后面分别附上一个特殊符号“<eos>”表示序列的终止。我们在接下来的讨论中也将沿用上一节的数学符号。为了便于讨论,假设解码器的输出是一段文本序列。设输出文本词典 \(\mathcal{Y}\)(包含特殊符号“<eos>”)的大小为 \(\left|\mathcal{Y}\right|\),输出序列的最大长度为 \(T'\)。所有可能的输出序列一共有 \(\mathcal{O}(\left|\mathcal{Y}\right|^{T'})\) 种。这些输出序列中所有特殊符号“<eos>”后面的子序列将被舍弃。

穷举搜索

我们在上一节描述解码器时提到,输出序列基于输入序列的条件概率是 \(\prod_{t'=1}^{T'} \mathbb{P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, \boldsymbol{c})\)。为了搜索该条件概率最大的输出序列,一种方法是穷举所有可能输出序列的条件概率,并输出条件概率最大的序列。我们将该序列称为最优序列,并将这种搜索方法称为穷举搜索(exhaustive search)。

虽然穷举搜索可以得到最优序列,但它的计算开销 \(\mathcal{O}(\left|\mathcal{Y}\right|^{T'})\) 很容易过大。例如,当 \(|\mathcal{Y}|=10000\)\(T'=10\) 时,我们将评估 \(10000^{10} = 10^{40}\) 个序列:这几乎不可能完成。

贪婪搜索

我们还可以使用贪婪搜索(greedy search)。也就是说,对于输出序列任一时间步 \(t'\),从 \(|\mathcal{Y}|\) 个词中搜索出输出词

\[y_{t'} = \text{argmax}_{y_{t'} \in \mathcal{Y}} \mathbb{P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, \boldsymbol{c}),\]

且一旦搜索出“<eos>”符号即完成输出序列。贪婪搜索的计算开销是 \(\mathcal{O}(\left|\mathcal{Y}\right|T')\)。它比起穷举搜索的计算开销显著下降。例如,当 \(|\mathcal{Y}|=10000\)\(T'=10\) 时,我们只需评估 \(10000\times10=1\times10^5\) 个序列。

下面我们来看一个例子。假设输出词典里面有“A”、“B”、“C”和“<eos>”这四个词。图 10.3 中每个时间步下的四个数字分别代表了该时间步生成“A”、“B”、“C”和“<eos>”这四个词的条件概率。在每个时间步,贪婪搜索选取生成条件概率最大的词。因此,图 10.3 中将生成序列“ABC<eos>”。该输出序列的条件概率是 \(0.5\times0.4\times0.4\times0.6 = 0.048\)

每个时间步下的四个数字分别代表了该时间步生成“A”、“B”、“C”和“<eos>”这四个词的条件概率。在每个时间步,贪婪搜索选取生成条件概率最大的词。

每个时间步下的四个数字分别代表了该时间步生成“A”、“B”、“C”和“<eos>”这四个词的条件概率。在每个时间步,贪婪搜索选取生成条件概率最大的词。

正如绝大部分贪婪算法不能保证最优解一样,贪婪搜索也无法保证找出条件概率最大的最优序列。图 10.4 演示了这样的一个例子。与图 10.3 中不同,图 10.4 在时间步 2 中选取了条件概率第二大的“C”。由于时间步 3 所基于的时间步 1 和 2 的输出子序列由图 10.3 中的“AB”变为了图 10.4 中的“AC”,图 10.4 中时间步 3 生成各个词的条件概率发生了变化。我们选取条件概率最大的“B”。此时时间步 4 所基于的前三个时间步的输出子序列为“ACB”,与图 10.3 中的“ABC”不同。因此图 10.4 中时间步 4 生成各个词的条件概率也与图 10.3 中的不同。我们发现,此时的输出序列“ACB<eos>”的条件概率是 \(0.5\times0.3\times0.6\times0.6=0.054\),大于贪婪搜索得到的输出序列的条件概率。因此,贪婪搜索得到的输出序列“ABC<eos>”并非最优序列。

每个时间步下的四个数字分别代表了该时间步生成“A”、“B”、“C”和“<eos>”这四个词的条件概率。在时间步2选取条件概率第二大的“C”。

每个时间步下的四个数字分别代表了该时间步生成“A”、“B”、“C”和“<eos>”这四个词的条件概率。在时间步2选取条件概率第二大的“C”。

束搜索

束搜索(beam search)是比贪婪搜索更加广义的搜索算法。它有一个束宽(beam size)超参数。我们将它设为 \(k\)。在时间步 1 时,选取当前时间步生成条件概率最大的 \(k\) 个词,分别组成 \(k\) 个候选输出序列的首词。在之后的每个时间步,基于上个时间步的 \(k\) 个候选输出序列,从 \(k\left|\mathcal{Y}\right|\) 个可能的输出序列中选取生成条件概率最大的 \(k\) 个,作为该时间步的候选输出序列。 最终,我们在各个时间步的候选输出序列中筛选出包含特殊符号“<eos>”的序列,并将它们中所有特殊符号“<eos>”后面的子序列舍弃,得到最终候选输出序列。在这些最终候选输出序列中,取以下分数最高的序列作为输出序列:

\[\frac{1}{L^\alpha} \log \mathbb{P}(y_1, \ldots, y_{L}) = \frac{1}{L^\alpha} \sum_{t'=1}^L \log \mathbb{P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, \boldsymbol{c}),\]

其中 \(L\) 为最终候选序列长度,\(\alpha\) 一般可选为 0.75。分母上的 \(L^\alpha\) 是为了惩罚较长序列在以上分数中较多的对数相加项。分析可得,束搜索的计算开销为 \(\mathcal{O}(k\left|\mathcal{Y}\right|T')\)。这介于穷举搜索和贪婪搜索的计算开销之间。

束搜索的过程。束宽为2,输出序列最大长度为3。候选输出序列有\ :math:`A`\ 、\ :math:`C`\ 、\ :math:`AB`\ 、\ :math:`CE`\ 、\ :math:`ABD`\ 和\ :math:`CED`\ 。

束搜索的过程。束宽为2,输出序列最大长度为3。候选输出序列有\(A\)\(C\)\(AB\)\(CE\)\(ABD\)\(CED\)

图 10.5 通过一个例子演示了束搜索的过程。假设输出序列的词典中只包含五个元素:\(\mathcal{Y} = \{A, B, C, D, E\}\),且其中一个为特殊符号“<eos>”。设束搜索的束宽等于 2,输出序列最大长度为 3。在输出序列的时间步 1 时,假设条件概率 \(\mathbb{P}(y_1 \mid \boldsymbol{c})\) 最大的两个词为 \(A\)\(C\)。我们在时间步 2 时将对所有的 \(y_2 \in \mathcal{Y}\) 都分别计算 \(\mathbb{P}(y_2 \mid A, \boldsymbol{c})\)\(\mathbb{P}(y_2 \mid C, \boldsymbol{c})\),并从计算出的 10 个条件概率中取最大的两个:假设为 \(\mathbb{P}(B \mid A, \boldsymbol{c})\)\(\mathbb{P}(E \mid C, \boldsymbol{c})\)。那么,我们在时间步 3 时将对所有的 \(y_3 \in \mathcal{Y}\) 都分别计算 \(\mathbb{P}(y_3 \mid A, B, \boldsymbol{c})\)\(\mathbb{P}(y_3 \mid C, E, \boldsymbol{c})\),并从计算出的 10 个条件概率中取最大的两个:假设为 \(\mathbb{P}(D \mid A, B, \boldsymbol{c})\)\(\mathbb{P}(D \mid C, E, \boldsymbol{c})\)。接下来,我们可以在 6 个候选输出序列:\(A\)\(C\)\(AB\)\(CE\)\(ABD\)\(CED\) 中筛选出包含特殊符号“<eos>”的序列,并将它们中所有特殊符号“<eos>”后面的子序列舍弃,得到最终候选输出序列。我们可以在最终候选输出序列中取分数最高的序列作为输出序列。

贪婪搜索可看作是束宽为 1 的束搜索。束搜索通过更灵活的束宽 \(k\) 来权衡计算开销和搜索质量。

小结

  • 预测不定长序列的方法包括穷举搜索、贪婪搜索和束搜索。
  • 束搜索通过更灵活的束宽来权衡计算开销和搜索质量。

练习

  • 穷举搜索可否看作是特殊束宽的束搜索?为什么?
  • “循环神经网络”一节中,我们使用语言模型创作歌词。它的输出属于哪种搜索?你能改进它吗?

扫码直达讨论区