このチュートリアルでは、プライバシーを保護しながらニューラルネットワークを学習できる新しい分散学習アルゴリズム、連合学習(Federated Learning)について学びます。
本稿は、筆者が英語で執筆した記事の翻訳になります。
深層学習は様々な分野で大きな成功を収めていますが、深層学習モデルの学習には大量のデータが必要です。そのため、プライバシーを保護しながら深層学習で高いパフォーマンスを得ることは課題となっています。この問題を解決する一つの方法が連合学習です。連合学習では、複数のクライアントがローカルデータセットを共有せずに、協力して単一のグローバルモデルを学習します。
典型的な連合学習の手順は以下の通りです:
1. 中央サーバーがグローバルモデルを初期化する。
2. サーバーが各クライアントにグローバルモデルを配布する。
3. 各クライアントが自身のデータセットで損失関数の勾配をローカルに計算する。
4. 各クライアントが勾配をサーバーに送信する。
5. サーバーが受け取った勾配を何らかの方法(例:平均)で集約し、集約された勾配でグローバルモデルを更新する。
6. 収束するまで2〜5を繰り返す。
集約の方法が重み付き平均の場合、数学的表記は以下の通りです:
$$
w_{t} \leftarrow w_{t - 1} - \eta \sum_{c=1}^{C} \frac{n_{c}}{N} \nabla \mathcal{l}(w_{t - 1}, X_{c}, Y_{c})
$$
ここで、$w_{t}$はt回目のラウンドにおけるグローバルモデルのパラメータ、$\nabla \mathcal{l}(w_{t - 1}, X_{c}, Y_{c})$はc番目のクライアントのデータセット$(X_{c}, Y_{c})$で計算された勾配、$n_{c}$はc番目のクライアントのデータセットのサンプル数、Nは全サンプル数です。
コード
次に、連合学習の代表的な手法の一つであるFedAVG [1]を実装します。機械学習アルゴリズムのセキュリティとプライバシーリスクを実験するためのOSSであるAIJackを使用します。AIJackはシングルプロセスとMPIの両方をバックエンドとしてサポートしています。
まず、pip
でAIJackをインストールします。
apt install -y libboost-all-dev
pip install -U pip
pip install "pybind11[global]"
pip install aijack
シングルプロセス
以下のモジュールをインポートします。
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms
from aijack.collaborative.fedavg import FedAVGClient, FedAVGServer, FedAVGAPI
今回使用するハイパーパラメータは以下の通りです。
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
client_size = 2
criterion = F.nll_loss
このチュートリアルではMNISTデータセットを使用します。
def prepare_dataloader(num_clients, myid, train=True, path=""):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
if train:
dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
idxs = list(range(len(dataset.data)))
random.shuffle(idxs)
idx = np.array_split(idxs, num_clients, 0)[myid - 1]
dataset.data = dataset.data[idx]
dataset.targets = dataset.targets[idx]
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=training_batch_size
)
return train_loader
else:
dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
return test_loader
AIJackを使用すると、PyTorchモデルを使って連合学習のクライアントとサーバーを簡単に実装できます。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.ln = nn.Linear(28 * 28, 10)
def forward(self, x):
x = self.ln(x.reshape(-1, 28 * 28))
output = F.log_softmax(x, dim=1)
return output
clients = [FedAVGClient(Net().to(device), user_id=c) for c in range(client_size)]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]
server = FedAVGServer(clients, Net().to(device))
次にFedAVGAPI
のrun
メソッドを使って学習を実行できます。
api = FedAVGAPI(
server,
clients,
criterion,
local_optimizers,
local_dataloaders,
num_communication=num_rounds,
custom_action=evaluate_gloal_model(test_dataloader),
)
api.run()
実行結果
communication 0, epoch 0: client-1 0.019623182541131972
communication 0, epoch 0: client-2 0.019723439224561056
Test set: Average loss: 0.7824367880821228, Accuracy: 83.71
communication 1, epoch 0: client-1 0.01071754728158315
communication 1, epoch 0: client-2 0.010851142065723737
Test set: Average loss: 0.58545467877388, Accuracy: 86.49
communication 2, epoch 0: client-1 0.008766427374879518
communication 2, epoch 0: client-2 0.00891655088464419
Test set: Average loss: 0.507768925857544, Accuracy: 87.54
communication 3, epoch 0: client-1 0.007839484961827596
communication 3, epoch 0: client-2 0.00799967499623696
Test set: Average loss: 0.46477557654380797, Accuracy: 88.25
communication 4, epoch 0: client-1 0.0072782577464977904
communication 4, epoch 0: client-2 0.007445397683481375
Test set: Average loss: 0.436919868183136, Accuracy: 88.63
MPI
上記のコードは、並列プログラミング環境で実行できるMPI互換のコードに簡単に変換できます。
# mpi_FedAVG.py
import random
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms
from aijack.collaborative import FedAVGClient, FedAVGServer, MPIFedAVGAPI, MPIFedAVGClientManager, MPIFedAVGServerManager
logger = getLogger(__name__)
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
def fix_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# 省略(prepare_dataloader、Net、evaluate_gloal_modelの定義)
def main():
fix_seed(seed)
comm = MPI.COMM_WORLD
myid = comm.Get_rank()
size = comm.Get_size()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr)
mpi_client_manager = MPIFedAVGClientManager()
mpi_server_manager = MPIFedAVGServerManager()
MPIFedAVGClient = mpi_client_manager.attach(FedAVGClient)
MPIFedAVGServer = mpi_server_manager.attach(FedAVGServer)
if myid == 0:
dataloader = prepare_dataloader(size - 1, myid, train=False)
client_ids = list(range(1, size))
server = MPIFedAVGServer(comm, [1, 2], model)
api = MPIFedAVGAPI(
comm,
server,
True,
F.nll_loss,
None,
None,
num_rounds,
1,
custom_action=evaluate_gloal_model(dataloader),
device=device
)
else:
dataloader = prepare_dataloader(size - 1, myid, train=True)
client = MPIFedAVGClient(comm, model, user_id=myid)
api = MPIFedAVGAPI(
comm,
client,
False,
F.nll_loss,
optimizer,
dataloader,
num_rounds,
1,
device=device
)
api.run()
if __name__ == "__main__":
main()
このコードは、標準的なMPIコマンドで実行できます。
!mpiexec -np 3 --allow-run-as-root python /content/mpi_FedAVG.py
communication 0, epoch 0: client-3 0.019996537216504413
communication 0, epoch 0: client-2 0.02008056694070498
Round: 1, Test set: Average loss: 0.7860309104919434, Accuracy: 82.72
communication 1, epoch 0: client-3 0.010822976715366046
communication 1, epoch 0: client-2 0.010937693453828494
Round: 2, Test set: Average loss: 0.5885528886795044, Accuracy: 86.04
communication 2, epoch 0: client-2 0.008990796900788942
communication 2, epoch 0: client-3 0.008850129560629527
Round: 3, Test set: Average loss: 0.5102099328994751, Accuracy: 87.33
communication 3, epoch 0: client-3 0.00791173183619976
communication 3, epoch 0: client-2 0.008069112183650334
Round: 4, Test set: Average loss: 0.4666414333820343, Accuracy: 88.01
communication 4, epoch 0: client-2 0.007512268128991127
communication 4, epoch 0: client-3 0.007343090359369914
Round: 5, Test set: Average loss: 0.4383064950466156, Accuracy: 88.65
まとめ
このチュートリアルでは、プライバシーを侵害せずに深層学習モデルを安全に学習するための代表的なアプローチである連合学習について学びました。AIJackのドキュメントでは、より多くの例やノートブックを見つけることができます。
しかし実際にはこのアルゴリズムは安全とは限りません。各クライアントがローカルデータセットを共有する必要がないため、このスキームはセキュアに見えるかもしれませんが、次のチュートリアルでは、共有されたローカル勾配が私的情報を漏洩する可能性があることを示します。
参考文献
[1] McMahan, Brendan, et al. "Communication-efficient learning of deep networks from decentralized data." Artificial intelligence and statistics. PMLR, 2017.
[2] Takahashi, Hideaki. "AIJack: Security and Privacy Risk Simulator for Machine Learning." arXiv preprint arXiv:2312.17667 (2023).