前回の記事で開設したように、連合学習はクライアントがプライベートなデータセットを隠すことを可能にしますが、多くの論文[1, 2, 3]では、悪意のあるサーバーが、クライアントからアップロードされた勾配 $\nabla \mathcal{l}(w_{t - 1}, X, Y)$ を用いて、クライアントの手元にあるプライベートな情報である訓練サンプルを復元できることを示しています。
具体的には、サーバーはすでにグローバルモデルのパラメータ $w_{t - 1}$ を知っているため、以下の最適化によってプライベートな訓練サンプル $(X, Y)$ を推定することができます。
$$
X' \leftarrow X' - \lambda \nabla_{X'} D
$$
$$
Y' \leftarrow Y' - \lambda \nabla_{Y'} D
$$
ここで、$\lambda$は学習率、$D$ は以下のように計算される損失関数です:
$$
D = || \nabla \mathcal{l}(w_{t - 1}, X, Y) - \nabla \mathcal{l}(w_{t - 1}, X', Y') ||_{2}
$$
言い換えれば、この攻撃は、クライアントから受け取った勾配に十分近い勾配を生成するように偽のデータを最適化することで、プライベートな訓練データを再構築しようとします。
本稿は、筆者が英語で執筆した記事の翻訳になります。
コード
このような攻撃は近年活発に研究されており、様々な距離メトリクス、正則化項、最適化手法が提案されています。前回の記事で使ったAIJackは、その中でもよく使われる損失関数や最適化手法を網羅して言います。
まず、必要なライブラリをインポートします。
import cv2
import copy
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from numpy import e
from matplotlib import pyplot as plt
import torch.optim as optim
from tqdm.notebook import tqdm
from aijack.collaborative.fedavg import FedAVGAPI, FedAVGClient, FedAVGServer
from aijack.attack.inversion import GradientInversionAttackServerManager
from torch.utils.data import DataLoader, TensorDataset
from aijack.utils import NumpyDataset
子のチュートリアルではモデルとしてシンプルなLeNet、データセットにはMNISTを使用します。
class LeNet(nn.Module):
def __init__(self, channel=3, hideen=768, num_classes=10):
super(LeNet, self).__init__()
act = nn.Sigmoid
self.body = nn.Sequential(
nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
nn.BatchNorm2d(12),
act(),
nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
nn.BatchNorm2d(12),
act(),
nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
nn.BatchNorm2d(12),
act(),
)
self.fc = nn.Sequential(nn.Linear(hideen, num_classes))
def forward(self, x):
out = self.body(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
def prepare_dataloader(path="MNIST/.", batch_size=64, shuffle=True):
at_t_dataset_train = torchvision.datasets.MNIST(
root=path, train=True, download=True
)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
dataset = NumpyDataset(
at_t_dataset_train.train_data.numpy(),
at_t_dataset_train.train_labels.numpy(),
transform=transform,
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0
)
return dataloader
ハイパーパラメータは以下の通りです:
torch.manual_seed(7777)
shape_img = (28, 28)
num_classes = 10
channel = 1
hidden = 588
criterion = nn.CrossEntropyLoss()
num_seeds = 5
今回の攻撃の目標は、以下のデータを復元することです。
device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
dataloader = prepare_dataloader()
for data in dataloader:
xs, ys = data[0], data[1]
break
x = xs[:1]
y = ys[:1]
fig = plt.figure(figsize=(1, 1))
plt.axis("off")
plt.imshow(x.detach().numpy()[0][0], cmap="gray")
plt.show()
前回と同様に、AIJackを使用して連合学習を簡単に実装できます。大きな違いは、FedAVGServer
クラスをGradientInversionAttackServerManager
でラップすることで、悪意のあるサーバーによる勾配ベースのモデル反転攻撃を実行できることです。このマネージャークラスは、サーバーが各通信でアップロードされた勾配から、自動でプライベートデータを推定します。異なるランダムシードで5回攻撃を行います。
manager = GradientInversionAttackServerManager(
(1, 28, 28),
num_trial_per_communication=num_seeds,
log_interval=0,
num_iteration=100,
distancename="l2",
device=device,
gradinvattack_kwargs={"lr": 1.0},
)
DLGFedAVGServer = manager.attach(FedAVGServer)
client = FedAVGClient(
LeNet(channel=channel, hideen=hidden, num_classes=num_classes).to(device),
lr=1.0,
device=device,
)
server = DLGFedAVGServer(
[client],
LeNet(channel=channel, hideen=hidden, num_classes=num_classes).to(device),
lr=1.0,
device=device,
)
local_dataloaders = [DataLoader(TensorDataset(x, y))]
local_optimizers = [optim.SGD(client.parameters(), lr=1.0)]
api = FedAVGAPI(
server,
[client],
criterion,
local_optimizers,
local_dataloaders,
num_communication=1,
local_epoch=1,
use_gradients=True,
device=device,
)
api.run()
結果を確認すると、すべてのランダムシードで元のプライベート画像を正常に復元できていることが分かります。
fig = plt.figure(figsize=(5, 2))
for s, result in enumerate(server.attack_results[0]):
ax = fig.add_subplot(1, len(server.attack_results[0]), s + 1)
ax.imshow(result[0].cpu().detach().numpy()[0][0], cmap="gray")
ax.axis("off")
plt.tight_layout()
plt.show()
まとめ
このチュートリアルでは、サーバーが受け取った勾配からプライベートな訓練データを盗むことができるため、連合学習が安全ではないことを学びました。連合学習に対するモデル反転攻撃の例は、AIJackのドキュメントでより詳細に確認することができます。
この攻撃の回避方法の一つとして、次のチュートリアルでは、各クライアントがアップロードする前に手元の勾配を暗号化する連合学習を紹介します。
参考文献
[1] Zhu, Ligeng, Zhijian Liu, and Song Han. "Deep leakage from gradients." Advances in neural information processing systems 32 (2019).
[2] Zhao, Bo, Konda Reddy Mopuri, and Hakan Bilen. "idlg: Improved deep leakage from gradients." arXiv preprint arXiv:2001.02610 (2020).
Yin, Hongxu, et al. "See through gradients: Image batch recovery via gradinversion." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021.
[3] Takahashi, Hideaki. "AIJack: Security and Privacy Risk Simulator for Machine Learning." arXiv preprint arXiv:2312.17667 (2023).