knowwell-livewellの日記

knowwell-livewellの日記

好きなこととかもろもろ書きます

セマンティックセグメンテーションで利用されるloss関数(損失関数)について①

今回はセマンティックセグメンテーション(領域推定)のタスクでよく利用されるloss関数をまとめます(あまり利用されていないものも含めます)。

1. セマンティックセグメンテーションとは?

セマンティックセグメンテーションは画像中の対象物体の領域を推定するタスクで、畳み込みニューラルネットワーク(CNN)を活用することで精度良く推定できるようになったタスクの一つです。
このタスクにおける代表的なCNNモデルは、FCN、SegNet、U-Net、PSPNet、DeepLab familyなどがあります。一般的なセマンティックセグメンテーションタスクにおいて、現状のデファクトスタンダード(まずはこれ試そうモデル)は?と言われれば、2018年に発表されたモデルですが、DeepLab V3+なんじゃないかなと思っています(あくまで個人的な考えです)。
このCNNを学習させるためには、モデル出力と正解画像との差(誤差)を計算する必要があるのですが、ここで登場するのがloss関数です。個人的によく思うのですが、このloss関数の選択はどのくらい収束速度、精度に影響を及ぼすのでしょうか?(このブログでは実際に評価までしませんが..)

f:id:knowwell-livewell:20220130162908p:plain
2クラスのセマンティックセグメンテーション

2. セマンティックセグメンテーションで利用されるloss関数

f:id:knowwell-livewell:20220130125757p:plain
https://github.com/JunMa11/SegLossより引用

セマンティックセグメンテーションのloss関数に特化したサーベイ論文として、2020年の「A survey of loss functions for semantic segmentation*1」があります。ほかにも、セグメンテーションのloss関数をリストしているgithubのページ*2もあります。これらを参考にしながら、以降ではCross Entropy Loss、Focal Loss、Dice Loss、Tversky Loss、Boundary Loss、HD Lossについて紹介していきます(本記事はCross Entropy LossとFocal Loss)。

①Cross Entropy Loss

クラス分類でお馴染みのLoss関数です。セマンティックセグメンテーションもピクセル単位でクラス分類するタスクですので、こちらのLoss関数は一般的によく使用されていると思います。Cross Entropy Lossは2つの確率分布(教師データの分布とモデル出力の分布)の差を表しているので、図中においてDistribution-based Lossにカテゴリー分けされています。数式は以下になります。

f:id:knowwell-livewell:20220130133020p:plain:w300
Cross Entropy Loss

p(x)が正解ラベル、q(x)が推測結果です。これは正解ラベルをone-hotベクトルとすると、正解ラベルが1の推測結果(確率)の対数の負を計算すればよいということになります。もっと言うと、正解クラスの推測結果(確率)の対数の負を計算すればよいということですね。(「正解ラベルが1の推測結果だけ計算すればよい」と考えてしまうと、「セマンティックセグメンテーションの正解画像の1の値の領域だけ計算すればいいのか」みたいな間違った考えをしてしまいかねないので気を付けましょう。)ちなみにPytorchのtorch.nn.CrossEntropyLossはTargetとしてone-hotベクトルではなく、正解クラスのインデックスを与えるので、注意が必要です。そのため、one-hot表現で正解ラベルが(0,0,1)の場合は2を入れないといけません。これは結構ややこしいポイントだと思います。以下の記事がtorch.nn.CrossEntropyLossの使い方として(公式リファレンスと併せて見ると)分かりやすいと思います。([PyTorch]CrossEntropyLossを数式入りでちょっと理解する - Qiita
話がかなり逸れてしまいました。とにかくセマンティックセグメンテーションにおけるCross Entropy Lossをまとめると、以下の図のようになります。

f:id:knowwell-livewell:20220130173725p:plain
Cross Entropy Lossの計算方法(Pytorchに準拠)
②Focal Loss

「不均衡データに対しても学習がうまくいくように」という意図で設計されたLoss関数です。セマンティックセグメンテーションにおける不均衡データとは、例えば、画像に占める対象物体の面積がとても小さいようなデータです。実際には、Cross Entropyへの重み付けにより、正解クラスの推測結果(確率)が高い、簡単なピクセルのLossへの寄与を小さくすることで、推測結果(確率)が低い、難しいピクセルにフォーカスするようにします。図中ではCross Entropyから「Down-weight easy examples」となっていますよね。数式は以下になります。

f:id:knowwell-livewell:20220130144633p:plain
Focal Loss

ここで、 p_tは正解クラスの推測結果(確率)です。 (1-p_t)^\gammaのおかげで、推測結果(確率)が大きい場合はその値が小さくなることが分かります。ちなみに、 \gammaが0の場合は通常のCross Entropy Lossになります。提案論文中では \gamma=2と設定しているようです。以下のサイトにPytorch実装があります。
https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/focal_loss.py



続きます...
(間違っていたらコメントでお教えいただけますと助かります。)