全体の構成
###①データの下準備
https://qiita.com/hara_tatsu/items/a90173d33cb381648f72
###②画像データの拡張(水増し)
https://qiita.com/hara_tatsu/items/86ddf3c00a374e9ae796
###③転移学習
https://qiita.com/hara_tatsu/items/bc93fb61b7ccbc639eed
環境情報
Python 3.6.5
tensorflow 2.3.1
【SIGANTE】画像ラベリング(10種類)について
画像データに対して、10種類のラベルの1つを割り当てるモデルを作成します。
学習データサンプル数:5000
以下リンク
https://signate.jp/competitions/133
画像データの拡張(水増し)
ここまでの処理は以下を参照(①データの下準備)
https://qiita.com/hara_tatsu/items/a90173d33cb381648f72
ライブラリーのインポート
python.py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# 画像処理
from PIL import Image
# tensorflow
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
画像データの拡張(水増し)
python.py
# 画像の加工内容
image_gen = ImageDataGenerator(
rotation_range=45, # 45°回転
horizontal_flip = True # 水平移動
)
#元となるデータ数を増やす
X_gen = image_gen.flow(X_train, Y_train, batch_size=32)
# 画像データを格納
X = []
# ラベルデータを格納
Y = []
for _ in range(10):
#データ生成
X_gen_new, y_gen_new = X_gen.__next__()
#データを結合
X_train = np.concatenate([X_train, X_gen_new])
Y_train = np.concatenate([Y_train, y_gen_new])
X.append(X_train)
Y.append(Y_train)
#データサイズを表示
print(X_train.shape, Y_train.shape)
# データ数が2770まで増えた(batch_sizeを増やせばもっと増やせる)
(2482, 96, 96, 3) (2482, 10)
(2514, 96, 96, 3) (2514, 10)
(2546, 96, 96, 3) (2546, 10)
(2578, 96, 96, 3) (2578, 10)
(2610, 96, 96, 3) (2610, 10)
(2642, 96, 96, 3) (2642, 10)
(2674, 96, 96, 3) (2674, 10)
(2706, 96, 96, 3) (2706, 10)
(2738, 96, 96, 3) (2738, 10)
(2770, 96, 96, 3) (2770, 10)
教師データの決定
python.py
X_train = X[-1]
Y_train = Y[-1]
print(X_train.shape, Y_train.shape)
# 教師データとして採用
(2770, 96, 96, 3) (2770, 10)
画像データの確認
python.py
# 1行5列のグリッド形式で画像をプロットする関数。
def plotImages(images_arr):
fig, axes = plt.subplots(1, 5, figsize=(20,20))
axes = axes.flatten()
for img, ax in zip( images_arr, axes):
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.show()
# 教師データをまとめる
train_data_gen = image_gen.flow(X_train, Y_train, batch_size = 32, shuffle = False)
augmented_images = [train_data_gen[0][0][0] for i in range(5)]
# 表示
plotImages(augmented_images)
こんな感じ!!
python.py
# 検証データも教師データと同じ形へ変更
valid_gen = ImageDataGenerator()
valid_data_gen = valid_gen.flow(X_valid, Y_valid, batch_size = 32)