FedJAX:連合学習のシミュレーションをJAXで容易に実行(1/2)

プライバシー

1.FedJAX:連合学習のシミュレーションをJAXで容易に実行(1/2)まとめ

・連合学習を使用するとデータをスマホからクラウドに送信せずともモデルトレーニングが可能
・連合学習を研究する際の使いやすさを重視したJAXベースのFedJAXをオープンソースで公開
・FedJAXには文字認識、文字予測、単語予測を行うためのデータセットが付属している

2.FedJAXとは?

以下、ai.googleblog.comより「FedJAX: Federated Learning Simulation with JAX」の意訳です。元記事の投稿は2021年10月4日、Jae Hun RoさんとAnanda Theertha Sureshさんによる投稿です。

Federated Learningは連合学習の訳が定着してきたように感じたので連合学習としました。「連合(Federated)」と聞くとクライアント側が自発的に連携するイメージがありますが、現在の仕様はサーバー側から駆動するので「統合(Integrated)」の方がしっくりくる感がありますが、発展途上の技術なので最終的には真の連合型になっていくのでしょうか?

Federatedな学習を行っていそうなアイキャッチ画像のクレジットはPhoto by Yannis H on Unsplash

連合学習(Federated learning)とは、トレーニングデータを分散させたまま、多くのクライアント(つまり、モバイルデバイスまたは組織内の全端末など)が中央サーバーの指揮の下でモデルを共同でトレーニングする機械学習セッティングの事です。

たとえば、連合学習を使用すると、モバイルデバイス内のユーザデータをクラウドに送信せずとも仮想キーボード用の言語モデルをトレーニングできます。

連合学習アルゴリズムは、最初にサーバーでモデルを初期化し、トレーニング時に毎回3つの主要ステップを完了することでこれを実現します。

(1)サーバーが全クライアントから一部を選択し、選択したクライアントにモデルを送信します。
(2)これらの選択されたクライアントは、クライアント内に保持するローカルデータでモデルをトレーニングします。
(3)トレーニング後、クライアントは更新されたモデルをサーバーに送信し、サーバーはそれらを集約します。


4つのクライアントを使用した際の連合学習アルゴリズムの動作例

近年、プライバシーやセキュリティ意識が高まっているため、連合学習は特に活発な研究分野になっています。アイデアをコードに簡単に変換し、すばやく反復し、既存のベースラインを比較および再現できることは、このような急成長している分野にとって重要です。

これを踏まえて、研究目的での使いやすさを重視した連合学習シミュレーション用のJAXベースのオープンソースライブラリであるFedJAXを紹介できることを嬉しく思います。

FedJAXは、フェデレーションアルゴリズム、事前にパッケージ化されたデータセット、モデルとアルゴリズムを実装するためのシンプルな土台、および高速なシミュレーション速度により、研究者がフェデレーションアルゴリズムの開発と評価をより迅速かつ簡単に行えるようにすることを目指しています。

本投稿では、FedJAXのライブラリ構造と内容について説明します。 TPUでは、FedJAXを使用して、MNISTの拡張であるEMNISTデータセットでフェデレーション平均(federated averaging)を使用してモデルを数分でトレーニングし、StackOverflowデータセットを標準のハイパーパラメーターで約1時間でトレーニングできることを示します。

ライブラリの構造

使いやすさを念頭に置いて、FedJAXは新しい概念を取り込む事は限定的にしています。

FedJAXで記述されたコードは、学術論文で新しいアルゴリズムを記述するために使用される擬似コードに似ているため、簡単に使い始めることができます。さらに、FedJAXは連合学習の構成要素を提供しますが、ユーザーは、トレーニング全体を適度に高速に保ちながら、NumPyとJAXのみを使用してこれらを最も基本的な実装に置き換えることができます。

同梱されているデータセットとモデル

連合学習の研究が進んでいる現在、画像認識、言語モデリングなど、一般的に使用されるさまざまなデータセットとモデルがあります。これらの多くのデータセットとモデルをFedJAXですぐに使用できるため、前処理されたデータセットとモデルを一から用意する必要はありません。

これにより、異なるフェデレーションアルゴリズム間の有効な比較が促進されるだけでなく、新しいアルゴリズムの開発も加速されます。

現在、FedJAXには以下のデータセットとサンプルモデルがパッケージ化されています。

・EMNIST-62
文字認識タスク用
・Shakespeare
次に出現する文字の予測タスク
・Stack Overflow
次に出現する単語の予測タスク

これらの標準設定に加えて、FedJAXは、ライブラリの他の部分で使用できる新しいデータセットとモデルを作成するためのツールを提供します。

最後に、FedJAXには、分散データ環境で共有モデルをトレーニングするためのフェデレーション平均およびその他のフェデレーションアルゴリズムの標準実装が付属しています。

例えば、適応型フェデレーションオプティマイザー(adaptive federated optimizers)、非依存型フェデレーション平均(agnostic federated averaging)、Mimeなどその他のフェデレーションアルゴリズムの標準実装が付属しています。

3.連合学習のシミュレーションをJAXで容易に実行(1/2)関連リンク

1)ai.googleblog.com
FedJAX: Federated Learning Simulation with JAX

2)arxiv.org
FedJAX: Federated learning simulation with JAX

3)fedjax.readthedocs.io
FedJAX documentation

4)github.com
google / fedjax

タイトルとURLをコピーしました