LoginSignup
0
1

More than 1 year has passed since last update.

PyTorch入門(part-1) - パーセプトロンの可視化

Posted at

この記事では.

今更Deepやり始めたのでメモ程度です。
基本的に他記事をなぞり、自分の理解のために可視化などの回り道をした部分のみをまとめる。

今回はMNISTで単純パーセプトロンで分類する。
そのときにどのように重みが更新されているのかを可視化している。

モデル

単純パーセプトロンなので特段記載する必要なし。

class NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(28*28, 10) # 画像サイズ:28*28, 分類:0~9
        
    def forward(self, input_data):
        input_data = input_data.view(-1, 28*28)
        input_data = self.layer(input_data)
        return input_data

可視化しているだけなので lr, epoch, batch などの条件は省略する。

学習工程の可視化

重みは以下で layer.weight, layer.bias を確認できる。

model.state_dict()

任意のタイミングで以下のコードを差し込めばOK。これで 0~9 の重みの更新が可視化される。
PyTorchはtensor型でデータを扱うがplot時はnumpyに変換する必要がある。(もしかしたらないかもしれんが・・)

# 重み tensor > numpy
weights = [weight.numpy().reshape(28,28) for weight in model.state_dict()["layer.weight"]]

# subplotで 0 ~ 9の重みを可視化
plt.figure(figsize=(30,6))

for i, weight in enumerate(weights):
    plt.subplot(1, len(weights)+1, i+1)
    plt.imshow(weight)

# これがないとprintなどのタイミングが狂う
plt.show()

最初はランダムで初期値設定しているのでノイズにしか見えない。
スクリーンショット 2022-08-11 13.02.00.png

最後は 0~9 が浮き上がっているように見える。
スクリーンショット 2022-08-11 13.03.35.png

予測の可視化

上記の重みと画像を掛け合わせる。

# テストデータ
test = test_data[0].numpy().reshape(28,28)

# subplotで 0 ~ 9の重みを可視化
plt.figure(figsize=(30,6))

for i, weight in enumerate(weights):
    plt.subplot(1, len(weights)+1, i+1)
    plt.imshow(test*weight, vmin=-0.3, vmax=0.3) # 画像*重み, 最大値・最小値は適当

元の画像が5で重みをそのまま掛け合わせた結果(0~9)
3,5,6,8あたりが似ている。5が一番色付けされている気がする。
スクリーンショット 2022-08-11 13.23.13.png

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1