1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorch参考書 Self-Attention GAN のサンプルコードの追加フラグメント

Last updated at Posted at 2020-06-10

「つくりながら学ぶ! PyTorchによる発展ディープラーニング」の「第5章 GANによる画像生成(DCGAN、Self-Attention GAN)」の「Attention Map」の可視化について、追加フラグメントを作成しました。

参照書籍

つくりながら学ぶ! PyTorchによる発展ディープラーニング
小川雄太郎

https://book.mynavi.jp/ec/products/detail/id=104855
https://github.com/YutaroOgawa/pytorch_advanced
https://qiita.com/sugulu/items/07253d12b1fc72e16aba

Attention Map可視化フラグメント

「5-4_SAGAN.ipynb」の最後にセルを追加して、以下のコードを貼り付けてください。


# Attentiom Mapを出力

print('1段目 生成した画像データ')
print('2段目 Attentin Map1 中央のピクセル→全ピクセル')
print('3段目 Attentin Map1 全ピクセル→中央のピクセル ⇒ 2段目と同じ結果が得られる')
print('4段目 Attentin Map1 右下のピクセル→全ピクセル ⇒ 7と8で差が出やすいピクセル')
print('5段目 Attentin Map1 左上端のピクセル→全ピクセル')
print('6段目 Attentin Map1 自ピクセル→自ピクセル')

row_num=6

# print('fake_images : ' + str(fake_images.size()))
# print('am1 : ' + str(am1.size()))

fig = plt.figure(figsize=(3*5, 3*row_num))
for i in range(0, 5):

  fake_image = fake_images[i][0]
  am = am1[i].view(16, 16, 16, 16)

  # print('fake_image : ' + str(fake_image.size()))
  # print('am : ' + str(am.size()))

  for j in range(0, row_num):
    plt.subplot(row_num, 5, 5*j+i+1)

    if j == 0:

      # 1段目 生成した画像データ
      plt.imshow(fake_image.cpu().detach().numpy(), 'gray')

    elif (j > 0) and (j < 5):

      if j == 1:
        # 2段目 Attentin Map1 中央のピクセル→全ピクセル
        am_tmp = am[7][7]
      elif j == 2:
        # 3段目 Attentin Map1 全ピクセル→中央のピクセル
        am_tmp = am[:][:][7][7]
      elif j == 3:
        # 4段目 Attentin Map1 右下のピクセル→全ピクセル
        am_tmp = am[11][11]
      elif j == 4:
        # 5段目 Attentin Map1 左上端のピクセル→全ピクセル
        am_tmp = am[0][0]

      am_tmp = am_tmp.cpu().detach().numpy()
      # print('i : ' + str(i) + ', j : ' + str(j) + ', max : ' + str(np.max(am_tmp)) + ', min : ' + str(np.min(am_tmp)))
      plt.imshow(am_tmp, 'Reds', vmin=0, vmax=0.05)

    elif j == 5:
      am_tmp = np.ones((16,16), dtype='float')

      # 6段目 Attentin Map1 自ピクセル→自ピクセル
      for k in range(16):
        for l in range(16):
          am_tmp[k][l] = am[k][l][k][l]

      # print('i : ' + str(i) + ', j : ' + str(j) + ', max : ' + str(np.max(am_tmp)) + ', min : ' + str(np.min(am_tmp)))
      plt.imshow(am_tmp, 'Reds', vmin=0, vmax=0.05)

    else:
      raise ValueError

出力結果

image.png
1~3段目
image.png
4~6段目
image.png

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?