6
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Google Colaboratoryを使ってvgg16モデルをファインチューニングしてみた。

Last updated at Posted at 2019-01-11

画像認識や物体認識を調べてみるといろいろなモデルが出てきます。
cifar10やSSDなどいろいろあります。
今回は1000種類認識できるというVGG16モデルを使って標識を分類することを目的にプログラムを作成しました。
やったこと

・ 学習用、検証用、テスト用のgeneratorを作成する。
・ VGG16の15層以降を学習可能にする
・ VGGの出力をPooling層、全結合層 x 2、ソフトマックス関数に渡す
・ 学習した重みを保存する。
・ 重みを使って画像を予測する
目的

VGG16モデルのimagenetという学習データを使ってファインチューニングを利用して標識を分類する学習データを作成しました。
コード概要

プログラム構成はこちらのサイトを参考にしました。
vgg16の構成やファインチューニングの方法などの詳細説明が詳しくありますので参考にしてください。
GoogleColaboratoryについて

GoogleColaboratoryという無料でGPU環境で解析できるツールがあります。
時間制限などはありますが基本的にGoogleアカウントを持っていれば使用できます。
今回なぜGoogleColaboratoryを使用しているかというと今回の解析はノートPCでは120時間ほどかかる見込みだったGPU環境での解析が必要だったからです。GoogleClaboratoryを使用した場合は3時間ほどで解析終了しました。GPUってすごいですね。
GoogleClaboratoryでのフォルダアップロード

ファインチューニングするためには「test」、「train」、「validation」の3種類のフォルダをGoogleClaboratoryの仮想環境上にアップロードすることが必要です。
ファイルのアップロード方法はいくつかありますが今回はzipファイルをアップロードして解答するという手法を取ります。
〇フォルダをそれぞれzip形式で圧縮
「test.zip」、「train.zip」、「validation.zip」という3つのファイルを作成します。
〇zipファイルをアップロード
下記を実行するとファイルを一つずつアップロードできます。

GoogleClaboratory
from google.colab import files

uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name = fn, length = len(uploaded[fn])))

〇ファイルの確認
「!ls」を実行するとフォルダ構成を確認することができます。

GoogleClaboratory
!ls

<実行結果>
sample_data train.zip validation.zip
「sample_data」フォルダと「train.zip」「validation.zip」というファイルがあるよという意味です。
「sample_data」フォルダは常にGoogleClaboratoryにあるので気にしないでください。
〇zipファイルの解凍

GoogleClaboratory
!ls
!date -R
!unzip -qq train.zip
!unzip -qq validation.zip
!date -R
!ls

<実行結果>
sample_data train.zip validation.zip
Mon, 07 Jan 2019 11:17:39 +0000
replace train/crossroad/1929_s[1]-30x30.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
Mon, 07 Jan 2019 11:19:02 +0000
sample_data train train.zip validation validation.zip

これで仮想環境下にフォルダをアップロードできたのでこちらを利用してファインチューニングを実施していきます。

vgg16のファインチューニング

今回のプログラムは参考にしたこちらのサイトからコピーしてきたものです。

GoogleClaboratory
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D,Input
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
from keras.callbacks import CSVLogger

n_categories = 23
batch_size = 32
train_dir = 'train'
validation_dir = 'validation'
file_name = 'vgg16_sign_fine'

base_model = VGG16(weights = 'imagenet',include_top=False,
                 input_tensor = Input(shape=(224,224,3)))

#add new layers instead of FC networks
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024,activation = 'relu')(x)
prediction = Dense(n_categories,activation = 'softmax')(x)
model = Model(inputs = base_model.input,outputs=prediction)

#fix weights before VGG16 14layers
for layer in base_model.layers[:15]:
    layer.trainable = False

model.compile(optimizer = SGD(lr=0.0001,momentum=0.9),
              loss = 'categorical_crossentropy',
              metrics = ['accuracy'])

model.summary()

モデルの構成を確認したらいよいよ解析します。
こちらも参考にしたこちらを参考にして作成しています。

GoogleClaboratory
train_datagen = ImageDataGenerator(
    rescale = 1.0/255,
    shear_range = 0.2,
    zoom_range = 0.2,
    horizontal_flip = True)

validation_datagen = ImageDataGenerator(rescale=1.0/255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size = (224,224),
    batch_size = batch_size,
    class_mode = 'categorical',
    shuffle = True
)

validation_generator = validation_datagen.flow_from_directory(
    validation_dir,
    target_size = (224,224),
    batch_size = batch_size,
    class_mode = 'categorical',
    shuffle = True
)
#・steps_per_epochとvalidation_stepsは、1エポックあたりのステップ数。すなわち、全サンプル数=バッチ数*ステップ数。(1426 = 32*step(45))
hist=model.fit_generator(train_generator,
                         epochs = 300,
                         verbose = 1,
                         validation_steps = 50,
                         validation_data = validation_generator,
                         steps_per_epoch = 50,
                         callbacks=[CSVLogger(file_name + '.csv')])

#save weights
model.save(file_name + '.h5')

学習がうまく終われば学習データが仮想環境上にあるのでダウンロードします。
仮想環境上のデータはプログラムをスタートしてから一定時間(10時間くらい?)すると消えてしまうようなので注意して下さい。

〇学習データのダウンロード
下記を実行すると学習データをダウンロードできます。

GoogleClaboratory
from google.colab import files
files.download("vgg16_sign_fine.csv") 
files.download("vgg16_sign_fine.h5") 

今回の学習では80%程の正解率まで上がっていました。
学習モデルの確認

モデルの確認をするためにjsonファイルのモデル「model.json」を読み込んで重み「vgg16_sign_fine」を使って分類します。

GoogleClaboratory
import numpy as np
import time
import os
import keras
from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img
from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.preprocessing import image
from keras.layers import Input
from keras.models import model_from_json
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD

batch_size = 32
file_name = 'vgg16_sign_fine'

classes = ['crossroad','crosswalk','decrease_width','falling_rocks','hairpin_turn',
         'intersection_+','intersecton_R','intersection_T','intersection_Y','left_only',
         'no_entry','no_parking','parking','parking_stopping_prohibited','pedestrians_only',
         'road_closed','road_closed_2','road_under_repair','slowdown','speed30',
         'speed50','stop','traffic_circle']

modeljson = open('model.json').read()
model_vgg = model_from_json(modeljson)
model_vgg.load_weights(file_name+'.h5')

filename = 'pic07.jpg'

car_vgg = image.load_img(filename, target_size=(224, 224))

vgg = image.img_to_array(car_vgg)
vgg = np.expand_dims(vgg, axis=0)
vgg = preprocess_input(vgg)

pred_vgg = model_vgg.predict(vgg)[0]
predicted = pred_vgg.argmax()
text_vgg = '[No.' + str(predicted) + '] '  + classes[predicted]
print(text_vgg)

結果はあまりうまくいきませんでした。
標識自体のデータを集めるのが難しくデータ数が少なすぎたことが原因かと思っています。(今回は30枚/カテゴリ)
標識のデータはバリエーションが少ないのでどうやったら多くのデータを集められるのか模索中です。

参考ページ

VGG16を転移学習させて「まどか☆マギカ」のキャラを見分ける
https://qiita.com/God_KonaBanana/items/2cf829172087d2423f58

6
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?