10
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

顔画像から性別を判定

Last updated at Posted at 2018-11-16

概要

Keras学習済みモデルのXceptionをUTKFaceデータセットでFine-tuningさせ、性別分類モデルを構築する

UTKFace
https://susanqq.github.io/UTKFace/

  • 20,000以上の顔画像データセット
  • 性別、年齢、人種のラベリング

実行環境
Google Colaboratory(GPU)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%config InlineBackend.figure_formats = {'png', 'retina'}
import os, zipfile, io, re
from PIL import Image
from sklearn.model_selection import train_test_split
from keras.applications.xception import Xception
from keras.models import Model, load_model
from keras.layers.core import Dense
from keras.layers.pooling import GlobalAveragePooling2D
from keras.optimizers import Adam, RMSprop, SGD
from keras.utils.np_utils import to_categorical
from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
from keras.preprocessing.image import ImageDataGenerator

データ取得

画像は100にリサイズ
ZIPファイルからデータセットを取得し、配列に変換
クラスはmale、femaleの2クラス(各画像ファイル名にラベリングされている)
取得したデータセットはtrain、valid、testに分割

image_size = 100
classes = ["male", "female"]
num_classes = len(classes)
%%time
# ZIP読み込み
z = zipfile.ZipFile('../dataset/UTKFace.zip')
# 画像ファイルパスのみ取得
imgfiles = [ x for x in z.namelist() if re.search(r"^UTKFace.*jpg$", x)]

X = []
Y = []
for imgfile in imgfiles:
    # ZIPから画像読み込み
    image = Image.open(io.BytesIO(z.read(imgfile)))
    # RGB変換
    image = image.convert('RGB')
    # リサイズ
    image = image.resize((image_size, image_size))
    # 画像から配列に変換
    data = np.asarray(image)
    file = os.path.basename(imgfile)
    file_split = [i for i in file.split('_')]
    X.append(data)
    Y.append(file_split[1])
z.close()
del z, imgfiles

X = np.array(X)
Y = np.array(Y)
print(X.shape, Y.shape)

(23708, 100, 100, 3) (23708,)
CPU times: user 22.8 s, sys: 1.28 s, total: 24.1 s
Wall time: 24.1 s

# trainデータとtestデータに分割
X_train, X_test, y_train, y_test = train_test_split(
    X, Y,
    random_state = 0,
    stratify = Y,
    test_size = 0.2
)
del X,Y
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

(18966, 100, 100, 3) (18966,) (4742, 100, 100, 3) (4742,)

# データ型の変換&正規化
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
# one-hot変換
y_train = to_categorical(y_train, num_classes = num_classes)
y_test = to_categorical(y_test, num_classes = num_classes)
# trainデータからvalidデータを分割
X_train, X_valid, y_train, y_valid = train_test_split(
    X_train,
    y_train,
    random_state = 0,
    stratify = y_train,
    test_size = 0.2
)
print(X_train.shape, y_train.shape, X_valid.shape, y_valid.shape) 

(15172, 100, 100, 3) (15172, 2) (3794, 100, 100, 3) (3794, 2)

モデル構築

Xception読み込み
Keras学習済みモデルXceptionを読み込む
その際、ネットワーク出力層側にある全結合層を除去

base_model = Xception(
    include_top = False,
    weights = "imagenet",
    input_shape = None
)

全結合層の新規構築

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation = 'relu')(x)
predictions = Dense(num_classes, activation = 'softmax')(x)

Data Augmentation

datagen = ImageDataGenerator(
    featurewise_center = False,
    samplewise_center = False,
    featurewise_std_normalization = False,
    samplewise_std_normalization = False,
    zca_whitening = False,
    rotation_range = 0,
    width_shift_range = 0.1,
    height_shift_range = 0.1,
    horizontal_flip = True,
    vertical_flip = False
)

Callback

# EarlyStopping
early_stopping = EarlyStopping(
    monitor = 'val_loss',
    patience = 10,
    verbose = 1
)

# ModelCheckpoint
weights_dir = './weights/'
if os.path.exists(weights_dir) == False:os.mkdir(weights_dir)
model_checkpoint = ModelCheckpoint(
    weights_dir + "val_loss{val_loss:.3f}.hdf5",
    monitor = 'val_loss',
    verbose = 1,
    save_best_only = True,
    save_weights_only = True,
    period = 3
)

