0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

CNNのプーリング処理を理解する。(備忘録)

Posted at

CNNってなんやねん。
プーリングってなんやねん。畳み込みってなんやねん。
よし、勉強しよ。

今回やること

CNNの際の処理に、画像の特徴量を抜きとる処理に「畳み込み」と「プーリング」がある。
今回はim2colアルゴリズムを使ってプーリングについて学習したので、備忘録として残しておく

CNNとは

簡単にいうと画像認識タスクに用いられる、深層学習アルゴリズム。
写真から特徴量を抜き取って、画像のものがなんなのか判別するタスク。
主に「畳み込み層」「プーリング層」「全結合層」「出力層」からなる。

プーリングとは

画像のピクセルにフィルターをかけていき、そのフィルターの範囲の最大値を取得して、特徴量を生成する。(元画像から特徴を抜き出して、画像を圧縮するイメージ)

4✖️4の画像に2✖️2のフィルターを右端から1マスずつずらしながらローラー作戦のように、対象4マスの中の最大値を取得していく。
(今回は最大値を取得するが、平均値とか、いろんな取得方法があるらしい)

[[ 1, 2, 3, 4           
   5, 6, 7, 8
   9,10,11,12
  13,14,15,16]]

↓

[[ 6,  7,  8
  10, 11, 12
  14, 15, 16]]

コードで実装

必要なライブラリとデータをインポート

python
import numpy as np
import matplotlib.pyplot as plt 
from sklearn import datasets

データを読み込み

データを読み込んで、今回試しに写真一枚だけプーリングを行ってみる。

python
digits = datasets.load_digits() # データをロード
print(digits.data.shape) # 読み込んだデータの形を表示

image = digits.data[2].reshape(8, 8) # 2枚目の写真を整形して、表示 
plt.imshow(image, cmap="gray")
plt.show() 

im2colアルゴリズムを定義

python
def im2col(img, filter_h, filter_w, output_h, output_w, stride):  # 入力画像、プーリングの高さ、幅、出力画像の高さ、幅、ストライド

    cols = np.zeros((filter_h*filter_w, output_h*output_w)) # outputの配列のサイズ
    for h in range(output_h):
        h_limit = stride*h + filter_h  # h:プーリングの上端、h_limit:プーリングの下端
        for w in range(output_w):
            w_limit = stride*w + filter_w  # w:プーリングの左端、w_limit:プーリングの右端
            cols[:, h*output_w+w] = img[stride*h:h_lim, stride*w:w_lim].reshape(-1)

    return cols

maxプーリング

定義したim2colを使って画像を行列計算をしやすいように変形した後、maxプーリング行います。
その後画像を表示する。

python
img_h, img_w = image.shape  # 入力画像の高さ、幅
pool = 2  # プーリング領域のサイズ

output_h = img_h//pool  # 出力画像の高さ
output_w = img_w//pool  # 出力画像の幅

cols = im2col(image, pool, pool, output_h, output_w, pool)
image_output = np.max(cols, axis=0)  # Maxプーリング
image_output = image_output.reshape(output_h, output_w) # プーリングした後、画像の形に変形

plt.imshow(image_output, cmap="gray") # 画像を表示
plt.show() 

終わりに

今回はCNNのプーリングをコートで実装してみました。
ライブラリを用いれば、簡単に実装できることをコードで書くと、深く理解できた気がします。
今後も学習楽しみます。

ps.
今日からカーソルキーを使わずに、Ctrl+f,b,p,nの練習しながら記事書いてみました。
慣れなくて苦労してます。。。ww

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?