9
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

顔画像から人種を判定

Last updated at Posted at 2018-11-20

概要

Keras学習済みモデルのXceptionをUTKFaceデータセットでFine-tuningさせ、人種分類モデルを構築する
白人、黒人、アジア、インド、その他(ヒスパニック、ラテン、中東など)の5つに分類

UTKFace
https://susanqq.github.io/UTKFace/

  • 20,000以上の顔画像データセット
  • 性別、年齢、人種のラベリング

実行環境
Google Colaboratory(GPU)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%config InlineBackend.figure_formats = {'png', 'retina'}
import os, zipfile, io, re
from PIL import Image
from sklearn.model_selection import train_test_split
from keras.applications.xception import Xception
from keras.models import Model, load_model
from keras.layers.core import Dense
from keras.layers.pooling import GlobalAveragePooling2D
from keras.optimizers import Adam, RMSprop, SGD
from keras.utils.np_utils import to_categorical
from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
from keras.preprocessing.image import ImageDataGenerator

データ取得

画像は100にリサイズ
ZIPファイルからデータセットを取得し、配列に変換
クラスはWhite、Black、Asian、Indian、Othersの5クラス(各画像ファイル名にラベリングされている)
取得したデータセットはtrain、valid、testに分割

classes = ["White", "Black", "Asian", "Indian", "Others"]
num_classes = len(classes)
image_size = 100
%%time
# ZIP読み込み
z = zipfile.ZipFile('../dataset/UTKFace.zip')
# 画像ファイルパスのみ取得
imgfiles = [ x for x in z.namelist() if re.search(r"^UTKFace.*jpg$", x)]

X = []
Y = []
for imgfile in imgfiles:
    # ZIPから画像読み込み
    image = Image.open(io.BytesIO(z.read(imgfile)))
    # RGB変換
    image = image.convert('RGB')
    # リサイズ
    image = image.resize((image_size, image_size))
    # 画像から配列に変換
    data = np.asarray(image)
    file = os.path.basename(imgfile)
    file_split = [i for i in file.split('_')]
    X.append(data)
    Y.append(file_split[2])
z.close()
del z, imgfiles

X = np.array(X)
Y = np.array(Y)
print(X.shape, Y.shape)

(23708, 100, 100, 3) (23708,)
CPU times: user 23.9 s, sys: 387 ms, total: 24.3 s
Wall time: 24.3 s

人種がラベリングされていない3つのデータを削除

# df_Y=pd.DataFrame(Y)
# pd.DataFrame(df_Y[0].value_counts()).sort_index().rename(columns={0 :'num'}).T

index = np.where(
    (Y == "20170109142408075.jpg.chip.jpg") | 
    (Y == "20170109150557335.jpg.chip.jpg") | 
    (Y == "20170116174525125.jpg.chip.jpg")
)
X=np.delete(X, index, axis=0)
Y=np.delete(Y, index, axis=0)
print(X.shape, Y.shape)

(23705, 100, 100, 3) (23705,)

# trainデータとtestデータに分割
X_train, X_test, y_train, y_test = train_test_split(
    X, Y,
    random_state = 0,
    stratify = Y,
    test_size = 0.2
)
del X,Y
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

(18964, 100, 100, 3) (18964,) (4741, 100, 100, 3) (4741,)

# データ型の変換&正規化
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
# one-hot変換
y_train = to_categorical(y_train, num_classes = num_classes)
y_test = to_categorical(y_test, num_classes = num_classes)
# trainデータからvalidデータを分割
X_train, X_valid, y_train, y_valid = train_test_split(
    X_train,
    y_train,
    random_state = 0,
    stratify = y_train,
    test_size = 0.2
)
print(X_train.shape, y_train.shape, X_valid.shape, y_valid.shape) 

(15171, 100, 100, 3) (15171, 5) (3793, 100, 100, 3) (3793, 5)

モデル構築

Xception読み込み
Keras学習済みモデルXceptionを読み込む
その際、ネットワーク出力層側にある全結合層を除去

base_model = Xception(
    include_top = False,
    weights = "imagenet",
    input_shape = None
)

全結合層の新規構築

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation = 'relu')(x)
predictions = Dense(num_classes, activation = 'softmax')(x)

Data Augmentation

datagen = ImageDataGenerator(
    featurewise_center = False,
    samplewise_center = False,
    featurewise_std_normalization = False,
    samplewise_std_normalization = False,
    zca_whitening = False,
    rotation_range = 0,
    width_shift_range = 0.1,
    height_shift_range = 0.1,
    horizontal_flip = True,
    vertical_flip = False
)

