2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

ファインチューニングを使ったCNNの実装

Last updated at Posted at 2021-03-19

#はじめに
本記事では,ファインチューニングを使ったCNNのコード(train)を掲載し,簡単な解説をします.
仮想環境の構築やライブラリのインストール方法に関しては割愛しますので,ご了承ください.

#本記事を書いた経緯
人工知能がTVや各種メディアに取り上げられることからも,AIがより身近なものになってきています.
大学の研究でAI(画像認識)をすることになる学生もますます増えていることかと思います.しかしながら,周りに頼れる人や先輩がおらず,苦労する大学生や大学院生が発生しているのではないかと考え,本記事を書きました.
そういった方々(もちろんそうでない方々も)の助けになれば幸いです.

#ファインチューニングについて
ファインチューニングとは,すでに学習され,重みが与えられている既存のモデルをベースにして新たなモデルを構築することで,学習に使用するデータ数が少なくても,適切な学習が行えるというものです.CNNの場合,通常,学習に使用するデータ数は数千,数万のデータが必要であるといわれています.一方,ファインチューニングをすることで学習に使用するデータ数を減らすことができ,数百個のデータで学習を行っている研究もあります.

#開発環境

  • Windows10
  • Python 3.7.7
  • TensorFlow 1.14.0
  • keras 2.3.1
  • numpy 1.16.4
  • matplotlib 3.1.3
  • hdf5 1.8.20

Anacondaでroot環境をコピーして,tensorflow,kerasあたりを入れて作りました.
root環境だとnumpyのverが1.17とかなので,バージョンをダウンさせてます.
一応TensorFlowとKeras両方を入れています.
※仮想環境の構築は以下の記事が参考になるかと
https://qiita.com/ozaki_physics/items/13466d6d1954a0afeb3b

#CNNのコード

from PIL import Image
import os, glob
import numpy as np
import random, math

#画像が保存されているルートディレクトリのパス
#root_dir = "パス"
root_dir = "./image"
#分類する種別ごとにフォルダーを分ける。モデルの最終出力もここで設定したフォルダーの数に合わせる
categories = ["dog","rabbit","marmot"]

# 画像データ用配列
X = []
# ラベルデータ用配列
Y = []

#画像データごとにadd_sample()を呼び出し、X,Yの配列を返す関数
def make_sample(files):
    global X, Y
    X = []
    Y = []
    for cat, fname in files:
        add_sample(cat, fname)
    return np.array(X), np.array(Y)

#渡された画像データを読み込んでXに格納し、また、
#画像データに対応するcategoriesのidxをY格納する関数
def add_sample(cat, fname):
    img = Image.open(fname)
    img = img.convert("RGB")
    img = img.resize((224, 224))
    data = np.asarray(img)
    X.append(data)
    Y.append(cat)

#全データ格納用配列
allfiles = []

#カテゴリ配列の各値と、それに対応するidxを認識し、全データをallfilesにまとめる
for idx, cat in enumerate(categories):
    image_dir = root_dir + "/" + cat
    files = glob.glob(image_dir + "/*.jpg")
    for f in files:
        allfiles.append((idx, f))

#シャッフル後、学習データと検証データに分ける
random.shuffle(allfiles)
th = math.floor(len(allfiles) * 0.9)
train = allfiles[0:th]
test  = allfiles[th:]
X_train, y_train = make_sample(train)
X_test, y_test = make_sample(test)
xy = (X_train, X_test, y_train, y_test)

np.save("image_data", xy)

#モデルの構築
from keras import layers, models
from keras.applications import VGG16

#転移元のモデルとしてVGG16を使用
conv_base = VGG16(weights='imagenet',
                     include_top=False,
                     input_shape=(224, 224, 3))

model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
#model.add(layers.Dense(3,activation="sigmoid")) #分類先の種類分設定,分類数に応じて数字変更
model.add(layers.Dense(3, activation='softmax'))

#モデル構成の確認
model.summary()

#モデルのコンパイル
from keras import optimizers
#model.compile(loss="binary_crossentropy",バイナリは二値分類用らしい
model.compile(loss="categorical_crossentropy",
              optimizer=optimizers.RMSprop(lr=1e-4),
              metrics=["acc"])

#データの準備
from keras.utils import np_utils
import numpy as np

