1.はじめに
・連合学習の流れは下記の通りです。
①:中央からグローバルモデルを各ローカル端末に配布
②:各ローカル端末が持ってるデータを使ってトレーニングし、モデルを更新
③:各ローカル端末で更新されたモデルを中央に集約(アグリゲート)し、新しいグローバルモデルを作る
④:①〜③を繰り返す
・③のローカルからモデルを中央に集約しグローバルモデルを作る(アグリゲーション)際に、「FedAvg」(各モデルの平均を取得する)というアルゴリズムが使われています。
・しかし、このFedAvgは完璧ではなく、精度の悪いグローバルモデルを作ってしまうことがあります。
グローバルモデルの精度が悪くなる原因
・明らかに他のモデルとは違うタイプのモデル(外れ値みたいな)が紛れている
・それぞれのローカル側のデータでモデルを更新する際、それぞれのデータの偏りから適切な平均が出ない。
・これらを解決するために、以下のアルゴリズムが使われていたりします。
2. Krum
Krumはf人のアタッカー(モデルの集約を意図的に妨害するやつ)を許容できる集約ルール
・各ローカル端末から集められた、コスト関数の「勾配推定値」$(V_1,・・・,V_n)$があります。
・任意の推定値$V$と、それに最も近い「n-f-2個の推定値」のユーグリッド距離を合計し、$scores(i)$を取得します。
・$scores(i)$を最小にするの時の推定値 $V_{i_{*}}$がKrum関数より得られます。
\newcommand{\argmax}{\mathop{\rm arg~max}\limits}
\newcommand{\argmin}{\mathop{\rm arg~min}\limits}
\begin{align}
scores(i)&=\sum_{i=1}^{n-f-2}||V-V_i||^2 \\
&=\sum_{i=1}^{n-f-2}\sqrt{(V-V_i)^2}\\
\\
Krum&=V_{i_{*}}
\\
\end{align}
\\
\left\{
\begin{array}{ll}
\ V:任意の勾配推定値\\
V_i:ローカルモデルから集められた勾配推定値\\
V_{i_{*}}: Krum関数により得られた推定値\\
f:アタッカー(異常モデル)\\
n:クライアントの数\\
i_*: scores(i)を最小にするi \\
\end{array}
\right.
3. Geometric Median
集められたローカルモデルのパラメータの「中央値」を使ってグローバルモデルを更新する。
・FedAvgではモデルの平均を出してグローバルモデルを作っていましたが、geometric medianでは中央値を出してグローバルモデルを作ります。
・アタッカーによる攻撃や一般的な異常なモデル(外れ値)の影響を抑えます。
・中央値の計算は以下の通りです。
\newcommand{\argmax}{\mathop{\rm arg~max}\limits}
\newcommand{\argmin}{\mathop{\rm arg~min}\limits}
\begin{align}
Geometric Median&=\argmin_{y\in\mathbb{R}^{n}}\sum_{i=1}^{m}||x_i-y||_2 \\
&=\argmin_{y\in\mathbb{R}^{n}}\sum_{i=1}^{m}\sqrt{(x_i-y)^2}\\
\end{align}
\\
\left\{
\begin{array}{ll}
\ y:グローバルモデル\\
x_i:ローカルモデル\\
\end{array}
\right.
4. Coordinate-Wise Median
集められたローカルモデルのパラメータの「中央値」を使ってグローバルモデルを更新する。
・Geometric Medianと同じで、平均の代わりに中央値を出してグローバルモデルを作ります。
・アタッカーによる攻撃や一般的な異常なモデル(外れ値)の影響を抑えます。
・中央値の計算がGeometric Medianとは違い、2点間の距離をマンハッタン距離($||x_i-y||_1$)で求めています。
\newcommand{\argmax}{\mathop{\rm arg~max}\limits}
\newcommand{\argmin}{\mathop{\rm arg~min}\limits}
\begin{align}
CoordinateWise Median&=\argmin_{y\in\mathbb{R}^{n}}\sum_{i=1}^{m}||x_i-y||_1 \\
&=\argmin_{y\in\mathbb{R}^{n}}\sum_{i=1}^{m}(x_i-y)
\end{align}\\
\left\{
\begin{array}{ll}
\ y:グローバルモデル\\
x_i:ローカルモデル\\
\end{array}
\right.
5. FedProx
ローカルでのトレーニングの損失関数に近位項を追加して、ローカルモデルの変更を制限する。
・近位項とは、ローカルで得られた重みとグローバルモデルの重みの違いを表す項($||w-w_g||^2$ の部分)です。
・異なるローカル端末での計算処理速度の不均一性によるパフォーマンスの低下を防ぐため、ペナルティを課します。
・計算処理速度:ローカル端末では、マシンの性能の違いから計算処理の速度の違いが生じているため、アグリゲーションのパフォーマンスが低下してしまいます。
\tilde{L(w)}=L(w)+α||w-w_g||^2 \\
\left\{
\begin{array}{ll}
\tilde{L(w)}:FedProxによって制約された損失関数\\
L(w):ローカルの損失関数\\
w:ローカルで得られた重み\\
w_g:グローバルモデルの重み\\
α:近位項を制御するハイパーパラメータ\\
\end{array}
\right.
6. FedCurv
フィッシャー情報行列を使い、モデル内の各パラメータの重要性を評価し、重要度に応じてより大きなペナルティを課す。
・Non-iidデータによる異なるローカル端末でのデータの不均一性によるパフォーマンスの低下を阻止するため、制約をします。
・データの不均一性:各クライアントではデータに隔たりがある(あるクライアントでは犬の画像しか持っていなくて、もう一方では猫の画像しか持っていない)場合、集約されて作られたグローバルモデルのパフォーマンスが低い。FedProxを使うことで、そのパフォーマンスを維持したどっちにも対応した(犬も猫も予測できる)グローバルモデルが作られる。
\tilde{L(w)_s}=L(w)_s+α\sum_{j∈S-s}(w-w_j)^Tdiag(\hat{I_j})(w-\hat{w_j}) \\
\left\{
\begin{array}{ll}
\tilde{L(w)_s}:FedCurvによって制約された損失関数\\
L(w)_s:ローカルの損失関数\\
S:トレーニングに参加しているノード\\
s:現在のノード\\
\hat{w_j}:ノードjのトレーニングで得られた重み\\
\hat{I_j}:\hat{w_j}に対応するフィッシャー情報行列\\
w:ローカルで得られた重み\\
α:ペナルティ項を制御するハイパーパラメータ\\
\end{array}
\right.
7. Trimmed mean(トリム平均)
トリム平均を使って各モデルの平均値を出し、グローバルモデルを作成する
・トリム平均とは、データを昇順に並べ替え、上位と下位〇〇%を除外して、残ったデータで平均値を取る手法です。
・各ローカルから集められたモデルを並べ替えし、トリム平均を取り、グローバル平均を作成します。
参考