Callback

# EarlyStopping
early_stopping = EarlyStopping(
    monitor = 'val_loss',
    patience = 10,
    verbose = 1
)

# ModelCheckpoint
weights_dir = './weights/'
if os.path.exists(weights_dir) == False:os.mkdir(weights_dir)
model_checkpoint = ModelCheckpoint(
    weights_dir + "val_loss{val_loss:.3f}.hdf5",
    monitor = 'val_loss',
    verbose = 1,
    save_best_only = True,
    save_weights_only = True,
    period = 3
)

# reduce learning rate
reduce_lr = ReduceLROnPlateau(
    monitor = 'val_loss',
    factor = 0.1,
    patience = 3,
    verbose = 1
)

# log for TensorBoard
logging = TensorBoard(log_dir = "log/")

モデル学習

XceptionをFine-tuning

# ネットワーク定義
model = Model(inputs = base_model.input, outputs = predictions)

#108層までfreeze
for layer in model.layers[:108]:
    layer.trainable = False

    # Batch Normalizationのfreeze解除
    if layer.name.startswith('batch_normalization'):
        layer.trainable = True
    if layer.name.endswith('bn'):
        layer.trainable = True

#109層以降、学習させる
for layer in model.layers[108:]:
    layer.trainable = True
    
# layer.trainableの設定後にcompile
model.compile(
    optimizer = Adam(),
    loss = 'categorical_crossentropy',
    metrics = ["accuracy"]
)
%%time
hist = model.fit_generator(
    datagen.flow(X_train, y_train, batch_size = 32),
    steps_per_epoch = X_train.shape[0] // 32,
    epochs = 50,
    validation_data = (X_valid, y_valid),
    callbacks = [early_stopping, reduce_lr],
    shuffle = True,
    verbose = 1
)
Epoch 1/50
474/474 [==============================] - 121s 256ms/step - loss: 1.3566 - acc: 0.4722 - val_loss: 1.4610 - val_acc: 0.6011
Epoch 2/50
474/474 [==============================] - 112s 237ms/step - loss: 0.8982 - acc: 0.6872 - val_loss: 0.8991 - val_acc: 0.7037
Epoch 3/50
474/474 [==============================] - 112s 236ms/step - loss: 0.7573 - acc: 0.7388 - val_loss: 0.7338 - val_acc: 0.7611

〜省略〜

Epoch 00018: ReduceLROnPlateau reducing learning rate to 1.0000000656873453e-06.
Epoch 19/50
474/474 [==============================] - 111s 235ms/step - loss: 0.2864 - acc: 0.9008 - val_loss: 0.6332 - val_acc: 0.8070
Epoch 20/50
474/474 [==============================] - 111s 235ms/step - loss: 0.2869 - acc: 0.9000 - val_loss: 0.6365 - val_acc: 0.8089
Epoch 21/50
474/474 [==============================] - 111s 235ms/step - loss: 0.2845 - acc: 0.9029 - val_loss: 0.6342 - val_acc: 0.8081

Epoch 00021: ReduceLROnPlateau reducing learning rate to 1.0000001111620805e-07.
Epoch 22/50
474/474 [==============================] - 111s 235ms/step - loss: 0.2831 - acc: 0.9007 - val_loss: 0.6333 - val_acc: 0.8073
Epoch 00022: early stopping
CPU times: user 52min 7s, sys: 9min 12s, total: 1h 1min 19s
Wall time: 41min 7s

学習曲線をプロット

plt.figure(figsize = (18,6))

# accuracy
plt.subplot(1, 2, 1)
plt.plot(hist.history["acc"], label = "acc", marker = "o")
plt.plot(hist.history["val_acc"], label = "val_acc", marker = "o")
#plt.xticks(np.arange())
#plt.yticks(np.arange())
plt.xlabel("epoch")
plt.ylabel("accuracy")
#plt.title("")
plt.legend(loc = "best")
plt.grid(color = 'gray', alpha = 0.2)

# loss
plt.subplot(1, 2, 2)
plt.plot(hist.history["loss"], label = "loss", marker = "o")
plt.plot(hist.history["val_loss"], label = "val_loss", marker = "o")
#plt.xticks(np.arange())
#plt.yticks(np.arange())
plt.xlabel("epoch")
plt.ylabel("loss")
#plt.title("")
plt.legend(loc = "best")
plt.grid(color = 'gray', alpha = 0.2)

