DeepCTRL:ニューラルネットワークにルールを教えて制御する試み(3/3)

AI

1.DeepCTRL:ニューラルネットワークにルールを教えて制御する試み(3/3)まとめ

・元データが必ずしもルールに従うとは限らないのでルールの効果は元データに依存
・DeepCTRLは再トレーニングせずにデータに合わせてルールの強さを変更可能
・DeepCTRLは既知の原理を用いて信頼性を向上しルール強度を用いて領域適応が可能

2.DeepCTRLの検証

以下、ai.googleblog.comより「Controlling Neural Networks with Rule Representations」の意訳です。元記事は2022年1月28日、Sungyong SeoさんとSercan O. Arikさんによる投稿です。

アイキャッチ画像のクレジットはPhoto by Chris Chow on Unsplash

医療分野における分布のズレへの対応

ルールの強さは、データセットによって異なる場合があります。例えば、疾病予測では、心血管疾患と高血圧の相関は、若い患者よりも高齢の患者の方が強いと言われています。このように、タスクは共通であっても、データセット間でデータの分布やルールの有効性が異なる場合、DeepCTRLはαを制御することで、分布のずれ(distribution shifts)に対応することができます。

以下では、心血管疾患のデータセットを用いて、心血管疾患の有無を予測するタスクに着目して調査を行いました。収縮期血圧の高さが心血管疾患と強く関連することが知られていることから、「収縮期血圧が高ければリスクが高い」というルールを考えます。

これに基づいて、患者を2つのグループに分けます。
(1)高血圧だが病気ではない、または低血圧だが病気である、という異常群
(2)高血圧かつ病気である、または低血圧だが病気ではない、という通常群

以下に、元データが必ずしもルールに従うとは限らないこと、したがって、ルールを組み込む事による効果は元データに依存する可能性があることを示します。

分類精度を示すテストクロスエントロピー(小さいほど良い値)とルールの強さの関係を、元データと、通常/異常の比率を変化させた/ターゲットデータセットについて、以下に可視化します。α → 1に近づくにつれて単調に誤差が増加している原因は、元データを正確に反映出来ていないルールをより厳密に適用するようになるためです。


通常/異常の比率が0.30の元データセットについて、クロスエントロピーとルールの強度の関係をテストしています

そこで、ターゲット1、2、3と呼ぶ3つの領域別データセットを用いて、学習したモデルをターゲット領域に転移する際の誤差を低減します。ターゲット1では、患者の大半が通常群であるため、αを大きくするとルールベース表現の比重が大きくなり、結果として誤差が単調に減少していくことがわかります。


上記と同じ条件ですが、ターゲット1データセットの通常/異常の比率は0.77です

ターゲット2データセットとターゲット3データセットで通常患者の割合が減少した場合、最適なαは0と1の間の中間値となります。


上記と同じ条件ですが、ターゲット2データセットの通常/異常の比率は0.50です

上記と同じ条件ですが、ターゲット3データセットの通常/異常の比率は0.40です

結論

ルールから学習する事は、解釈可能で堅牢かつ信頼性の高いディープニューラルネットワーク(DNN)を構築するために極めて重要です。私たちは、データから学習するDNNにルールを組み込むための新しい方法論であるDeepCTRLを提案しました。

DeepCTRLは、再トレーニングを行うことなく、推論時にルールの強さを制御することが可能です。任意のルールを意味のある特徴表現に統合するために、摂動に基づく新しいルールエンコーディング手法を提案します。また、DeepCTRLの3つのユースケースとして、既知の原理を用いた信頼性の向上、ルール候補の検討、ルール強度を用いた領域適応(domain adaptation)を示しました。

謝辞

Jinsung Yoon, Xiang Zhang, Kihyuk Sohn そして Tomas Pfisterの貢献に大きく感謝します。

3.DeepCTRL:ニューラルネットワークにルールを教えて制御する試み(3/3)関連リンク

1)ai.googleblog.com
Controlling Neural Networks with Rule Representations

2)arxiv.org
Controlling Neural Networks with Rule Representations

タイトルとURLをコピーしました