LoginSignup
8
7

More than 1 year has passed since last update.

連合学習 (Federated Learning) とSTADLEについて

Last updated at Posted at 2022-05-28

はじめに

この記事は、最近巷で耳にするようになった連合学習(Federated Learning)を実際に手元のPCで行い、AI学習の先端技術に触れてみようと思います。
皆さんは連合学習というキーワードをご存じでしょうか。何かしらで耳にされてこの記事を読んでいただいていると思いますが、サラッとおさらいを述べて、本題(ハンズオン)に入りたいと思います。

連合学習(Federated Learning)について

AIモデルの構築において従来のようにデータを1か所に集約することなく、分散された複数の学習環境によって並列的に学習を行う手法です。
各学習環境にあるデータを用いて各々が学習を行い、それぞれで生成されるニューラルネットワークのノードを繋ぐウェイト値を集約して統合することで、それぞれの学習環境のデータ量が少量であっても全ての学習環境で用いられたデータを使用して学習を行う場合と同等のモデルを構築することができます。

image.png

これにより、並列的に学習を進められるので単一の学習環境で行う学習に比べ効率的に学習が行えるうえに、各学習環境間で データを共有することなく 統合モデルを構築出来るので、企業間をまたいだ連合学習に参加しても秘匿情報を他社に共有することなく、数社のデータを使用した業界共通のグローバルモデルを構築することが可能です。

連合学習の大まかな手順

[STEP01]ベースとなるモデルのアップロード

ベースモデルのアップロードとは、連合学習を行うモデルの構造(アーキテクチャ)をアグリゲーターに登録することを意味します

[STEP02]学習コードの実行

各学習環境で学習コードを実行することによって生成されるウェイトの値を集約します。

動作環境の前提

連合学習におけるモデルの集約(アグリゲーション)を行う部分はSTADLEプラットフォームを使用します。
stadle.png

したがって、この部分は特別なセットアップは不要で、下記のURLにアクセスしてサインアップすればOKです。

上記のサイトでモデルのアグリゲーターを起動して、手元のPCで実行するMLアプリケーションから接続をします。
Pythonで記述されたMLアプリケーションを実行しますので、Pythonの実行環境を用意してください。

今回はMNISTのサンプルモデル(pytorch版)を使用してデモを行います

コードの構成

mnist_example/
  ├── config/
  │   └── config_agent.json ← 学習コードの設定ファイル
  ├── models/
  │   └── samplenet.py ← モデルアーキテクチャ(ベースモデル定義)
  ├── mnist_admin_agent.py ← ベースモデルのアップロード
  └── mnist_ml_agent.py ← 学習コード

ローカル環境(MLエージェント)のセットアップ

[STEP01]仮想環境の準備及びアクティベート

// conda環境構築
conda create -n ENVCLIENT python=3.8

// 作成したconda環境へ入る
conda activate ENVCLIENT

[STEP02]必要モジュールインストール

pip install --upgrade pip

// stadle-clientダウンロード
pip install stadle-client

// ML用モジュールダウンロード
pip install torch torchvision

[STEP03] stadle.aiでプロジェクト作成

ここまでで、学習環境に必要なコードが揃いましたので、
次にstadle.aiでプロジェクトを作成します。

プロジェクトの作成

image.png

image.png

アグリゲータの起動

アグリゲータのプラスマークを押してアグリゲータの起動ボタンをしてください。リフレッシュするとAggregator countが1に変わります

image.png

DashBoard画面でIPアドレスとポート番号を取得、configの情報を書き換える!

image.png

image.png

mnist_example/config/config_agent.json
{
  "model_path": "./data/agent",
  "local_model_file_name": "lms.binaryfile",
  "semi_global_model_file_name": "sgms.binaryfile",
  "state_file_name": "state",
  "aggr_ip": "<ここにIPアドレスを書き込む>",
  "reg_port": "<ここにPORT番号を書き込む>",
  "init_weights_flag": 1,
  "token": "stadle12345",
  "simulation": "False",
  "exch_socket": "0000",
  "agent_name": "default_agent"
}

[STEP04]モデルアーキテクチャアップロード

python mnist_admin_agent.py --agent_name <ここに任意の名前を入れる>

stadle.ai のダッシュボードを更新すると、以下のようにアップロードしたモデルアーキテクチャ名が登録され、Agents Connected が1になります。

image.png

[STEP05]学習プロセスの実行

次に複数のターミナルで学習プロセスを実行します。
この例では、手元のPCだけで3つの学習環境を実行しますが、もちろん複数の異なるPCを使用した学習環境でも可能です。

【接続エージェント識別名指定引数】
--agent_name エージェント名
【使用データ選択引数】
--classes 指定クラス(カンマ区切り)
【選択データ使用比率引数】
--sel_prob 比率 0.0~1.0
【非選択データ使用比率引数】
--def_prob 比率 0.0~1.0

先ほどのモデルアーキテクチャをアップロードしたターミナルからは以下のコマンドで agent01 という名前でプロセスを実行します。

python mnist_ml_agent.py --agent_name agent01 --classes 1,2,3 --sel_prob 1.0 --def_prob 0.05

他のターミナルを開いて、同様にconda環境 "ENVCLIENT" に入り、mnist_example/デレクトリに移動
以下のコマンドで agent02 という名前でプロセスを実行します。

python mnist_ml_agent.py --agent_name agent02 --classes 4,5,6 --sel_prob 1.0 --def_prob 0.05

同様にして3つめのターミナルを用意して以下のコマンドで agent03 という名前でプロセスを実行

python mnist_ml_agent.py --agent_name agent03 --classes 7,8,9,0 --sel_prob 1.0 --def_prob 0.05

上記コマンドの引数は、--classes が学習に使用するラベルを指定します。
上記の例では、agent01 が手書き文字の1/2/3を選択し、agent02 が4/5/6を選択し、agent03 が7/8/9/0を選択しています。
--sel_prob は選択した数字を学習で用いる比率で、1.0は選択した数字のデータを全て使用することを意味します。
--def_prob は選択されていない他の数字を学習に用いる比率で、0.05は5%を意味します。
すなわち、agent01 は1/2/3のラベルがついたデータを全て使用し、4以降の他の数字データについては5%しか使用せずに学習を実行しています。

ここまでで、conda環境ENVCLIENT内で3つのMLプロセスが実行されます。

image.png

この状態でダッシュボードの表示は以下のようになり、Agents Connectedが4になっているはずです。

image.png

[STEP06]学習過程のモニタリング

左のサイドバーから“Performance Tracking`を選択するとモデルの各評価指標について学習過程のモニタリングが可能です。

image.png

今回用意したMLアプリケーションのコードでは、各学習環境で2epochの学習が進むたびに stadle.ai に各学習環境のローカルモデルが集約されます。
それらを元に stadle.ai でセミグローバルモデルが生成されて各学習環境へ戻され、そのセミグローバルモデルをベースに、また各学習環境のデータを用いた学習が行われるというサイクルになります。

[STEP07]単一エージェントでの学習結果と比較

上記の連合学習の結果とシングルエージェントによる学習結果を比較してみましょう。
シングルエージェントは、先述の agent01 と同じ条件で、1/2/3のデータをフルで使用して4以降のデータに関しては5%程度とした学習を行いました。
シングルエージェントで且つ極端に偏りのあるデータでは十分に学習を進める事ができませんが、複数の学習環境にあるデータで補い合う事で正常に学習を進める事が出来ています。

image.png

以上stadle.aiと連合学習を用いたデモでした。

TieSet inc.ではこのようなユースケースを作成していただける方を募集して抱いております。

8
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
7