連合学習のフレームワークであるFlowerには、VCE(Virtual Client Engine)と呼ばれる仮想クライアントエンジンの機能があり、この機能を活用することで、本来はサーバと複数のクライアント環境を準備する必要のある連合学習のシミュレーションを1台の物理デバイス上で実現することが可能となる。
本資料では、このFlowerVCEの概要、活用の方法について解説する。
FlowerVCEのアーキテクチャー
FlowerVCEは、Pythonの分散並列処理のフレームワークであるRayのフレームワークを使って、複数の仮想クライアントをエミュレートし、多数のクライアントでの連合学習処理を実現する機能である。
FlowerVCEでは、構成に必要なクラスやメソッドを実装することでカスタマイズ可能で、strategyと呼ばれるクラスを使って、様々なFederatedLearning戦略(FedAvg、FedProxなど)を定義して検証することが可能である。
処理に必要なクラス、メソッド、関数の構成、機能
FlowerVCEの処理に必要なクラス、メソッド、関数の構成図を以下に記載する.
以下にクラス、メソッド、関数の機能について説明する.
-
start_simulation[Method]
FederatedLearningのシミュレーションを開始するためのメソッド
以下のような引数を渡してFlowerVCEを起動する- client_fn:クライアントを構築するためのコールバック関数を指定
- num_clients:FederatedLearningに参加するクライアント数
- config:サーバーの設定を指定、fl.server.ServerConfigでnum_rounds(ラウンド数)などを指定
- strategy:FederatedLearningの戦略を指定
- client_resources:クライアントに割り当てるリソース(CPU、GPU)を指定
-
client_fn[Function]
クライアントを構築するための関数
この関数はFederatedLearningのクライアントがサンプルされる時に実行される。主に以下のような処理を実装して構成する- 対応するクライアント用のデータセットを読み込む
- FlowerClientクラスを生成、構築する
-
FlowerClient[Class]
FederatedLearningの各クライアント用のClass定義
以下のような各メソッドでクライアントの訓練、検証用の処理を定義する- _init_:クラスの初期化メソッド
- set_parameters:サーバーから送信されたパラメータを受け取り、モデルに設定する
- get_parameters:ローカルモデルからパラメータを抽出し、サーバに返す
- fit: サーバーから送信されたパラメータを使ってローカルデータセットでモデルを訓練し、その結果をサーバーに返す
- evaluate:サーバーから送信されたパラメータを使ってローカルの検証データセットでモデルを評価し、その結果をサーバーに返す
-
strategy[Class]
FederatedLearningの戦略を定義
以下のような引数を渡してFederatedLearningの戦略やクライアントからのデータ集約処理を定義する- fraction_fit:訓練に利用するクライアントの割合を指定
- fraction_evaluate:評価に利用するクライアントの割合を指定
- evaluate_fn:グローバル評価関数の設定、ラウンド毎に中央集権的なテストデータセットを使って評価する関数を指定
-
evaluate_fn[Function]
各ラウンドの終了時にstrategy戦略によって実行されるコールバック用の評価関数
主に以下のような処理を実装して構成する- モデルパラメータのロード
- 中央集権的なテストデータセットを使った評価処理の実行
- 損失や精度の結果の保存
-
train[Function]
クライアント用の訓練関数--FlowerClient.fitから呼び出される -
test[Function]
クライアント用の評価関--FlowerClient.evaluateから呼び出される -
Model[Class]
FederatedLearningで使うmodelクラス(CNN、GCN等) -
Dataset[data]
FederatedLearningでクライアントが利用するtrainデータセット、valデータセットやFlowerVCEのevaluate_fnが中央集権的な評価用に使うtestデータセット
処理の流れ
FlowerVCEのシミュレーション処理の流れは下記のとおり。
FlowerVCEを使った実装の手順
-
ライブラリのインストールとインポート
- プログラム処理に必要なライブラリのインストールとインポート処理を準備
(Pythonを使った一般的なDeepLearingのプログラムと同様)
- プログラム処理に必要なライブラリのインストールとインポート処理を準備
-
定数、SEED値、Device等の設定
- プログラム処理に必要な定数やSEED値(randam.seed、np.random.seedやtorch関係のseed等)の設定
- GPU(cuda)のアサインメント処理
(Pythonを使った一般的なDeepLearingのプログラムと同様)
-
データセットの準備
- 訓練(train)データ、検証(val)データ、テスト(test)データ
- FederatedLearningで使うクライアント数分にランダム分割したデータセットを準備する
- FlowerVCEでは、Strategyを定義してevaluate_fn(コールバック用の評価関数)を使った中央集権的評価を利用することが出来るので、その場合はevaluate_fnの1つのテスト(test)データとクライアント数分に分割した訓練(train)データ、検証(val)データを準備する
-
モデルの定義
- FederatedLearningに使うModelClassを定義する(PytorchやPyGのCNN、GCN等)
(Pythonを使った一般的なDeepLearingのプログラムと同様)
- FederatedLearningに使うModelClassを定義する(PytorchやPyGのCNN、GCN等)
-
tarin関数、test関数、損失関数、最適化関数の実装
- FederatedLearningに使うtarin関数、test関数、損失関数、最適化関数を作成
(Pythonを使った一般的なDeepLearingのプログラムと同様)
- FederatedLearningに使うtarin関数、test関数、損失関数、最適化関数を作成
-
FlowerClientの定義
FlowerVCEを使って各クライアントが処理する訓練、検証用の以下の処理(メソッド)を持ったFlowerClientClassを定義する- _init_():クラスの初期化メソッド、訓練と検証用のデータローダーを受け取りモデルを初期化する
- set_parameters():サーバーから送信されたパラメータを受け取り、それらを使用してクライアントのローカルモデルのパラメータをモデルに設定する
- get_parameters():ローカルモデルからパラメータを抽出し、サーバに返す
- fit():サーバーから送信されたパラメータを使ってローカルデータセットでモデルを訓練し、その結果をサーバーに返す
- evaluate():サーバーから送信されたパラメータを使ってローカルの検証データセットでモデルを評価し、その結果をサーバーに返す
-
クライアントを構築するための関数(client_fn)の実装
- FederatedLearningのクライアントを作成するためのコールバック関数(client_fn)を作成
- FowerVCEではシミュレータからクライアントID(cid)が渡されるので、そのクライアント用のデータセットパーティションを読み込み、各クライアント用に分割されたtrainデーターローダーとvalデータローダーを引数にしてFlowerClientClassを作成する
-
グローバルモデルの評価関数(evaluate_fn)の実装
- 各ラウンドの終了時にstrategy戦略によって実行されるコールバック用の評価関数(evaluate_fn)を作成
- モデルパラメータをロードし、中央集権的評価を行うためのtestデータローダーを使って評価処理を行い、Loss値や精度(MSE、R2)などを計算する
-
FederatedLearningの戦略を定義
- strategyを定義してFlowerVCEで実行するFederatedLearningの戦略を実装する
- FederatedLearningの戦略(アルゴリズム)としてはFedAvgやFedProxなど様々なものをstrategyとして宣言することができる
- 訓練や評価に利用するクライアントの割合(fraction_fit、fraction_evaluate)などを引数で指定する
- 各ラウンドの終了時にstrategy戦略によって実行されるコールバック用の評価関数(evaluate_fn)もstrategyの引数として指定する
-
FederatedLearningのシミュレーションの開始
start_simulationメソッドを実装し、以下のような引数を渡してFlowerVCEのシミュレーションを実行することができる- client_fn:クライアントを構築するためのコールバック関数を指定
- num_clients:FederatedLearningに参加するクライアント数
- config:サーバーの設定を指定、fl.server.ServerConfigでnum_rounds(ラウンド数)などを指定
- strategy:FederatedLearningの戦略を指定
- client_resources:クライアントに割り当てるリソース(CPU、GPU)を指定
FlowerVCEを使ったサンプルプログラム
FlowerVCEを使ったPytorch版のサンプルプログラムについては、公式サイト内の以下のページで以前は紹介されていました。
しかし、現在(2024/11/06)では、サンプルプログラムへのリンク(「Simulation examples」の「PyTorch Simulation: 100 clients collaboratively train a CNN model on MNIST.」)部分からのGitHubページ参照が「404 - page not found」となっています。
以下に以前に掲載されていたFlowerVCEを使ったPyTorch Simulationの日本語訳解説付きのサンプルプログラムを置いてありますので参照ください。
本プログラムは、GoogleColab環境で動作確認済です。
※だだし、Flower公式サイトから参照できるオリジナルが削除されているので、Flowerのバージョンアップ等によって動作しなくなる可能性もあることをご了承ください。
参考
https://flower.ai/
https://flower.ai/docs/framework/how-to-run-simulations.html
変更履歴
2024/11/06 「FlowerVCEを使ったサンプルプログラム」の記載追加