# reduce learning rate
reduce_lr = ReduceLROnPlateau(
    monitor = 'val_loss',
    factor = 0.1,
    patience = 3,
    verbose = 1
)

# log for TensorBoard
logging = TensorBoard(log_dir = "log/")

モデル学習

XceptionをFine-tuning

# ネットワーク定義
model = Model(inputs = base_model.input, outputs = predictions)

#108層までfreeze
for layer in model.layers[:108]:
    layer.trainable = False

    # Batch Normalizationのfreeze解除
    if layer.name.startswith('batch_normalization'):
        layer.trainable = True
    if layer.name.endswith('bn'):
        layer.trainable = True

#109層以降、学習させる
for layer in model.layers[108:]:
    layer.trainable = True
    
# layer.trainableの設定後にcompile
model.compile(
    optimizer = Adam(),
    loss = 'categorical_crossentropy',
    metrics = ["accuracy"]
)
%%time
hist = model.fit_generator(
    datagen.flow(X_train, y_train, batch_size = 32),
    steps_per_epoch = X_train.shape[0] // 32,
    epochs = 50,
    validation_data = (X_valid, y_valid),
    callbacks = [early_stopping, reduce_lr],
    shuffle = True,
    verbose = 1
)
Epoch 1/50
474/474 [==============================] - 118s 250ms/step - loss: 0.4055 - acc: 0.8159 - val_loss: 0.3959 - val_acc: 0.8743
Epoch 2/50
474/474 [==============================] - 111s 233ms/step - loss: 0.2998 - acc: 0.8685 - val_loss: 0.3020 - val_acc: 0.8888
Epoch 3/50
474/474 [==============================] - 110s 232ms/step - loss: 0.3072 - acc: 0.8735 - val_loss: 0.3128 - val_acc: 0.8980

〜省略〜

Epoch 00013: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.
Epoch 14/50
474/474 [==============================] - 110s 232ms/step - loss: 0.1084 - acc: 0.9596 - val_loss: 0.2715 - val_acc: 0.9135
Epoch 15/50
474/474 [==============================] - 110s 232ms/step - loss: 0.1069 - acc: 0.9595 - val_loss: 0.2687 - val_acc: 0.9143
Epoch 16/50
474/474 [==============================] - 110s 232ms/step - loss: 0.1065 - acc: 0.9591 - val_loss: 0.2706 - val_acc: 0.9151

Epoch 00016: ReduceLROnPlateau reducing learning rate to 1.0000000656873453e-06.
Epoch 17/50
474/474 [==============================] - 110s 232ms/step - loss: 0.1045 - acc: 0.9587 - val_loss: 0.2695 - val_acc: 0.9151
Epoch 00017: early stopping
CPU times: user 39min 56s, sys: 7min 1s, total: 46min 57s
Wall time: 31min 26s

学習曲線をプロット

plt.figure(figsize = (18,6))

# accuracy
plt.subplot(1, 2, 1)
plt.plot(hist.history["acc"], label = "acc", marker = "o")
plt.plot(hist.history["val_acc"], label = "val_acc", marker = "o")
#plt.xticks(np.arange())
#plt.yticks(np.arange())
plt.xlabel("epoch")
plt.ylabel("accuracy")
#plt.title("")
plt.legend(loc = "best")
plt.grid(color = 'gray', alpha = 0.2)

# loss
plt.subplot(1, 2, 2)
plt.plot(hist.history["loss"], label = "loss", marker = "o")
plt.plot(hist.history["val_loss"], label = "val_loss", marker = "o")
#plt.xticks(np.arange())
#plt.yticks(np.arange())
plt.xlabel("epoch")
plt.ylabel("loss")
#plt.title("")
plt.legend(loc = "best")
plt.grid(color = 'gray', alpha = 0.2)

plt.show()

モデル評価

score = model.evaluate(X_test, y_test, verbose = 1)
print("evaluate loss: {[0]:.4f}".format(score))
print("evaluate acc: {[1]:.1%}".format(score))

4742/4742 [==============================] - 12s 3ms/step
evaluate loss: 0.2665
evaluate acc: 91.9%

