Help us understand the problem. What is going on with this article?

顔画像から年齢を予測

More than 1 year has passed since last update.

概要

Keras学習済みモデルのXceptionをUTKFaceデータセットでFine-tuningさせ、年齢回帰モデルを構築する

UTKFace
https://susanqq.github.io/UTKFace/
- 20,000以上の顔画像データセット
- 性別、年齢、人種のラベリング

実行環境
Google Colaboratory(GPU)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%config InlineBackend.figure_formats = {'png', 'retina'}
import os, zipfile, io, re
from PIL import Image
from sklearn.model_selection import train_test_split
from keras.applications.xception import Xception
from keras.models import Model, load_model
from keras.layers.core import Dense
from keras.layers.pooling import GlobalAveragePooling2D
from keras.optimizers import Adam, RMSprop, SGD
from keras.utils.np_utils import to_categorical
from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
from keras.preprocessing.image import ImageDataGenerator

データ取得

画像は100にリサイズ
ZIPファイルからデータセットを取得し、配列に変換
年齢は各画像のファイル名から取得
取得したデータセットはtrain、valid、testに分割

# 画像入力サイズ
image_size=100
%%time
# ZIP読み込み
z = zipfile.ZipFile('../dataset/UTKFace.zip')
# 画像ファイルパスのみ取得
imgfiles = [ x for x in z.namelist() if re.search(r"^UTKFace.*jpg$", x)]

X=[]
Y=[]

for imgfile in imgfiles:
    # ZIPから画像読み込み
    image = Image.open(io.BytesIO(z.read(imgfile)))
    # RGB変換
    image = image.convert('RGB')
    # リサイズ
    image = image.resize((image_size, image_size))
    # 画像から配列に変換
    data = np.asarray(image)
    file = os.path.basename(imgfile)
    file_split = [i for i in file.split('_')]
    X.append(data)
    Y.append(int(file_split[0]))

z.close()

X = np.array(X)
Y = np.array(Y)

del z, imgfiles

print(X.shape, Y.shape)

(23708, 100, 100, 3) (23708,)
CPU times: user 24.3 s, sys: 1.33 s, total: 25.6 s
Wall time: 25.7 s

# trainデータとtestデータに分割
X_train, X_test, y_train, y_test = train_test_split(
    X,
    Y,
    random_state = 0,
    test_size = 0.2
)
del X,Y
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape) 

(18966, 100, 100, 3) (18966,) (4742, 100, 100, 3) (4742,)

# データ型の変換&正規化
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255

# trainデータからvalidデータを分割
X_train, X_valid, y_train, y_valid = train_test_split(
    X_train,
    y_train,
    random_state = 0,
    test_size = 0.2
)
print(X_train.shape, y_train.shape, X_valid.shape, y_valid.shape) 

(15172, 100, 100, 3) (15172,) (3794, 100, 100, 3) (3794,)

モデル構築

Xception読み込み
Keras学習済みモデルXceptionを読み込む
その際、ネットワーク出力層側にある全結合層を除去

base_model = Xception(
    include_top = False,
    weights = "imagenet",
    input_shape = None
)

全結合層の新規構築
今回は分類ではなく、回帰のため、最後に予測値を1つ出力
そのため出力層のユニット数を1つにし、活性化関数のsoftmax(確率値に変換)も使用しない

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(1)(x)

Data Augmentation

datagen = ImageDataGenerator(
    featurewise_center = False,
    samplewise_center = False,
    featurewise_std_normalization = False,
    samplewise_std_normalization = False,
    zca_whitening = False,
    rotation_range = 0,
    width_shift_range = 0.1,
    height_shift_range = 0.1,
    horizontal_flip = True,
    vertical_flip = False
)

Callback

# EarlyStopping
early_stopping = EarlyStopping(
    monitor = 'val_loss',
    patience = 10,
    verbose = 1
)

# ModelCheckpoint
weights_dir = './weights/'
if os.path.exists(weights_dir) == False:os.mkdir(weights_dir)
model_checkpoint = ModelCheckpoint(
    weights_dir + "val_loss{val_loss:.3f}.hdf5",
    monitor = 'val_loss',
    verbose = 1,
    save_best_only = True,
    save_weights_only = True,
    period = 3
)

# reduce learning rate
reduce_lr = ReduceLROnPlateau(
    monitor = 'val_loss',
    factor = 0.1,
    patience = 3,
    verbose = 1
)

# log for TensorBoard
logging = TensorBoard(log_dir = "log/")

RMSE(Root Mean Squared Error)
今回使用する損失関数であるRMSEを定義

