この記事では.
今更Deepやり始めたのでメモ程度です。
基本的に他記事をなぞり、自分の理解のために可視化などの回り道をした部分のみをまとめる。
今回はMNISTで単純パーセプトロンで分類する。
そのときにどのように重みが更新されているのかを可視化している。
モデル
単純パーセプトロンなので特段記載する必要なし。
class NN(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(28*28, 10) # 画像サイズ:28*28, 分類:0~9
def forward(self, input_data):
input_data = input_data.view(-1, 28*28)
input_data = self.layer(input_data)
return input_data
可視化しているだけなので lr, epoch, batch などの条件は省略する。
学習工程の可視化
重みは以下で layer.weight, layer.bias を確認できる。
model.state_dict()
任意のタイミングで以下のコードを差し込めばOK。これで 0~9 の重みの更新が可視化される。
PyTorchはtensor型でデータを扱うがplot時はnumpyに変換する必要がある。(もしかしたらないかもしれんが・・)
# 重み tensor > numpy
weights = [weight.numpy().reshape(28,28) for weight in model.state_dict()["layer.weight"]]
# subplotで 0 ~ 9の重みを可視化
plt.figure(figsize=(30,6))
for i, weight in enumerate(weights):
plt.subplot(1, len(weights)+1, i+1)
plt.imshow(weight)
# これがないとprintなどのタイミングが狂う
plt.show()
最初はランダムで初期値設定しているのでノイズにしか見えない。
予測の可視化
上記の重みと画像を掛け合わせる。
# テストデータ
test = test_data[0].numpy().reshape(28,28)
# subplotで 0 ~ 9の重みを可視化
plt.figure(figsize=(30,6))
for i, weight in enumerate(weights):
plt.subplot(1, len(weights)+1, i+1)
plt.imshow(test*weight, vmin=-0.3, vmax=0.3) # 画像*重み, 最大値・最小値は適当
元の画像が5で重みをそのまま掛け合わせた結果(0~9)
3,5,6,8あたりが似ている。5が一番色付けされている気がする。