是啥?
在深度学习中,“交叉熵”是一个用来衡量两个概率分布之间差异的函数,通常用于分类任务。
比喻:猜糖果的颜色
想象你有一袋糖果,里面有红色、绿色和蓝色三种颜色。你猜这袋糖果的颜色分布是红色 50%,绿色 30%,蓝色 20%。但实际上,这袋糖果的真实分布是红色 60%,绿色 20%,蓝色 20%。
交叉熵的作用就是告诉你,你的猜测(预测)和糖果袋的真实分布有多接近。如果你的猜测和真实情况差别很大,交叉熵的值会很大;如果你猜得很接近,交叉熵的值就会很小。
数学解释
交叉熵的数学公式可以写成:
\[ H(p, q) = - \sum p(x) \log(q(x)) \]其中:
- \( p(x) \) 是真实的概率分布(比如糖果袋的真实颜色分布)。
- \( q(x) \) 是你预测的概率分布(比如你猜测的颜色分布)。
- 这个公式的意思是:对于每种可能的糖果颜色,真实的概率 \( p(x) \) 和你预测的概率 \( q(x) \) 之间的差异,乘上 \( \log(q(x)) \),然后取负号,最后对所有颜色的可能性求和。
直观解释:
- 如果你猜得很对:比如你猜红色糖果有 60%,而真实也是 60%,那么交叉熵的值就会很小,表示你预测得很准确。
- 如果你猜得不对:比如你猜红色糖果只有 20%,而真实是 60%,那么交叉熵的值就会很大,表示你的预测与真实情况差距很大。
交叉熵的计算
例子:
假设我们有一个简单的分类问题,只有两个类别,比如“猫”和“狗”。
真实的标签是“猫”,所以真实概率 \( p(x) \) 是:
- 猫:1.0
- 狗:0.0
模型预测的概率 \( q(x) \) 是:
- 猫:0.8
- 狗:0.2
逐步计算:
根据公式,交叉熵计算会涉及每个类别。我们带入这两个类别的概率:
\[ H(p, q) = - \left( p(\text{猫}) \cdot \log(q(\text{猫})) + p(\text{狗}) \cdot \log(q(\text{狗})) \right) \]-
计算“猫”部分:
- 真实概率 \( p(\text{猫}) = 1.0 \)
- 预测概率 \( q(\text{猫}) = 0.8 \)
- 计算这一部分:\( 1.0 \cdot \log(0.8) = \log(0.8) \approx -0.2231 \)
-
计算“狗”部分:
- 真实概率 \( p(\text{狗}) = 0.0 \)
- 预测概率 \( q(\text{狗}) = 0.2 \)
- 由于 \( p(\text{狗}) = 0 \),这一项的贡献为 0:\( 0.0 \cdot \log(0.2) = 0 \)
-
总和: 将两部分相加后取负号:
结果解释:
交叉熵的值为 0.2231,说明模型的预测和真实标签还是比较接近的,但并不是完全正确(如果预测是 1.0,那交叉熵会是 0)。
在深度学习中的应用
在深度学习的分类任务中,模型输出的结果是每个类别的概率分布,比如一张图有猫的概率是 0.7,有狗的概率是 0.2,有鸟的概率是 0.1。交叉熵用来衡量模型的预测分布和真实标签之间的差异,帮助模型调整参数,以便预测得越来越准确。
通过最小化交叉熵,我们能够让模型的预测分布尽可能接近真实的类别分布,这就是为什么交叉熵经常作为损失函数使用的原因。