1.Rax:JAXで使えるランキングシステム用ライブラリ(2/2)まとめ
・RaxはNDCGを近似してノイズを加える事で微分と局所最適化問題を解決した
・他のJax関連ツールとスムーズに連携できるようにRaxは設計されている
・JAXのT5XでRaxをどのように使用できるかを示すサンプルも用意されている
2.Raxの性能
以下、ai.googleblog.comより「Rax: Composable Learning-to-Rank Using JAX」の意訳です。元記事は2022年8月11日、Rolf JagermanさんとHonglei Zhuangさんによる投稿です。
アイキャッチ画像のクレジットはPhoto by Lukas Baumann on Unsplash
まず、NDCG(Normalized Discounted Cumulative Gain)指標(緑色)は平坦かつ不連続であり、確率的勾配降下法を用いて最適化することが困難であることに着目してください。
この指標にrax.approx_t12n変換を適用すると、近似的な指標であるApproxNDCGが得られ、これは明確に定義された勾配なので微分可能になります(赤色)。
しかし、多くの局所最適点(局所的には最適だが、大域的に見ると最適でない点)を持つ可能性があり、学習プロセスが行き詰まることがあります。損失がこのような局所最適になると、確率的勾配降下のような学習手順では、ニューラルネットワークをさらに改善することが難しくなります。
この問題を解決するために、rax.gumbel_t12n変換を使用して、近似NDCGのgumbel版を取得することができます。
このgumbel版はランキングスコアにノイズを含めるため、損失がゼロでないコストを発生させる可能性のあるさまざまなランキングをサンプリングすることになります。(青字)
この確率的な処理により、損失は局所最適点から逃れられるかもしれず、そのため、ランキング指標でニューラルネットワークを学習する際に、より良い選択となることが多いです。Raxは設計上、近似変換とgumbel変換を、ライブラリが提供するすべての指標で自由に使用することができます。実際、独自の指標を実装し、それを変換してgumbel近似版を得ることも可能であり、余分な労力をかけずに最適化を行うことができます。
JAXエコシステムにおけるランキング
Raxは、JAXエコシステムとうまく統合できるように設計されており、他のJAXベースのライブラリとの相互運用性を優先しています。例えば、JAXを使用する研究者の一般的なワークフローは、データセットをロードするためにTensorFlow Datasets、ニューラルネットワークを構築するためにFlax、そしてネットワークのパラメータを最適化するためにOptaxを使用することです。
これらのライブラリはそれぞれ他のライブラリとの相性が良く、これらのツールの組み合わせがJAXでの作業を柔軟かつ強力なものにしています。ランキングシステムの研究者や実務家にとって、JAXのエコシステムにはこれまでLTR(learning-to-rank)の機能が欠けており、Raxはランキング損失と指標のコレクションを提供することでこのギャップを埋めているのです。
私たちは、Raxをjax.jitやjax.gradなどの標準的なJAX変換や、Flax や Optax などの様々なライブラリでネイティブに機能するように慎重に構築しています。つまり、ユーザーは、お気に入りのJAXとRaxのツールを一緒に自由に使うことができるのです。
T5を使ったランキング
T5のような巨大言語モデルは自然言語タスクで素晴らしい性能を発揮しています。しかし、検索や質問回答のようなランキングタスクでランキング損失をどのように活用し、性能を向上させるかはまだ十分に検討されていなません。
Raxを使えば、このポテンシャルを十分に引き出すことができます。RaxはJAXファーストで書かれているため、他のJAXライブラリとの統合が容易です。T5XはJAXエコシステムにおけるT5の実装であるため、Raxはシームレスに動作させることができます。
このため、RaxをT5Xでどのように使用できるかを示すサンプルを用意しました。ランキング損失と指標を組み込むことで、ランキング問題に対してT5を微調整することが可能になり、ランキング損失でT5を強化することで、大幅な性能向上が得られることを示す結果が得られています。
例えば、MS-MARCO QNA v2.1ベンチマークでは、T5-Baseモデルを、pointwise sigmoidクロスエントロピーの代わりにRaxlistwise softmaxクロスエントロピー損失を使って微調整した結果、NDCG+1.2%、MRR+1.7%を達成することが出来ました。
ランキング損失(青字はsoftmax)と非ランキング損失(赤字はポイントワイズシグモイド)を用いてMS-MARCO QNA v2.1でT5-Baseモデルの微調整した場合の比較
まとめ
全体として、RaxはJAXライブラリの成長するエコシステムに新たに加わったものです。Raxは完全にオープンソースで、github.com/google/raxで誰でも利用することができます。より技術的な詳細は、私たちの論文にも記載されています。
(1)FlaxとOptaxを用いたニューラルネットワークの最適化
(2)異なる近似指標最適化技術との比較
(3)RaxとT5Xの統合方法
など、githubリポジトリに含まれるサンプルをぜひご覧ください。
謝辞
このプロジェクトは、Google社内の多くの協力者によって実現されました。Xuanhui Wang, Zhen Qin, Le Yan, Rama Kumar Pasumarthi, Michael Bendersky, Marc Najork, Fernando Diaz, Ryan Doherty, Afroz Mohiuddin, そして Samer Hassan。
3.Rax:JAXで使えるランキングシステム用ライブラリ(2/2)関連リンク
1)ai.googleblog.com
Rax: Composable Learning-to-Rank Using JAX
2)research.google
Rax: Composable Learning-to-Rank using JAX
3)github.com
google / rax