大模型困惑度计算

困惑度(Perplexity)

困惑度(Perplexity, PPL) 是衡量语言模型质量的常用指标。它表示模型对下一词的预测不确定性,数值越低,表示模型的预测越准确。困惑度越高,表示模型对语言的理解越“困惑”。

困惑度的计算公式

给定一个词序列 w1,w2,,wNw_1, w_2, \dots, w_N,困惑度的计算公式为:

Perplexity=exp(1Ni=1NlogP(wiw1,,wi1))\text{Perplexity} = \exp \left( -\frac{1}{N} \sum_{i=1}^{N} \log P(w_i \mid w_1, \dots, w_{i-1}) \right)

其中:

  • NN表示句子中的词数。
  • P(wiw1,,wi1)P(w_i \mid w_1, \dots, w_{i-1})是模型在给定上下文的条件下预测下一个词的概率。

例子:计算困惑度

假设我们有一个句子:
"The cat sat on the mat."

语言模型对每个词的预测概率如下:

P("The")=0.4P(\text{"The"}) = 0.4
P("cat""The")=0.3P(\text{"cat"} \mid \text{"The"}) = 0.3
P("sat""The cat")=0.2P(\text{"sat"} \mid \text{"The cat"}) = 0.2
P("on""The cat sat")=0.5P(\text{"on"} \mid \text{"The cat sat"}) = 0.5
P("the""The cat sat on")=0.7P(\text{"the"} \mid \text{"The cat sat on"}) = 0.7
P("mat""The cat sat on the")=0.6P(\text{"mat"} \mid \text{"The cat sat on the"}) = 0.6

第一步:计算对数概率

log(0.4)=0.9163\log(0.4) = -0.9163

log(0.3)=1.204\log(0.3) = -1.204

log(0.2)=1.6094\log(0.2) = -1.6094

log(0.5)=0.6931\log(0.5) = -0.6931

log(0.7)=0.3567\log(0.7) = -0.3567

log(0.6)=0.5108\log(0.6) = -0.5108

第二步:计算总和并取平均值

将这些对数概率相加:

0.9163+1.204+1.6094+0.6931+0.3567+0.5108=5.2903-0.9163 + -1.204 + -1.6094 + -0.6931 + -0.3567 + -0.5108 = -5.2903

平均值为:

5.29036=0.8817\frac{-5.2903}{6} = -0.8817

第三步:计算困惑度

最后,取平均值的指数:

Perplexity=exp(0.8817)2.414\text{Perplexity} = \exp(0.8817) \approx 2.414

结果分析

困惑度为 2.414,表示该模型在预测下一个词时,平均有 2.4 个可能的词可以选择。数值越小,说明模型对词的预测越有把握。

另一个模型的对比

假设另一个模型给出的预测概率如下:
P("The")=0.6P(\text{"The"}) = 0.6
P("cat""The")=0.5P(\text{"cat"} \mid \text{"The"}) = 0.5
P("sat""The cat")=0.4P(\text{"sat"} \mid \text{"The cat"}) = 0.4
P("on""The cat sat")=0.8P(\text{"on"} \mid \text{"The cat sat"}) = 0.8
P("the""The cat sat on")=0.9P(\text{"the"} \mid \text{"The cat sat on"}) = 0.9
P("mat""The cat sat on the")=0.9P(\text{"mat"} \mid \text{"The cat sat on the"}) = 0.9

我们重新计算困惑度:

第一步:计算对数概率

log(0.6)=0.5108\log(0.6) = -0.5108

log(0.5)=0.6931\log(0.5) = -0.6931

log(0.4)=0.9163\log(0.4) = -0.9163

log(0.8)=0.2231\log(0.8) = -0.2231

log(0.9)=0.1054\log(0.9) = -0.1054

log(0.95)=0.0513\log(0.95) = -0.0513

第二步:计算总和并取平均值

相加:

0.5108+0.6931+0.9163+0.2231+0.1054+0.0513=2.499-0.5108 + -0.6931 + -0.9163 + -0.2231 + -0.1054 + -0.0513 = -2.499

平均值为:

2.4996=0.4165\frac{-2.499}{6} = -0.4165

第三步:计算困惑度

Perplexity=exp(0.4165)1.516\text{Perplexity} = \exp(0.4165) \approx 1.516

结果对比

新模型的困惑度为 1.516,比之前的模型 2.414 更低,说明新模型对词序列的预测更加准确。


通过这个例子,可以看到困惑度越低,说明模型对语言的理解和预测越好。在语言模型的评估中,困惑度是衡量其预测能力的重要指标。