はじめに
敵対的サンプル攻撃の一種であるFGSMをCIFAR10データセットを用いて実装しました。PyTorchの公式チュートリアルではMNISTを用いて実装していますが、今回はCIFAR10を用いて実装しました。実装にあたりチュートリアルから変更した部分など解説します。
公式チュートリアルのリンクはこちら。本記事と合わせてお読みください。
実装
チュートリアルと重複する部分や本質的ではないコードは省略します。
データセット、データローダの設定
mu=0.5
sigma=0.5
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mu, mu, mu), (sigma, sigma, sigma)),
])
trainset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=100,
shuffle=True,
num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
testloader = torch.utils.data.DataLoader(
testset,
batch_size=1,
shuffle=False,
num_workers=2
)
平均、標準偏差が0.5になるように正規化しています。このとき、画像のピクセル値の変域は[-1,1]になります。
モデルの設定
from torchvision import models
def Net():
model_ft=models.resnet50(pretrained=True)
model_ft.fc=nn.Linear(model_ft.fc.in_features, 10)
return model_ft
net=Net()
チュートリアルではLeNetを使用していますが、正答率が60%ぐらいまでしか上がらなかったのでresnet50を用いました。訓練済みモデルを使用することで効率よく収束することを狙っています。
余談ですが、ftは「fine tuning」の略だそうです。fcは出力層(最終層)をあらわしており、inputサイズは変えずにoutputサイズを10に変更しています。もともとtorchvisionが提供しているresnet50モデルはImageNetデータセットを用いた1000種類の画像分類タスクに対応しており、今回用いるCIFAR10は10種類の画像分類タスクのため、出力層のoutputサイズを10に変更しています。
訓練結果
# 出力(losses:訓練データの損失関数値, corrects:テストデータの正答率)
losses: [0.8784692868590355, 0.45671922412514687, 0.30172924281656743, 0.21383790130168198, 0.15750555766746402]
corrects: [0.7889, 0.8189, 0.831, 0.8221, 0.8403]
5エポック訓練しました。正答率84%。
FGSM攻撃
# FGSM attack code
def fgsm_attack(image, epsilon, data_grad):
# Collect the element-wise sign of the data gradient
sign_data_grad = data_grad.sign()
# Create the perturbed image by adjusting each pixel of the input image
perturbed_image = image + epsilon*sign_data_grad
# Adding clipping to maintain [-1,1] range
perturbed_image = torch.clamp(perturbed_image, -1, 1)
# Return the perturbed image
return perturbed_image
クリッピング範囲を[-1,1]に変更しました。
攻撃結果
テストデータに対してFGSM攻撃を行ったときの、$\epsilon=0,0.05,0.10$に関するテストデータの正答率の推移を示します。
$\epsilon=0\rightarrow0.05$で急激に下がっています。MNISTでは見られなかった結果です。
可視化
# Plot several examples of adversarial samples at each epsilon
cnt = 0
plt.figure(figsize=(16,3*len(epsilons)))
#plt.figure(facecolor="azure", edgecolor="coral", linewidth=2)
for i in range(len(epsilons)):
for j in range(5):
cnt += 1
plt.subplot(len(epsilons),5,cnt)
plt.xticks([], [])
plt.yticks([], [])
if j == 0:
plt.ylabel("Eps: {}".format(epsilons[i]), fontsize=14)
orig,adv,ex = examples[i][j]
plt.title("{} -> {}".format(orig, adv))
ex=mu+sigma*ex
plt.imshow(ex.transpose(1,2,0))
plt.tight_layout()
plt.show()
ex=mu+sigma*ex 部分で画像のピクセル値の範囲を[-1,1]から[0,1]に直しています。最初、私はこの操作を忘れており、可視化時に強制的に[0,1]にクリッピングされてしまいました。
可視化結果は以下の通り。
MNISTよりはちょっと荒っぽさが目立っていますね。
終わりに
今回の取り組みでPyTorchチュートリアルのFGSMソースコードに対する理解がかなり深まった。これを機に他のデータセットや攻撃手法の実装も試したい。