plt.show()

モデル評価

score = model.evaluate(X_test, y_test, verbose = 1)
print("evaluate loss: {[0]:.4f}".format(score))
print("evaluate acc: {[1]:.1%}".format(score))

4741/4741 [==============================] - 13s 3ms/step
evaluate loss: 0.6409054181666798
evaluate acc: 0.8089010757475658

モデル保存

model_dir = './model/'
if os.path.exists(model_dir) == False:os.mkdir(model_dir)

model.save(model_dir + 'model.hdf5')

# optimizerのない軽量モデルを保存(学習や評価は不可だが、予測は可能)
model.save(model_dir + 'model-opt.hdf5', include_optimizer = False)

モデル予測

testデータ30件の画像と正解ラベルを出力

# testデータ30件の正解ラベル
true_classes = np.argmax(y_test[0:30], axis = 1)

# testデータ30件の画像と正解ラベルを出力
plt.figure(figsize = (16, 6))
for i in range(30):
    plt.subplot(3, 10, i + 1)
    plt.axis("off")
    plt.title(classes[true_classes[i]])
    plt.imshow(X_test[i])
plt.show()

testデータ30件の画像と予測ラベル・予測確率を出力

# testデータ30件の予測ラベル
pred_classes = np.argmax(model.predict(X_test[0:30]), axis = 1)

# testデータ30件の予測確率
pred_probs = np.max(model.predict(X_test[0:30]), axis = 1)
pred_probs = ['{:.4f}'.format(i) for i in pred_probs]

# testデータ30件の画像と予測ラベル・予測確率を出力
plt.figure(figsize = (16, 6))
for i in range(30):
    plt.subplot(3, 10, i + 1)
    plt.axis("off")
    if pred_classes[i] == true_classes[i]:
        plt.title(classes[pred_classes[i]]+'\n'+pred_probs[i])
    else:
        plt.title(classes[pred_classes[i]]+'\n'+pred_probs[i], color = "red")
    plt.imshow(X_test[i])
plt.show()

モデル評価でaccuracy 80.9%、loss 0.6409を計測
決して高い精度とは言えず、いくつか誤りが見られる

予測検証

testデータ1000件を抽出して、低い確率で正解した画像高い確率で間違えた画像を目視で確認

testデータ1000件抽出

# testデータ1000件の正解ラベル
true_classes = np.argmax(y_test[0:1000], axis = 1)

# testデータ1000件の予測ラベル
pred_classes = np.argmax(model.predict(X_test[0:1000]), axis = 1)

# testデータ1000件の予測確率
pred_probs = np.max(model.predict(X_test[0:1000]), axis = 1)
pred_probs = np.round(pred_probs, 2)
pred_probs = ['{:.4f}'.format(i) for i in pred_probs]

正解・不正解リスト作成
(画像データ, 予測ラベル, 予測確率)の正解・不正解それぞれのリストを作成

correct=[]
incorrect=[]

for i in range(1000):
    if pred_classes[i] == true_classes[i]:
        correct.append((X_test[i], classes[pred_classes[i]], pred_probs[i]))
    else:
        incorrect.append((X_test[i], classes[pred_classes[i]], pred_probs[i]))
        
print("number of correct:",len(correct))
print("number of incorrect:",len(incorrect))

number of correct : 803
number of incorrect : 197

低確率の正解画像
確率の低い順に正解画像を出力

# 正解画像を確率の低い順に並び替え
correct.sort(key = lambda x:x[2])

# ワースト30件の画像と予測ラベル・予測確率を出力
plt.figure(figsize = (16, 6))
for i in range(30):
    plt.subplot(3, 10, i + 1)
    plt.axis("off")
    plt.title(correct[i][1]+'\n'+correct[i][2])
    plt.imshow(correct[i][0])
plt.show()

高確率の不正解画像
確率の高い順に不正解画像を出力

# 不正解画像を確率の高い順に並び替え
incorrect.sort(key = lambda x:x[2], reverse = True)

# ベスト30件の画像と予測ラベル・予測確率を出力
plt.figure(figsize = (16, 6))
for i in range(30):
    plt.subplot(3, 10, i + 1)
    plt.axis("off")
    plt.title(incorrect[i][1]+'\n'+incorrect[i][2], color = 'red')
    plt.imshow(incorrect[i][0])
plt.show()

今回は目視での検証だが、今後は正誤の人種比や年齢・性別による正誤の違いを検証したい。

9
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?