モデル保存

model_dir = './model/'
if os.path.exists(model_dir) == False:os.mkdir(model_dir)

model.save(model_dir + 'model.hdf5')

# optimizerのない軽量モデルを保存(学習や評価は不可だが、予測は可能)
model.save(model_dir + 'model-opt.hdf5', include_optimizer = False)

モデル予測

testデータ30件の画像と正解ラベルを出力

# testデータ30件の正解ラベル
true_classes = np.argmax(y_test[0:30], axis = 1)

# testデータ30件の画像と正解ラベルを出力
plt.figure(figsize = (16, 6))
for i in range(30):
    plt.subplot(3, 10, i + 1)
    plt.axis("off")
    plt.title(classes[true_classes[i]])
    plt.imshow(X_test[i])
plt.show()

testデータ30件の画像と予測ラベル・予測確率を出力

# testデータ30件の予測ラベル
pred_classes = np.argmax(model.predict(X_test[0:30]), axis = 1)

# testデータ30件の予測確率
pred_probs = np.max(model.predict(X_test[0:30]), axis = 1)
pred_probs = ['{:.4f}'.format(i) for i in pred_probs]

# testデータ30件の画像と予測ラベル・予測確率を出力
plt.figure(figsize = (16, 6))
for i in range(30):
    plt.subplot(3, 10, i + 1)
    plt.axis("off")
    if pred_classes[i] == true_classes[i]:
        plt.title(classes[pred_classes[i]]+'\n'+pred_probs[i])
    else:
        plt.title(classes[pred_classes[i]]+'\n'+pred_probs[i], color = "red")
    plt.imshow(X_test[i])
plt.show()

モデル評価でaccuracy 91.9%、loss 0.2665を計測
決して高い精度とは言えず、いくつか誤りが見られる
ただ目視でも見誤ってしまいそう

予測検証

testデータ1000件を抽出して、低い確率で正解した画像高い確率で間違えた画像を目視で確認

testデータ1000件抽出

# testデータ1000件の正解ラベル
true_classes = np.argmax(y_test[0:1000], axis = 1)

# testデータ1000件の予測ラベル
pred_classes = np.argmax(model.predict(X_test[0:1000]), axis = 1)

# testデータ1000件の予測確率
pred_probs = np.max(model.predict(X_test[0:1000]), axis = 1)
pred_probs = np.round(pred_probs, 2)
pred_probs = ['{:.4f}'.format(i) for i in pred_probs]

正解・不正解リスト作成
(画像データ, 予測ラベル, 予測確率)の正解・不正解それぞれのリストを作成

correct=[]
incorrect=[]

for i in range(1000):
    if pred_classes[i] == true_classes[i]:
        correct.append((X_test[i], classes[pred_classes[i]], pred_probs[i]))
    else:
        incorrect.append((X_test[i], classes[pred_classes[i]], pred_probs[i]))
        
print("number of correct:",len(correct))
print("number of incorrect:",len(incorrect))

number of correct: 922
number of incorrect: 78

低確率の正解画像
確率の低い順に正解画像を出力

# 正解画像を確率の低い順に並び替え
correct.sort(key = lambda x:x[2])

# ワースト30件の画像と予測ラベル・予測確率を出力
plt.figure(figsize = (16, 6))
for i in range(30):
    plt.subplot(3, 10, i + 1)
    plt.axis("off")
    plt.title(correct[i][1]+'\n'+correct[i][2])
    plt.imshow(correct[i][0])
plt.show()

赤ちゃんが多い気がする
赤ちゃんの画像の性別は、明確に判断することが難しい?

高確率の不正解画像
確率の高い順に不正解画像を出力

# 不正解画像を確率の高い順に並び替え
incorrect.sort(key = lambda x:x[2], reverse = True)

# ベスト30件の画像と予測ラベル・予測確率を出力
plt.figure(figsize = (16, 6))
for i in range(30):
    plt.subplot(3, 10, i + 1)
    plt.axis("off")
    plt.title(incorrect[i][1]+'\n'+incorrect[i][2], color = 'red')
    plt.imshow(incorrect[i][0])
plt.show()

今回は目視での検証だが、今後は正誤の男女比や年齢・人種による正誤の違いを検証したい。

10
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?