是啥?

在深度学习中,“交叉熵”是一个用来衡量两个概率分布之间差异的函数,通常用于分类任务。

比喻:猜糖果的颜色

想象你有一袋糖果,里面有红色、绿色和蓝色三种颜色。你猜这袋糖果的颜色分布是红色 50%,绿色 30%,蓝色 20%。但实际上,这袋糖果的真实分布是红色 60%,绿色 20%,蓝色 20%。

交叉熵的作用就是告诉你,你的猜测(预测)和糖果袋的真实分布有多接近。如果你的猜测和真实情况差别很大,交叉熵的值会很大;如果你猜得很接近,交叉熵的值就会很小。

数学解释

交叉熵的数学公式可以写成:

\[ H(p, q) = - \sum p(x) \log(q(x)) \]

其中:

  • \( p(x) \) 是真实的概率分布(比如糖果袋的真实颜色分布)。
  • \( q(x) \) 是你预测的概率分布(比如你猜测的颜色分布)。
  • 这个公式的意思是:对于每种可能的糖果颜色,真实的概率 \( p(x) \) 和你预测的概率 \( q(x) \) 之间的差异,乘上 \( \log(q(x)) \),然后取负号,最后对所有颜色的可能性求和。

直观解释:

  1. 如果你猜得很对:比如你猜红色糖果有 60%,而真实也是 60%,那么交叉熵的值就会很小,表示你预测得很准确。
  2. 如果你猜得不对:比如你猜红色糖果只有 20%,而真实是 60%,那么交叉熵的值就会很大,表示你的预测与真实情况差距很大。

交叉熵的计算

例子:

假设我们有一个简单的分类问题,只有两个类别,比如“猫”和“狗”。
真实的标签是“猫”,所以真实概率 \( p(x) \) 是:

  • 猫:1.0
  • 狗:0.0

模型预测的概率 \( q(x) \) 是:

  • 猫:0.8
  • 狗:0.2

逐步计算:

根据公式,交叉熵计算会涉及每个类别。我们带入这两个类别的概率:

\[ H(p, q) = - \left( p(\text{猫}) \cdot \log(q(\text{猫})) + p(\text{狗}) \cdot \log(q(\text{狗})) \right) \]
  1. 计算“猫”部分

    • 真实概率 \( p(\text{猫}) = 1.0 \)
    • 预测概率 \( q(\text{猫}) = 0.8 \)
    • 计算这一部分:\( 1.0 \cdot \log(0.8) = \log(0.8) \approx -0.2231 \)
  2. 计算“狗”部分

    • 真实概率 \( p(\text{狗}) = 0.0 \)
    • 预测概率 \( q(\text{狗}) = 0.2 \)
    • 由于 \( p(\text{狗}) = 0 \),这一项的贡献为 0:\( 0.0 \cdot \log(0.2) = 0 \)
  3. 总和: 将两部分相加后取负号:

\[ H(p, q) = - \left( -0.2231 + 0 \right) = 0.2231 \]

结果解释:

交叉熵的值为 0.2231,说明模型的预测和真实标签还是比较接近的,但并不是完全正确(如果预测是 1.0,那交叉熵会是 0)。

在深度学习中的应用

在深度学习的分类任务中,模型输出的结果是每个类别的概率分布,比如一张图有猫的概率是 0.7,有狗的概率是 0.2,有鸟的概率是 0.1。交叉熵用来衡量模型的预测分布和真实标签之间的差异,帮助模型调整参数,以便预测得越来越准确。

通过最小化交叉熵,我们能够让模型的预测分布尽可能接近真实的类别分布,这就是为什么交叉熵经常作为损失函数使用的原因。