knowwell-livewellの日記

knowwell-livewellの日記

好きなこととかもろもろ書きます

Cross-Entropy LossとBinary Cross-Entropy Lossの式と違いについて(クラス分類)

今回はCross-Entropy(クロスエントロピー)とBinary Cross-Entropy (バイナリクロスエントロピー)について書きます。クロスエントロピーは、分類タスクでよく用いられる損失関数ですね。一般的にクロスエントロピーの式といえば、下図第二辺ですが、特に2クラス分類の場合は第三辺のように表すこともできます。第二辺と第三辺の意味についてクラス分類の具体例を交えながら考えていきます。

f:id:knowwell-livewell:20220219214832p:plain
クロスエントロピー誤差( 交差エントロピー - Wikipedia より)
1. 多クラス分類の場合(例として3クラス分類)

例として、画像を入力して、それが犬、猫、ウサギのどれかを分類する深層学習モデルを考えます。この場合、犬、猫、ウサギをそれぞれ、(1,0,0)、(0,1,0)、(0,0,1)のようにone-hot表現で表し、モデル出力のユニット(ニューロン)数を3にして、モデルを学習させることが多いです。モデル学習時、モデル出力はソフトマックス関数により、確率に変換します。例えば、モデル出力がソフトマックス関数により(0.25,0.67,0.08)となれば、犬の確率が0.25(25%)、猫の確率が0.67(67%)、ウサギの確率が0.08(8%)であることを意味します。猫の画像を入力して、モデル出力の確率が(0.25,0.67,0.08)だった場合、(モデルを学習させるために必要な)クロスエントロピー誤差は以下のように計算できます。

f:id:knowwell-livewell:20220219225233p:plain
クロスエントロピーの計算

最初の図における pは正解の確率分布で、ここでは猫の画像を入力したことにしているので(0,1,0)です。 qはモデル出力の確率分布で、(0.25,0.67,0.08)になります。正解クラスのモデル出力(確率)の対数の負を計算していますね。

2. 2クラス分類の場合

続いて、2クラス分類の場合を考えます。1. 多クラス分類の場合 からウサギをなくしましょう(つまり、犬と猫の2クラス分類)。先ほどと同様に、犬、猫をそれぞれone-hot表現により、(1,0)、(0,1)と表し、モデル出力のユニット数を2として学習させることにします。例えば、犬の画像を入力して、モデル出力の確率が(0.74,0.26)となったとすると、クロスエントロピー誤差は1. 多クラス分類の場合 と同じように以下で計算することが出来ます。

f:id:knowwell-livewell:20220219231536p:plain
クロスエントロピーの計算

ところで、2クラス分類の場合、片方のクラスに対する確率が分かると、もう片方の確率が自動的に分かる(1から引けばよい)ので、出力層のユニット数は1つでもクラス分類することができます。犬を0、猫を1と表現(labelエンコーディング)し、モデル出力のユニット数を1にして学習させることにします。このとき、モデル出力をシグモイド関数により0~1の値に変換し、その値を1のクラス(ここでは猫)の確率と考えます。例えば、モデル出力の確率が0.26となれば、猫の確率が0.26、犬の確率が0.74であることを意味します(そのため、0.5以上なら猫と、0.5より小さいなら犬と予測したのだと判断できます)。例えば、猫の画像を入力して、モデル出力の確率が0.6になったとすると、クロスエントロピー誤差は最初の図の第三辺を用いて、以下のように計算できます。

f:id:knowwell-livewell:20220219234907p:plain
クロスエントロピーの計算

最初の図における yは正解ラベルで、ここでは猫の画像を入力したことにしているので1です。 \hat{y}はモデル出力の確率で、0.6です。続いて、one-hot表現で計算した例と同じ例を考えてみます。つまり、犬の画像を入力して、モデル出力が0.26となった場合、クロスエントロピー誤差は以下のように計算できます。

f:id:knowwell-livewell:20220220000451p:plain
クロスエントロピーの計算

同じ結果が得られましたね。
まとめると、正解クラスが0の場合、最初の図の第三辺の一項目がなくなり、モデル出力を1から引いたもの(つまり、クラス0の確率)の対数の負を計算することになります。正解クラスが1の場合には、二項目がなくなり、モデル出力(クラスの確率)の対数の負を計算することになります。

3. おわりに

2. 2クラス分類の場合 で見たように、正解をone-hot表現で表してクロスエントロピー誤差を計算する場合(最初の図の第二辺)でも、正解を0と1で表して計算する場合(最初の図の第三辺)でも、結局は正解クラスのモデル出力(確率)の対数の負を計算していて、同じ値になることが確認できましたね。


(間違っていたらコメントでお教えいただけますと助かります。)