目次
- はじめに
- 実行環境
- 準備
- 画像収集
- モデル作成〜分析、改善
- 結果と考察
はじめに
育休中に何かスキルを身につけたい!と思いAIdemyのAIアプリ開発講座を受講しています。
育児の傍ら勉強を進めるのは想像以上に大変でした。
今回学んだ成果をここにまとめます。
テーマ:画像認識でインコにあげられる野菜・あげられない野菜を識別する
最近、家族に仲間入りしたオカメインコに夢中で、毎日インコのことばかり考えているのでこのテーマにしました。
実際のアプリケーション
こちらが実際に作成したアプリケーションです。
実行環境
・Google Colaboratory
・Visual Studio Code
準備
KaggleのFruits and Vegetables Image Recognition Dataset をベースに必要な画像を追加して利用しました。
DLしたデータセット
次に、Google Colaboratory(以下 Colab)でGoogleドライブをマウントします。
from google.colab import drive
drive.mount("./content")
これでGoogleドライブ上の画像をColabで利用できるようになります。
画像収集
早速、画像を集めていきます。
あげられる野菜
- 小松菜
- 豆苗
- 水菜
- チンゲンサイ
- パセリ
- サラダ菜
- 春菊
- カブの葉
- 大根の葉
- かぼちゃ
- にんじん
- ピーマン
- ズッキーニ
- パプリカ
あげられない野菜
- アボカド
- 玉ねぎ
- 長ネギ
- にら
- にんにく
- モロヘイヤ
- 生姜
- ジャガイモ
- ほうれん草
それぞれ画像が100枚になるようにデータを集めていきます。
そこで、効率的に画像収集するためにicrawlerをインストールします。
pip install icrawler
Google用のクローラーのモジュールをインポートします。
from google.colab import drive
drive.mount("./content")
検索キーワードは日本語で、それぞれ最大500枚として収集しました。
from icrawler.builtin import GoogleImageCrawler
google_crawler = GoogleImageCrawler(storage={"root_dir": 'crawled_komatsuna'})
search_keywords = "小松菜 フリー" # 検索キーワード
num_images = 500 # 取得する画像の数
google_crawler.crawl(keyword=search_keywords, max_num=num_images)
が、各項目30枚づつぐらいしか収集できなかったため
足りない部分は手動で集めました。
モデル作成〜分析、改善
必要なモジュールをインポートします。
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Dense, Dropout, Flatten, Input
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras import optimizers
画像の前処理を行います。
各野菜フォルダからファイルの一覧を取得して画像データを格納するためのリストを作成します。
画像の前処理としてRGBの変換、画像のリサイズ(50×50)、画像データ格納用リストに読み込んだ画像を追加していく処理を行います。
# ファイル一覧を取得
path_komatsuna = os.listdir("/content/content/MyDrive/vegi_data/train/komatsuna/")
path_toumyo = os.listdir("/content/content/MyDrive/vegi_data/train/toumyo/")
path_mizuna = os.listdir("/content/content/MyDrive/vegi_data/train/mizuna/")
path_chingensai = os.listdir("/content/content/MyDrive/vegi_data/train/chingensai/")
path_paseri = os.listdir("/content/content/MyDrive/vegi_data/train/paseri/")
path_saradana = os.listdir("/content/content/MyDrive/vegi_data/train/saradana/")
path_shungiku = os.listdir("/content/content/MyDrive/vegi_data/train/shungiku/")
path_kabu = os.listdir("/content/content/MyDrive/vegi_data/train/kabu/")
path_daikon = os.listdir("/content/content/MyDrive/vegi_data/train/daikon/")
path_pumpkin = os.listdir("/content/content/MyDrive/vegi_data/train/pumpkin/")
path_carrot = os.listdir("/content/content/MyDrive/vegi_data/train/carrot/")
path_piiman = os.listdir("/content/content/MyDrive/vegi_data/train/piiman/")
path_zucchini = os.listdir("/content/content/MyDrive/vegi_data/train/zucchini/")
path_paprika = os.listdir("/content/content/MyDrive/vegi_data/train/paprika/")
path_abocado = os.listdir("/content/content/MyDrive/vegi_data/train/abocado/")
path_onion = os.listdir("/content/content/MyDrive/vegi_data/train/onion/")
path_naganegi = os.listdir("/content/content/MyDrive/vegi_data/train/naganegi/")
path_nira = os.listdir("/content/content/MyDrive/vegi_data/train/nira/")
path_garlic = os.listdir("/content/content/MyDrive/vegi_data/train/garlic/")
path_moroheiya = os.listdir("/content/content/MyDrive/vegi_data/train/moroheiya/")
path_ginger = os.listdir("/content/content/MyDrive/vegi_data/train/ginger/")
path_potato = os.listdir("/content/content/MyDrive/vegi_data/train/potato/")
path_hourensou = os.listdir("/content/content/MyDrive/vegi_data/train/hourensou/")
path_komatsuna_test = os.listdir("/content/content/MyDrive/vegi_data/test/komatsuna/")
path_toumyo_test = os.listdir("/content/content/MyDrive/vegi_data/test/toumyo/")
path_mizuna_test = os.listdir("/content/content/MyDrive/vegi_data/test/mizuna/")
path_chingensai_test = os.listdir("/content/content/MyDrive/vegi_data/test/chingensai/")
path_paseri_test = os.listdir("/content/content/MyDrive/vegi_data/test/paseri/")
path_saradana_test = os.listdir("/content/content/MyDrive/vegi_data/test/saradana/")
path_shungiku_test = os.listdir("/content/content/MyDrive/vegi_data/test/shungiku/")
path_kabu_test = os.listdir("/content/content/MyDrive/vegi_data/test/kabu/")
path_daikon_test = os.listdir("/content/content/MyDrive/vegi_data/test/daikon/")
path_pumpkin_test = os.listdir("/content/content/MyDrive/vegi_data/test/pumpkin/")
path_carrot_test = os.listdir("/content/content/MyDrive/vegi_data/test/carrot/")
path_piiman_test = os.listdir("/content/content/MyDrive/vegi_data/test/piiman/")
path_zucchini_test = os.listdir("/content/content/MyDrive/vegi_data/test/zucchini/")
path_paprika_test = os.listdir("/content/content/MyDrive/vegi_data/test/paprika/")
path_abocado_test = os.listdir("/content/content/MyDrive/vegi_data/test/abocado/")
path_onion_test = os.listdir("/content/content/MyDrive/vegi_data/test/onion/")
path_naganegi_test = os.listdir("/content/content/MyDrive/vegi_data/test/naganegi/")
path_nira_test = os.listdir("/content/content/MyDrive/vegi_data/test/nira/")
path_garlic_test = os.listdir("/content/content/MyDrive/vegi_data/test/garlic/")
path_moroheiya_test = os.listdir("/content/content/MyDrive/vegi_data/test/moroheiya/")
path_ginger_test = os.listdir("/content/content/MyDrive/vegi_data/test/ginger/")
path_potato_test = os.listdir("/content/content/MyDrive/vegi_data/test/potato/")
path_hourensou_test = os.listdir("/content/content/MyDrive/vegi_data/test/hourensou/")
# 画像を置くための空リストを作成
img_komatsuna = []
img_toumyo = []
img_mizuna = []
img_chingensai = []
img_paseri = []
img_saradana = []
img_shungiku = []
img_kabu = []
img_daikon = []
img_pumpkin = []
img_carrot = []
img_piiman = []
img_zucchini = []
img_paprika = []
img_abocado = []
img_onion = []
img_naganegi = []
img_nira = []
img_garlic = []
img_moroheiya = []
img_ginger = []
img_potato = []
img_hourensou = []
img_komatsuna_test = []
img_toumyo_test = []
img_mizuna_test = []
img_chingensai_test = []
img_paseri_test = []
img_saradana_test = []
img_shungiku_test = []
img_kabu_test = []
img_daikon_test = []
img_pumpkin_test = []
img_carrot_test = []
img_piiman_test = []
img_zucchini_test = []
img_paprika_test = []
img_abocado_test = []
img_onion_test = []
img_naganegi_test = []
img_nira_test = []
img_garlic_test = []
img_moroheiya_test = []
img_ginger_test = []
img_potato_test = []
img_hourensou_test = []
for i in range(len(path_komatsuna)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/komatsuna/" + path_komatsuna[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_komatsuna.append(img)
for i in range(len(path_toumyo)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/toumyo/" + path_toumyo[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_toumyo.append(img)
for i in range(len(path_mizuna)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/mizuna/" + path_mizuna[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_mizuna.append(img)
for i in range(len(path_chingensai)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/chingensai/" + path_chingensai[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_chingensai.append(img)
for i in range(len(path_paseri)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/paseri/" + path_paseri[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_paseri.append(img)
for i in range(len(path_saradana)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/saradana/" + path_saradana[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_saradana.append(img)
for i in range(len(path_shungiku)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/shungiku/" + path_shungiku[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_shungiku.append(img)
for i in range(len(path_kabu)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/kabu/" + path_kabu[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_kabu.append(img)
for i in range(len(path_daikon)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/daikon/" + path_daikon[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_daikon.append(img)
for i in range(len(path_pumpkin)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/pumpkin/" + path_pumpkin[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_pumpkin.append(img)
for i in range(len(path_carrot)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/carrot/" + path_carrot[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_carrot.append(img)
for i in range(len(path_piiman)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/piiman/" + path_piiman[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_piiman.append(img)
for i in range(len(path_zucchini)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/zucchini/" + path_zucchini[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_zucchini.append(img)
for i in range(len(path_abocado)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/abocado/" + path_abocado[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_abocado.append(img)
for i in range(len(path_onion)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/onion/" + path_onion[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_onion.append(img)
for i in range(len(path_naganegi)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/naganegi/" + path_naganegi[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_naganegi.append(img)
for i in range(len(path_nira)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/nira/" + path_nira[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_nira.append(img)
for i in range(len(path_garlic)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/garlic/" + path_garlic[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_garlic.append(img)
for i in range(len(path_moroheiya)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/moroheiya/" + path_moroheiya[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_moroheiya.append(img)
for i in range(len(path_ginger)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/ginger/" + path_ginger[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_ginger.append(img)
for i in range(len(path_potato)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/potato/" + path_potato[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_potato.append(img)
for i in range(len(path_hourensou)):
img = cv2.imread("/content/content/MyDrive/vegi_data/train/hourensou/" + path_hourensou[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_hourensou.append(img)
for i in range(len(path_komatsuna_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/komatsuna/" + path_komatsuna_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_komatsuna_test.append(img)
for i in range(len(path_toumyo_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/toumyo/" + path_toumyo_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_toumyo_test.append(img)
for i in range(len(path_mizuna_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/mizuna/" + path_mizuna_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_mizuna_test.append(img)
for i in range(len(path_chingensai_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/chingensai/" + path_chingensai_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_chingensai_test.append(img)
for i in range(len(path_paseri_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/paseri/" + path_paseri_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_paseri_test.append(img)
for i in range(len(path_saradana_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/saradana/" + path_saradana_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_saradana_test.append(img)
for i in range(len(path_shungiku_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/shungiku/" + path_shungiku_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_shungiku_test.append(img)
for i in range(len(path_kabu_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/kabu/" + path_kabu_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_kabu_test.append(img)
for i in range(len(path_daikon_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/daikon/" + path_daikon_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_daikon_test.append(img)
for i in range(len(path_pumpkin_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/pumpkin/" + path_pumpkin_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_pumpkin_test.append(img)
for i in range(len(path_carrot_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/carrot/" + path_carrot_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_carrot_test.append(img)
for i in range(len(path_piiman_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/piiman/" + path_piiman_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_piiman_test.append(img)
for i in range(len(path_zucchini_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/zucchini/" + path_zucchini_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_zucchini_test.append(img)
for i in range(len(path_paprika_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/paprika/" + path_paprika_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_paprika_test.append(img)
for i in range(len(path_abocado_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/abocado/" + path_abocado_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_abocado_test.append(img)
for i in range(len(path_onion_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/onion/" + path_onion_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_onion_test.append(img)
for i in range(len(path_naganegi_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/naganegi/" + path_naganegi_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_naganegi_test.append(img)
for i in range(len(path_nira_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/nira/" + path_nira_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_nira_test.append(img)
for i in range(len(path_garlic_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/garlic/" + path_garlic_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_garlic_test.append(img)
for i in range(len(path_moroheiya_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/moroheiya/" + path_moroheiya_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_moroheiya_test.append(img)
for i in range(len(path_ginger_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/ginger/" + path_ginger_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_ginger_test.append(img)
for i in range(len(path_potato_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/potato/" + path_potato_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_potato_test.append(img)
for i in range(len(path_hourensou_test)):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/hourensou/" + path_hourensou_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
img = cv2.resize(img, (50,50))
img_hourensou_test.append(img)
#np.arrayでXに学習画像、yに正解ラベルを代入
X = np.array(img_komatsuna + img_toumyo + img_mizuna + img_chingensai + img_paseri + img_saradana + img_shungiku + img_kabu + img_daikon + img_pumpkin + img_carrot + img_piiman + img_zucchini + img_paprika + img_abocado + img_onion + img_naganegi + img_nira + img_garlic + img_moroheiya + img_ginger + img_potato + img_hourensou)
y = np.array([0]*len(img_komatsuna) + [1]*len(img_toumyo) + [2]*len(img_mizuna) + [3]*len(img_chingensai) + [4]*len(img_paseri) + [5]*len(img_saradana) + [6]*len(img_shungiku) + [7]*len(img_kabu) + [8]*len(img_daikon) + [9]*len(img_pumpkin) + [10]*len(img_carrot) + [11]*len(img_piiman) + [12]*len(img_zucchini) + [13]*len(img_paprika) + [14]*len(img_abocado) + [15]*len(img_onion) + [16]*len(img_naganegi) + [17]*len(img_nira) + [18]*len(img_garlic) + [19]*len(img_moroheiya) + [20]*len(img_ginger) + [21]*len(img_potato) + [22]*len(img_hourensou))
XX = np.array(img_komatsuna_test + img_toumyo_test + img_mizuna_test + img_chingensai_test + img_paseri_test + img_saradana_test + img_shungiku_test + img_kabu_test + img_daikon_test + img_pumpkin_test + img_carrot_test + img_piiman_test + img_zucchini_test + img_paprika_test + img_abocado_test + img_onion_test + img_naganegi_test + img_nira_test + img_garlic_test + img_moroheiya_test + img_ginger_test + img_potato_test + img_hourensou_test)
yy = np.array([0]*len(img_komatsuna_test) + [1]*len(img_toumyo_test) + [2]*len(img_mizuna_test) + [3]*len(img_chingensai_test) + [4]*len(img_paseri_test) + [5]*len(img_saradana_test) + [6]*len(img_shungiku_test) + [7]*len(img_kabu_test) + [8]*len(img_daikon_test) + [9]*len(img_pumpkin_test) + [10]*len(img_carrot_test) + [11]*len(img_piiman_test) + [12]*len(img_zucchini_test) + [13]*len(img_paprika_test) + [14]*len(img_abocado_test) + [15]*len(img_onion_test) + [16]*len(img_naganegi_test) + [17]*len(img_nira_test) + [18]*len(img_garlic_test) + [19]*len(img_moroheiya_test) + [20]*len(img_ginger_test) + [21]*len(img_potato_test) + [22]*len(img_hourensou_test))
#配列のラベルをシャッフルする
rand_index = np.random.permutation(np.arange(len(X)))
X = X[rand_index]
y = y[rand_index]
#学習データと検証データを用意
X_train = X
y_train = y
X_test = XX
y_test = yy
#正解ラベルをone-hotベクトルで求める
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
分析
VGG16を使い転移学習を行います。
すでに学習済みのモデルを用いることで学習時間の短縮と少ないデータで高い精度の実現を試みます。
from keras.layers import Input
#モデルの入力画像として用いるためのテンソールのオプション
input_tensor = Input(shape=(50, 50, 3))
#転移学習のモデルとしてVGG16を使用
vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)
#モデルの定義〜活性化関数レル
top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(rate=0.5))
top_model.add(Dense(23, activation='softmax'))
#モデルの連結
model = Model(inputs=vgg16.input, outputs=top_model(vgg16.output))
#vgg16の重みの固定
for layer in model.layers[:19]:
layer.trainable = False
#モデルをコンパイルする
model.compile(loss='categorical_crossentropy',
optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
metrics=['accuracy'])
model.summary()
history = model.fit(X_train, y_train, batch_size=128, epochs=30, validation_data=(X_test, y_test))
#評価を行う
score = model.evaluate(X_test, y_test, batch_size=128, verbose=0)
print('validation loss:{0[0]}\nvalidation accuracy:{0[1]}'.format(score))
#グラフを表示する
plt.plot(history.history["accuracy"], label="accuracy", ls="-", marker="o")
plt.plot(history.history["val_accuracy"], label="val_accuracy", ls="-", marker="x")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.legend(loc="best")
plt.show()
# 画像を一枚受け取り、野菜を判定して返す関数
#def pred_vegi(img):
def pred_vegi(img):
img = cv2.resize(img, (50, 50))
img = img.astype('float32') / 255.0
pred = np.argmax(model.predict(np.array([img])))
if pred == 0:
return "komatsuna"
elif pred == 1:
return "toumyo"
elif pred == 2:
return "mizuna"
elif pred == 3:
return "chingensai"
elif pred == 4:
return "paseri"
elif pred == 5:
return "saradana"
elif pred == 6:
return "shungiku"
elif pred == 7:
return "kabu"
elif pred == 8:
return "daikon"
elif pred == 9:
return "pumpkin"
elif pred == 10:
return "carrot"
elif pred == 11:
return "piiman"
elif pred == 12:
return "zucchini"
elif pred == 13:
return "paprika"
elif pred == 14:
return "abocado"
elif pred == 15:
return "onion"
elif pred == 16:
return "naganegi"
elif pred == 17:
return "nira"
elif pred == 18:
return "garlic"
elif pred == 19:
return "moroheiya"
elif pred == 20:
return "ginger"
elif pred == 21:
return "potato"
elif pred == 22:
return "hourensou"
# 精度の評価(適切なモデル名に変えて、コメントアウトを外してください)
scores = model.evaluate(X_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
# pred_vegi関数に写真を渡して野菜を予測します
for i in range(5):
img = cv2.imread("/content/content/MyDrive/vegi_data/test/carrot/" + path_carrot_test[i])
b,g,r = cv2.split(img)
img = cv2.merge([r,g,b])
plt.imshow(img)
plt.show()
print(pred_vegi(img))
結果を確認してみます。
8/8 [==============================] - 5s 614ms/step - loss: 2.9587 - accuracy: 0.1870
Test loss: 2.9587271213531494
Test accuracy: 0.186956524848938
全然ダメです。
似たような野菜が多く判別が難しいのかもしれないと思い、項目を減らしてみます。
「与えてもいい野菜」を小松菜、豆苗、にんじん、ピーマンの4種類にしてみます。
「与えられない野菜」をアボカド、玉ねぎ、にらの3種類にします。
他にも野菜の画像枚数が100枚づつだと足りないと感じたため、訓練用画像データの水増しを行います。
訓練用画像データの水増し
#画像の水増し
def scratch_image(img, flip=True, thr=True, filt=True, resize=True, erode=True):
methods = [flip, thr, filt, resize, erode]
img_size = img.shape
filter1 = np.ones((3, 3))
images = [img]
scratch = np.array([
lambda x: cv2.flip(x, 1),
lambda x: cv2.threshold(x, 100, 255, cv2.THRESH_TOZERO)[1],
lambda x: cv2.GaussianBlur(x, (5,5), 0),
lambda x: cv2.resize(cv2.resize(x, (x.shape[1]//5, x.shape[0]//5)),(x.shape[1], x.shape[0])),
lambda x: cv2.erode(x, filter1)
])
doubling_images = lambda f, imag: (imag + [f(i) for i in imag])
for func in scratch[methods]:
images = doubling_images(func, images)
return images
scratch_train_images = []
scratch_train_labels = []
for im,label in zip(X_train,y_train):
tmp = scratch_image(im, flip=True, thr=True, filt=False, resize=False, erode=False)
scratch_train_images += tmp
scratch_train_labels += [label]*4
X_train=np.array(scratch_train_images)
y_train=np.array(scratch_train_labels)
閾値処理と左右反転した画像を増やしてデータを4倍にしました。
何度か試してみてbatch_size=128、 epochs=30に決めました。
結果と考察
3/3 [==============================] - 2s 473ms/step - loss: 1.5115 - accuracy: 0.7286
Test loss: 1.511489987373352
Test accuracy: 0.7285714149475098
当初は23分類、精度80%以上を目指していましたが、7分類までが限界でした。
より高い精度を出すためにできそうなこととしては、
・画像枚数をさらに増やす
・画像サイズを大きくする(150×150など)
などが考えられると思います。
プログラミングほぼ未経験で受講を始めましたが、アプリを制作できるまでになりました。
独学で取り組むのと違い分からないことを聞ける環境の効果は絶大だと感じました。
今後もプログラミングの学習は継続していきたいと思います。
また別のアプリ制作にもチャレンジしようと思います。
参考リンク