1.Rax:JAXで使えるランキングシステム用ライブラリ(1/2)まとめ
・ランク付けは、様々な領域にまたがる中核的な問題で機械学習技術であるLTRが良く使用される
・比較的新しい機械学習フレームワークであるJAXで動作するLTRライブラリは存在しない
・JAXで関連システムと柔軟に連携な可能なLTR用ライブラリであるRaxを新たに開発して公開
2.Raxとは?
以下、ai.googleblog.comより「Rax: Composable Learning-to-Rank Using JAX」の意訳です。元記事は2022年8月11日、Rolf JagermanさんとHonglei Zhuangさんによる投稿です。
アイキャッチ画像はオーストリアのRax山脈でクレジットはPhoto by Lukas Baumann on Unsplash
ランキングは、検索エンジン、推薦システム、質問回答など、様々な領域にまたがる中核的な問題です。そのため、研究者はしばしば、一連の教師あり機械学習技術であるLTR(learning-to-rank)を利用しています。LTRは(一度に1つの項目ではなく)項目のリスト全体の有用性を最適化します。
最近注目されているのは、LTRとディープラーニングを組み合わせることです。既存のライブラリ、特にTF-Rankingは、研究者や実務家が仕事でLTRを使用するために必要なツールを提供しています。しかし、既存のLTRライブラリの中で、自動微分、GPU/TPUデバイスへのJITコンパイルなどを構成する関数変換の拡張可能なシステムを提供する新しい機械学習フレームワークであるJAXでネイティブに動作するものはありません。
本日、JAXエコシステムにおけるLTR用ライブラリであるRaxを紹介出来る事を嬉しく思います。Raxは、数十年にわたるLTR研究をJAXエコシステムにもたらし、JAXを様々なランキング問題に適用し、ランキング技術をJAX上に構築された最近の深層学習の進歩(例えば、T5X)と組み合わせることが可能になります。
Raxは、最先端のランキング損失、多くの標準的なランキング指標、ランキング指標の最適化を可能にする関数変換のセットを提供します。これらの機能はすべて、JAXのユーザーにとって見慣れた、使いやすい、文書化されたAPIで提供されています。技術的な詳細については、私達の論文をご覧ください。
Raxを用いたLearning-to-Rank
RaxはLTR問題を解くために設計されています。この目的のために、Raxは他の機械学習問題で一般的な個々のデータポイントのバッチではなく、リストのバッチに対して操作する損失関数と指標関数を提供します。このようなリストの例としては、検索エンジンの検索語から得られる複数の潜在的な検索結果があります。
下図は、Raxのツールを使って、ランキング・タスクでニューラルネットワークを学習する方法を示しています。この例では、緑色の項目(B, F)は非常に関連性が高く、黄色の項目(C, E)はやや関連性が高く、赤色の項目(A, D)は関連性が低いことを示しています。
ニューラルネットワークを用いて各項目の関連性スコアを予測し、これらの項目をこのスコアでソートしてランキングを作成します。Raxランキング損失は、ニューラルネットワークを最適化するためにスコアのリスト全体を組み入れ、項目の全体的なランキング品質を向上させます。
確率的勾配降下を数回繰り返すと、ニューラルネットワークは、結果として得られるランキングが最適となるように、項目にスコアを付けることを学習します。つまり、関連性のある項目はリストの最上部に、関連性のない項目は最下部に配置されます。
Raxを使って、ランキングタスクのためのニューラルネットワークを最適化
緑色の項目(B, F)は非常に関連性が高く、黄色の項目(C, E)はやや関連性が高く、赤色の項目(A, D)は関連性がありません。
近似的に指標を最適化
ランキングの質は、一般にNDCG(Normalized Discounted Cumulative Gain)などのランキング指標を用いて評価されます。LTRの重要な目的は、ニューラルネットワークを最適化し、ランキング指標で高得点を得られるようにすることです。
しかし、NDCGのようなランキング指標は不連続かつ平坦であることが多いため、確率的勾配降下法を直接適用することができないという課題があります。Raxは、最新の近似技術により、ランキング指標の微分可能な代替物を生成し、勾配降下法による最適化を可能にしています。下図は、rax.approx_t12nというRax独自の関数変換を用いて、NDCGメトリックスを近似的かつ微分可能な形に変換した例です。
Raxの近似技術を使って、NDCGランキングメトリックを微分可能で最適化可能なランキング損失(approx_t12nとgumbel_t12n)に変換
3.Rax:JAXで使えるランキングシステム用ライブラリ(1/2)関連リンク
1)ai.googleblog.com
Rax: Composable Learning-to-Rank Using JAX
2)research.google
Rax: Composable Learning-to-Rank using JAX
3)github.com
google / rax