1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【SIGANTE】画像ラベリング(10種類) 画像データの拡張(水増し)

Last updated at Posted at 2020-12-14

全体の構成

###①データの下準備
https://qiita.com/hara_tatsu/items/a90173d33cb381648f72
###②画像データの拡張(水増し)
https://qiita.com/hara_tatsu/items/86ddf3c00a374e9ae796
###③転移学習
https://qiita.com/hara_tatsu/items/bc93fb61b7ccbc639eed

環境情報

Python 3.6.5
tensorflow 2.3.1

【SIGANTE】画像ラベリング(10種類)について

画像データに対して、10種類のラベルの1つを割り当てるモデルを作成します。

学習データサンプル数:5000

以下リンク
https://signate.jp/competitions/133

画像データの拡張(水増し)

ここまでの処理は以下を参照(①データの下準備)
https://qiita.com/hara_tatsu/items/a90173d33cb381648f72

ライブラリーのインポート

python.py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# 画像処理
from PIL import Image

# tensorflow
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator

画像データの拡張(水増し)

python.py
# 画像の加工内容
image_gen = ImageDataGenerator(
              rotation_range=45, # 45°回転
              horizontal_flip = True # 水平移動
             )

#元となるデータ数を増やす
X_gen = image_gen.flow(X_train, Y_train, batch_size=32)

# 画像データを格納
X = []
# ラベルデータを格納
Y = []

for _ in range(10):
    #データ生成
    X_gen_new, y_gen_new = X_gen.__next__()
    #データを結合
    X_train = np.concatenate([X_train, X_gen_new])
    Y_train = np.concatenate([Y_train, y_gen_new])
    X.append(X_train)
    Y.append(Y_train)
    #データサイズを表示
    print(X_train.shape, Y_train.shape)
# データ数が2770まで増えた(batch_sizeを増やせばもっと増やせる)
(2482, 96, 96, 3) (2482, 10)
(2514, 96, 96, 3) (2514, 10)
(2546, 96, 96, 3) (2546, 10)
(2578, 96, 96, 3) (2578, 10)
(2610, 96, 96, 3) (2610, 10)
(2642, 96, 96, 3) (2642, 10)
(2674, 96, 96, 3) (2674, 10)
(2706, 96, 96, 3) (2706, 10)
(2738, 96, 96, 3) (2738, 10)
(2770, 96, 96, 3) (2770, 10)

教師データの決定

python.py
X_train = X[-1]
Y_train = Y[-1]
print(X_train.shape, Y_train.shape)
# 教師データとして採用
(2770, 96, 96, 3) (2770, 10)

画像データの確認

python.py
# 1行5列のグリッド形式で画像をプロットする関数。
def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# 教師データをまとめる
train_data_gen = image_gen.flow(X_train, Y_train, batch_size = 32, shuffle = False)
augmented_images = [train_data_gen[0][0][0] for i in range(5)]
# 表示
plotImages(augmented_images)

こんな感じ!!

スクリーンショット 2020-12-14 19.40.35.png

python.py
# 検証データも教師データと同じ形へ変更
valid_gen = ImageDataGenerator()
valid_data_gen = valid_gen.flow(X_valid, Y_valid, batch_size = 32)

おわりに

次回は転移学習!!

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?