LoginSignup
2
0

More than 1 year has passed since last update.

連合学習(Federated Learning)に使われているメソッドについて

Last updated at Posted at 2022-07-25

1.はじめに

・連合学習の流れは下記の通りです。

①:中央からグローバルモデルを各ローカル端末に配布
②:各ローカル端末が持ってるデータを使ってトレーニングし、モデルを更新
③:各ローカル端末で更新されたモデルを中央に集約(アグリゲート)し、新しいグローバルモデルを作る
④:①〜③を繰り返す

・③のローカルからモデルを中央に集約しグローバルモデルを作る(アグリゲーション)際に、「FedAvg」(各モデルの平均を取得する)というアルゴリズムが使われています。
・しかし、このFedAvgは完璧ではなく、精度の悪いグローバルモデルを作ってしまうことがあります。

グローバルモデルの精度が悪くなる原因
・明らかに他のモデルとは違うタイプのモデル(外れ値みたいな)が紛れている
・それぞれのローカル側のデータでモデルを更新する際、それぞれのデータの偏りから適切な平均が出ない。

・これらを解決するために、以下のアルゴリズムが使われていたりします。

2. Krum

Krumはf人のアタッカー(モデルの集約を意図的に妨害するやつ)を許容できる集約ルール

・各ローカル端末から集められた、コスト関数の「勾配推定値」$(V_1,・・・,V_n)$があります。
・任意の推定値$V$と、それに最も近い「n-f-2個の推定値」のユーグリッド距離を合計し、$scores(i)$を取得します。
・$scores(i)$を最小にするの時の推定値 $V_{i_{*}}$がKrum関数より得られます。

\newcommand{\argmax}{\mathop{\rm arg~max}\limits}
\newcommand{\argmin}{\mathop{\rm arg~min}\limits}
\begin{align}
scores(i)&=\sum_{i=1}^{n-f-2}||V-V_i||^2 \\ 
&=\sum_{i=1}^{n-f-2}\sqrt{(V-V_i)^2}\\
 \\
Krum&=V_{i_{*}} 
\\
\end{align}
\\
\left\{
\begin{array}{ll}
\ V:任意の勾配推定値\\
V_i:ローカルモデルから集められた勾配推定値\\
V_{i_{*}}: Krum関数により得られた推定値\\
f:アタッカー(異常モデル)\\
n:クライアントの数\\ 
i_*: scores(i)を最小にするi \\

\end{array}
\right.

3. Geometric Median

集められたローカルモデルのパラメータの「中央値」を使ってグローバルモデルを更新する。

・FedAvgではモデルの平均を出してグローバルモデルを作っていましたが、geometric medianでは中央値を出してグローバルモデルを作ります。
・アタッカーによる攻撃や一般的な異常なモデル(外れ値)の影響を抑えます。
・中央値の計算は以下の通りです。

\newcommand{\argmax}{\mathop{\rm arg~max}\limits}
\newcommand{\argmin}{\mathop{\rm arg~min}\limits}
\begin{align}
Geometric Median&=\argmin_{y\in\mathbb{R}^{n}}\sum_{i=1}^{m}||x_i-y||_2 \\
&=\argmin_{y\in\mathbb{R}^{n}}\sum_{i=1}^{m}\sqrt{(x_i-y)^2}\\
\end{align}
\\
\left\{
\begin{array}{ll}
\ y:グローバルモデル\\
x_i:ローカルモデル\\
\end{array}
\right.

4. Coordinate-Wise Median

集められたローカルモデルのパラメータの「中央値」を使ってグローバルモデルを更新する。

・Geometric Medianと同じで、平均の代わりに中央値を出してグローバルモデルを作ります。
・アタッカーによる攻撃や一般的な異常なモデル(外れ値)の影響を抑えます。
・中央値の計算がGeometric Medianとは違い、2点間の距離をマンハッタン距離($||x_i-y||_1$)で求めています。

\newcommand{\argmax}{\mathop{\rm arg~max}\limits}
\newcommand{\argmin}{\mathop{\rm arg~min}\limits}
\begin{align}
CoordinateWise Median&=\argmin_{y\in\mathbb{R}^{n}}\sum_{i=1}^{m}||x_i-y||_1 \\
&=\argmin_{y\in\mathbb{R}^{n}}\sum_{i=1}^{m}(x_i-y)
\end{align}\\

\left\{
\begin{array}{ll}
\ y:グローバルモデル\\
x_i:ローカルモデル\\
\end{array}
\right.

5. FedProx

ローカルでのトレーニングの損失関数に近位項を追加して、ローカルモデルの変更を制限する。

近位項とは、ローカルで得られた重みとグローバルモデルの重みの違いを表す項($||w-w_g||^2$ の部分)です。
・異なるローカル端末での計算処理速度の不均一性によるパフォーマンスの低下を防ぐため、ペナルティを課します。
計算処理速度:ローカル端末では、マシンの性能の違いから計算処理の速度の違いが生じているため、アグリゲーションのパフォーマンスが低下してしまいます。

\tilde{L(w)}=L(w)+α||w-w_g||^2 \\

\left\{
\begin{array}{ll}
\tilde{L(w)}:FedProxによって制約された損失関数\\
L(w):ローカルの損失関数\\
w:ローカルで得られた重み\\
w_g:グローバルモデルの重み\\
α:近位項を制御するハイパーパラメータ\\
\end{array}
\right.

6. FedCurv

フィッシャー情報行列を使い、モデル内の各パラメータの重要性を評価し、重要度に応じてより大きなペナルティを課す。

・Non-iidデータによる異なるローカル端末でのデータの不均一性によるパフォーマンスの低下を阻止するため、制約をします。
データの不均一性:各クライアントではデータに隔たりがある(あるクライアントでは犬の画像しか持っていなくて、もう一方では猫の画像しか持っていない)場合、集約されて作られたグローバルモデルのパフォーマンスが低い。FedProxを使うことで、そのパフォーマンスを維持したどっちにも対応した(犬も猫も予測できる)グローバルモデルが作られる。

\tilde{L(w)_s}=L(w)_s+α\sum_{j∈S-s}(w-w_j)^Tdiag(\hat{I_j})(w-\hat{w_j}) \\

\left\{
\begin{array}{ll}
\tilde{L(w)_s}:FedCurvによって制約された損失関数\\
L(w)_s:ローカルの損失関数\\
S:トレーニングに参加しているノード\\
s:現在のノード\\
\hat{w_j}:ノードjのトレーニングで得られた重み\\
\hat{I_j}:\hat{w_j}に対応するフィッシャー情報行列\\
w:ローカルで得られた重み\\
α:ペナルティ項を制御するハイパーパラメータ\\
\end{array}
\right.

7. Trimmed mean(トリム平均)

トリム平均を使って各モデルの平均値を出し、グローバルモデルを作成する

トリム平均とは、データを昇順に並べ替え、上位と下位〇〇%を除外して、残ったデータで平均値を取る手法です。
・各ローカルから集められたモデルを並べ替えし、トリム平均を取り、グローバル平均を作成します。

参考

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0