categories = ["dog","rabbit","marmot"]
nb_classes = len(categories)

#X_train, X_test, y_train, y_test = np.load("保存した学習データ・テストデータのパス")
X_train, X_test, y_train, y_test = np.load("image_data.npy",allow_pickle=True)

#データの正規化
X_train = X_train.astype("float") / 255
X_test  = X_test.astype("float")  / 255

#kerasで扱えるようにcategoriesをベクトルに変換
y_train = np_utils.to_categorical(y_train, nb_classes)
y_test  = np_utils.to_categorical(y_test, nb_classes)

#モデルの学習

model = model.fit(X_train,
                  y_train,
                  epochs=50,
                  batch_size=16,
                  validation_data=(X_test,y_test))

#学習結果を表示

import matplotlib.pyplot as plt

acc = model.history['acc']
val_acc = model.history['val_acc']
loss = model.history['loss']
val_loss = model.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
#plt.savefig('精度を示すグラフのファイル名')
plt.savefig('cnn_accuracy')

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
#plt.savefig('損失値を示すグラフのファイル名')
plt.savefig('cnn_loss')

#モデルの保存

json_string = model.model.to_json()

open('animal_model.json', 'w').write(json_string)

#重みの保存
hdf5_file = "animal_weight.hdf5"
model.model.save_weights(hdf5_file)

print('Finish!')

#コードの解説
ちょっと長くなりますが,上から順に簡単な説明をしていきたいと思います.
コードがわかっている方は読み飛ばしてしまって全く問題ないです!

from PIL import Image
import os, glob
import numpy as np
import random, math

#画像が保存されているルートディレクトリのパス
#root_dir = "パス"
root_dir = "./image"
#分類する種別ごとにフォルダーを分ける。モデルの最終出力もここで設定したフォルダーの数に合わせる
categories = ["dog","rabbit","marmot"]

importでは,本コードにて使用するライブラリーの指定をしています.できるだけimportはまとめた方が見やすいです(このコードではまとめてません).
root_dirは画像を読み込む場所を指定しています.categoriesは分類するものを指定してます.ここではdogとrabbitとmarmotの画像を分類するようにしてます.

# 画像データ用配列
X = []
# ラベルデータ用配列
Y = []

#画像データごとにadd_sample()を呼び出し、X,Yの配列を返す関数
def make_sample(files):
    global X, Y
    X = []
    Y = []
    for cat, fname in files:
        add_sample(cat, fname)
    return np.array(X), np.array(Y)

#渡された画像データを読み込んでXに格納し、また、
#画像データに対応するcategoriesのidxをY格納する関数
def add_sample(cat, fname):
    img = Image.open(fname)
    img = img.convert("RGB")
    img = img.resize((224, 224))
    data = np.asarray(img)
    X.append(data)
    Y.append(cat)

#全データ格納用配列
allfiles = []

#カテゴリ配列の各値と、それに対応するidxを認識し、全データをallfilesにまとめる
for idx, cat in enumerate(categories):
    image_dir = root_dir + "/" + cat
    files = glob.glob(image_dir + "/*.jpg")
    for f in files:
        allfiles.append((idx, f))

#シャッフル後、学習データと検証データに分ける
random.shuffle(allfiles)
th = math.floor(len(allfiles) * 0.9)
train = allfiles[0:th]
test  = allfiles[th:]
X_train, y_train = make_sample(train)
X_test, y_test = make_sample(test)
xy = (X_train, X_test, y_train, y_test)

np.save("image_data", xy)

ここの処理については,簡単に言うと,画像データをXに格納し,教師値をYに格納,そのあと学習データと検証データに分けてます.
教師値の与え方は各フォルダーごとに与えています.dogは[1,0,0],rabbitは[0,1,0],marmotは[0,0,1]といった具合です.同じフォルダーに入れた画像には,同じ教師値が与えられます.フォルダーの場所等に関しては後述の”使い方”で説明します.
学習データと検証データの比は9:1です.コード内の数値"0.9"を変えることで,好きな比にできます.
その後,image_dataという名前で保存しています.このファイルは後で読み込みます.

#モデルの構築
from keras import layers, models
from keras.applications import VGG16

#転移元のモデルとしてVGG16を使用
conv_base = VGG16(weights='imagenet',
                     include_top=False,
                     input_shape=(224, 224, 3))