# RMSE
from keras import backend as K
def root_mean_squared_error(y_true, y_pred):
        return K.sqrt(K.mean(K.square(y_pred - y_true), axis = -1)) 

モデル学習

XceptionをFine-tuning

# ネットワーク定義
model = Model(inputs = base_model.input, outputs = predictions)

#108層までfreeze
for layer in model.layers[:108]:
    layer.trainable = False

    # Batch Normalizationのfreeze解除
    if layer.name.startswith('batch_normalization'):
        layer.trainable = True
    if layer.name.endswith('bn'):
        layer.trainable = True

#109層以降、学習させる
for layer in model.layers[108:]:
    layer.trainable = True

# layer.trainableの設定後にcompile
model.compile(
    optimizer = Adam(),
    loss = root_mean_squared_error,
)
%%time
hist = model.fit_generator(
    datagen.flow(X_train, y_train, batch_size = 32),
    steps_per_epoch = X_train.shape[0] // 32,
    epochs = 50,
    validation_data = (X_valid, y_valid),
    callbacks = [early_stopping, reduce_lr],
    shuffle = True,
    verbose = 1
)
Epoch 1/50
474/474 [==============================] - 120s 253ms/step - loss: 9.8930 - val_loss: 12.4886
Epoch 2/50
474/474 [==============================] - 110s 233ms/step - loss: 7.6839 - val_loss: 8.0459
Epoch 3/50
474/474 [==============================] - 110s 232ms/step - loss: 7.0648 - val_loss: 6.6014

〜省略〜

Epoch 00034: ReduceLROnPlateau reducing learning rate to 1.0000000656873453e-06.
Epoch 35/50
474/474 [==============================] - 111s 235ms/step - loss: 3.2090 - val_loss: 5.1658
Epoch 36/50
474/474 [==============================] - 111s 234ms/step - loss: 3.1937 - val_loss: 5.1755
Epoch 37/50
474/474 [==============================] - 111s 234ms/step - loss: 3.2478 - val_loss: 5.1742

Epoch 00037: ReduceLROnPlateau reducing learning rate to 1.0000001111620805e-07.
Epoch 38/50
474/474 [==============================] - 111s 234ms/step - loss: 3.2061 - val_loss: 5.1679
Epoch 00038: early stopping
CPU times: user 1h 27min 31s, sys: 16min 3s, total: 1h 43min 34s
Wall time: 1h 10min 15s

学習曲線プロット

plt.figure(figsize=(18,6))

# loss
plt.subplot(1, 2, 1)
plt.plot(hist.history["loss"], label="loss", marker="o")
plt.plot(hist.history["val_loss"], label="val_loss", marker="o")
#plt.yticks(np.arange())
#plt.xticks(np.arange())
plt.ylabel("loss")
plt.xlabel("epoch")
plt.title("")
plt.legend(loc="best")
plt.grid(color='gray', alpha=0.2)

plt.show()

モデル評価

score = model.evaluate(X_test, y_test, verbose=1)
print("evaluate loss: {}".format(score))

4742/4742 [==============================] - 13s 3ms/step
evaluate loss: 5.380563393668855

年齢誤差の平均が、およそ5.4歳ということかな

モデル保存

model_dir = './model/'
if os.path.exists(model_dir) == False : os.mkdir(model_dir)

model.save(model_dir + 'model.hdf5')

# optimizerのない軽量モデルを保存(学習や評価不可だが、予測は可能)
model.save(model_dir + 'model-opt.hdf5', include_optimizer = False)

モデル予測

testデータ30件の画像と正解値&予測値を出力

# testデータ30件の予測値
preds=model.predict(X_test[0:30])

# testデータ30件の画像と正解値&予測値を出力
plt.figure(figsize=(16, 6))
for i in range(30):
    plt.subplot(3, 10, i+1)
    plt.axis("off")
    pred = round(preds[i][0],1)
    true = y_test[i]
    if abs(pred - true) < 5.4:
        plt.title(str(true) + '\n' + str(pred))
    else:
        plt.title(str(true) + '\n' + str(pred), color = "red")
    plt.imshow(X_test[i])
plt.show()

各画像上段の数値が実年齢、下段が予測年齢
実年齢に対する予測年齢の誤差が年齢誤差の平均値よりも大きかった画像は赤字で出力

今後は、
- 誤差の大小の比率
- 大きな誤差は、どの年代で多いか
- 人種・性別による誤差の違い
- 実年齢よりも上で予測しているのか、下なのか
を検証したい。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした