LoginSignup
1
1

More than 1 year has passed since last update.

最大値プーリングをPythonで実装

Posted at

最大値のプーリングをPythonで実装する方法
メモ書き
参考になれば幸いです。

#ライブラリー呼び出し
import numpy as np

#下記データで最大値のプーリングをPythonで実装するには
#行列(データ)作成
array = np.arange(16).reshape(4, 4)
print("arrayは:", array)
#カーネル(フィルター)サイズ
kernel_size = 2
#ストライド(移動)サイズ
stride = 1

#最大値でプーリングを実施する関数
def max_pooling(array, kernel_size, stride):
  a_h,a_w = array.shape
  print("a_h(行列横方向サイズ) : ", a_h)
  print("a_w(行列縦方向サイズ) : ", a_w)
  k_h = kernel_size
  k_w = kernel_size
  print("k_h(カーネルサイズ) : ", k_h)
  row = int((a_w - k_w)/stride +1)
  print("row(OUTPUT行列の横方向サイズ) : ", row)
  column = int((a_h-k_h)/stride + 1)
  print("column(OUTPUT行列の縦方向サイズ) : ", column)
  output = np.zeros((column,row))
  print("output行列(0行列)の作成:", output)
  for i in range(column):
    for j in range(row):
      temp = np.zeros((k_h,k_w))
      for m in range(k_h):
        for n in range(k_w):
          temp[m,n] = array[i*stride + m, j*stride + n]
      output[i,j] = temp.max()
      print("tempMAX(プーリングの最大値) : ", temp.max())
      print("temp.max()の値をi行目に入れる。 i : ", i)
      print("temp.max()の値をj列目に入れる。 j : ", j)
  print("プーリング後の最終出力は")
  return output
  

#実際に出力
max_pooling (array, kernel_size, stride)

#以下出力結果
a_h(行列横方向サイズ) :  4
a_w行列縦方向サイズ :  4
k_hカーネルサイズ :  2
row(OUTPUT行列の横方向サイズ) :  3
columnOUTPUT行列の縦方向サイズ :  3
output行列0行列の作成 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
tempMAX(プーリングの最大値) :  5.0
temp.max()の値をi行目に入れる i :  0
temp.max()の値をj列目に入れる j :  0
tempMAX(プーリングの最大値) :  6.0
temp.max()の値をi行目に入れる i :  0
temp.max()の値をj列目に入れる j :  1
tempMAX(プーリングの最大値) :  7.0
temp.max()の値をi行目に入れる i :  0
temp.max()の値をj列目に入れる j :  2
tempMAX(プーリングの最大値) :  9.0
temp.max()の値をi行目に入れる i :  1
temp.max()の値をj列目に入れる j :  0
tempMAX(プーリングの最大値) :  10.0
temp.max()の値をi行目に入れる i :  1
temp.max()の値をj列目に入れる j :  1
tempMAX(プーリングの最大値) :  11.0
temp.max()の値をi行目に入れる i :  1
temp.max()の値をj列目に入れる j :  2
tempMAX(プーリングの最大値) :  13.0
temp.max()の値をi行目に入れる i :  2
temp.max()の値をj列目に入れる j :  0
tempMAX(プーリングの最大値) :  14.0
temp.max()の値をi行目に入れる i :  2
temp.max()の値をj列目に入れる j :  1
tempMAX(プーリングの最大値) :  15.0
temp.max()の値をi行目に入れる i :  2
temp.max()の値をj列目に入れる j :  2
プーリング後の最終出力は
array([[ 5.,  6.,  7.],
       [ 9., 10., 11.],
       [13., 14., 15.]])

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1