model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
#model.add(layers.Dense(3,activation="sigmoid")) #分類先の種類分設定,分類数に応じて数字変更
model.add(layers.Dense(3, activation='softmax'))

#モデル構成の確認
model.summary()

#モデルのコンパイル
from keras import optimizers
#model.compile(loss="binary_crossentropy",バイナリは二値分類用らしい
model.compile(loss="categorical_crossentropy",
              optimizer=optimizers.RMSprop(lr=1e-4),
              metrics=["acc"])

続いてモデルの構築+コンパイルです.
今回はImageNetで学習をしたVGG16を既存モデルとして使用してます.
活性化関数がsoftmaxの層は出力層です.出力する数は分類の個数と同じにする必要があります.
今回はdog, rabbit, marmotの3分類なので3です.

モデルのコンパイルでは損失関数などを指定します.今回は3分類なので損失関数はcategolical…です.
最適化関数は適当です.

#データの準備
from keras.utils import np_utils
import numpy as np

categories = ["dog","rabbit","marmot"]
nb_classes = len(categories)

#X_train, X_test, y_train, y_test = np.load("保存した学習データ・テストデータのパス")
X_train, X_test, y_train, y_test = np.load("image_data.npy",allow_pickle=True)

#データの正規化
X_train = X_train.astype("float") / 255
X_test  = X_test.astype("float")  / 255

#kerasで扱えるようにcategoriesをベクトルに変換
y_train = np_utils.to_categorical(y_train, nb_classes)
y_test  = np_utils.to_categorical(y_test, nb_classes)

#モデルの学習

model = model.fit(X_train,
                  y_train,
                  epochs=50,
                  batch_size=16,
                  validation_data=(X_test,y_test))

データの準備では,さっき保存したimage_dataを読み込みます.
その後model.fitでモデルの学習を行います.エポック数やバッチサイズは変更可能です.
学習しているときは,その経過が表示されます(〇〇%まで学習中みたいな).

#学習結果を表示

import matplotlib.pyplot as plt

acc = model.history['acc']
val_acc = model.history['val_acc']
loss = model.history['loss']
val_loss = model.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
#plt.savefig('精度を示すグラフのファイル名')
plt.savefig('cnn_accuracy')

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
#plt.savefig('損失値を示すグラフのファイル名')
plt.savefig('cnn_loss')

#モデルの保存

json_string = model.model.to_json()

open('animal_model.json', 'w').write(json_string)

#重みの保存
hdf5_file = "animal_weight.hdf5"
model.model.save_weights(hdf5_file)

print('Finish!')

最後に正解率の変動がcnn_accuracyというファイル名,損失関数の値の変動がcnn_lossというファイル名で保存され,表示されます.
不要なら消してしまっても構いません.
学習が完了したモデルがanimal_model.json,重みがanimal_weight.hdf5というファイル名で保存されます.
こちらはpredictする際に読み込むことになるかと思います.

#使い方
下の図のようにファイルを配置すればオーケーです.
mae - コピー.png

本記事に示したコードを入れたファイルと同じフォルダーに"image"という名前のフォルダーを置き,その中にdog, rabbit, marmotのフォルダーを入れてさらにその中に画像を入れるといった具合です.

学習を終えると以下の図のようになります.

ato - コピー.png

コードのファイルがあるフォルダー内に,image_dataやれweightやれいろいろなファイルが保存されます.
ファイルがこういう配置になっているのは,ファイルの保存・読み込む場所のパスをほとんど指定してないからです(学習に使う画像のみファイルを指定してます).理由は私がパスの書き方について"./"ぐらいしかわかってないからです.
普通はこんなにごちゃごちゃにはならないです!

#おわりに
CNNでファインチューニングするコードについて,書かせていただきました.
少しでも参考になったのなら幸いです.
気が向いたらpredictのコードも書きたいと思います.
といっても,predictはモデルと重みと画像を読み込むだけですが…

#参考にしたもの
画像認識で「綾鷹を選ばせる」AIを作る
https://qiita.com/tomo_20180402/items/e8c55bdca648f4877188
データの読み込みなどコードのほとんどの部分において,この記事を参考にしました.

Keras documentation
https://keras.io/ja/
kerasの公式(?)サイトです.大変お世話になっております.

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?