CIW:ノイズの多いラベルを使ってディープニューラルネットを訓練する新手法(2/3)

AI

1.CIW:ノイズの多いラベルを使ってディープニューラルネットを訓練する新手法(2/3)まとめ

・標準的なモデルは各サンプルに一律の重みを割り当てるためノイズに過剰適合してしまう
・CIWで学習したモデルはノイズの影響を回避して良好な判定境界へ収束することが可能
・サンプルだけでなくクラスラベルに対して重要度の重みを割り当てる拡張を行った

2.CICWとは?

以下、ai.googleblog.comより「Constrained Reweighting for Training Deep Neural Nets with Noisy Labels」の意訳です。元記事は2022年2月28日、Abhishek KumarさんとEhsan Amidさんによる投稿です。

アイキャッチ画像のクレジットはPhoto by Elyas Pasban on Unsplash

2次元のデータセットを使って判定境界を図解

本手法の動作を説明する例として、Two Moonsデータセットのノイズ版を考えてみます。Two Moonsデータセットは、2種の半月クラスからランダムにサンプリングされたデータから構成されています。

ラベルの30%を破損させ、その上で多層パーセプトロンネットワークを訓練し、二値分類を行います。モデルの学習には、標準的な二値クロスエントロピー損失と、モメンタムオプティマイザを用いたSGDを使用すします。下図(左)では、データポイントと、2つのクラスを分離する判定境界を点線で可視化しています。上半月に赤、下半月に緑で示した点は、ノイズ入りデータポイントを示しています。

2値クロスエントロピー損失で学習した比較対象モデルは、各ミニバッチ内のサンプルに一律の重みを割り当てるため、最終的にノイズの多いサンプルに過剰適合してしまい、判定境界がうまくいきません(下図中段)CIW法は、各ミニバッチのインスタンスに対応する損失値に基づいて重みを付け直します(下図右)。

判定境界の正しい側に位置するクリーンなサンプルには大きな重みを割り当て、より高い損失値を発生させるノイズの多いサンプルの影響を減衰させます。ノイズの多いサンプルの重みを小さくすることで、モデルの過剰適合を防ぎ、CIWで学習したモデルがラベルノイズの影響を回避して良好な判定境界へ収束することを可能にします。


Two Moonsデータセットにおける比較対象手法と私達が提案したCIW手法の学習の進行に伴う決定境界の説明
左:望ましい決定境界を持つノイズの多いデータセット
中央:クロスエントロピロスを用いた標準的な学習における決定境界
右:CIW法による学習
(中)と(右)の点の大きさは、ミニバッチでこれらの例に割り当てられた重要度の重みに比例しています。

制約条件付きクラス再重み付け(Constrained Class reWeighting)

サンプルの再重み付けは、損失の大きいサンプルに低い重みを割り当てます。私達はこの直観をさらに拡張し、すべての取りうるクラスラベルに対して重要度の重みを割り当てます。

標準的な学習では、クラスの重みにワンホットラベルベクトルを用い、ラベル付けされたクラスには1を、それ以外のクラスには0という重みを割り当てます。

しかし、誤ラベルが付与されている可能性があるサンプルに対しては、真のラベルである可能性を持つクラスに対して0以外の重みを割り当てることが合理的です。このクラス重みは、制約付き最適化問題群の解として得られます。

ここで、ワンホット分布を持つラベルからのクラス重みの偏差は、選択した逸脱(divergence of choice)によって測定され、ハイパーパラメータによって制御されます。

ここでも、いくつかの逸脱度について、クラス重みの簡単な公式を得ることができます。これをConstrained Instance and Class reWeighting(CICW)と呼ぶことにします。

また、この最適化問題の解は、逸脱を全変動距離とすれば、静的ラベルブートストラッピング(ラベルスムージングとも呼ばれます)に基づく初期に提案された手法をリカバーすることができます。これにより、静的ラベルブートストラッピングという一般的な手法に理論的な見通しを与えることができます。

3.CIW:ノイズの多いラベルを使ってディープニューラルネットを訓練する新手法(2/3)関連リンク

1)ai.googleblog.com
Constrained Reweighting for Training Deep Neural Nets with Noisy Labels

2)arxiv.org
Constrained Instance and Class Reweighting for Robust Learning under Label Noise

3)github.com
google-research/ciw_label_noise/

タイトルとURLをコピーしました