(目次はこちら)

はじめに

これまでの記事 で見てきた、Deep Neural Networks / Convolutional Neural Networks では、精度を高めるために中間層を増やすと、それに伴い、パラメータ数も多くなり、それを学習するためにさらに多くのデータが必要とされる。画像認識のコンペティションなどで上位に来るようなものは、複数のGPUを利用したとしても、当然のように学習に数週間以上かかっている。

正直、こんなに時間がかかるものなんて、頻繁に学習したくない。カテゴリ識別において、たとえば、カテゴリを1つ追加したくなった時、また数週間待つんだろうか?? 嫌だ。こんなときに用いられるのが、転移学習 (transfer learning) / 深層特徴 (deep features)である。

転移学習 (transfer learning) / 深層特徴 (deep features)

これは、何かというと、学習したモデルにデータを入力して、出力層での結果を利用するのではなくて、中間層の出力を特徴量として利用するもの。

そうやって得られた特徴量を使ってSVMなどの識別器を学習することを、転移学習とよび、Deep Neural Networksから得られた特徴量は、deep featuresと呼ばれている。全結合層の出力を利用することが多いが、もっと前段の出力を使っているケースもある。

たとえば、「画像を1000カテゴリに分類する」という問題設定で学習されたモデルがあったとして、カテゴリが1つ増えましたって時に、最初から学習し直すのは時間がかかるので面倒、ってときに非常に有効。

さらに、カテゴリが増えた時だけでなく、別の用途に使っても、わりとよい結果を返してくれたりするのでまた面白い。

やってみる

転移学習ではなく、deep featuresを使って、類似画像を検索してみる。

データ

google image searchで、1408枚の電車の画像を集めた。
いちおう、"非営利目的での再使用が許可された画像"ってオプションで。
以下、img_0000.jpg 〜 img_1407.jpg というファイル名で保存している前提。

方法

tensorflowで、inception-v3モデルってのが公開されているのでそれを利用する。

学習済みモデル (inception v3)

http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz

から、classify_image_graph_def.pbを抽出。

このモデルは、 ILSVRC 2012という1000カテゴリ含まれるデータセットを使っていて、5つカテゴリ選択した時の正解率は、96.54%らしい。ちなみに、Andrej Karpathyって人は、人間が同じデータで頑張ると、94.9%だったそうな。

コード

deep_features.py

deep_features.py
from helper import *

IMG_DIR = '/path/to/img'
MODEL_PATH = '/path/to/classify_image_graph_def.pb'
IMG_NUM = 1408
QUERY_IMG = 22
CANDIDATES = 5

with tf.gfile.FastGFile(MODEL_PATH, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
    pool3 = sess.graph.get_tensor_by_name('pool_3:0')
    features = []
    for i in range(IMG_NUM):
        image_data = tf.gfile.FastGFile('%s/img_%04d.jpg' % (IMG_DIR, i), 'rb').read()
        pool3_features = sess.run(pool3,{'DecodeJpeg/contents:0': image_data})
        features.append(np.squeeze(pool3_features))

query_feat = features[QUERY_IMG]
sims = [(k, round(1 - spatial.distance.cosine(query_feat, v), 3)) for k,v in enumerate(features)]
print(sorted(sims, key=operator.itemgetter(1), reverse=True)[:CANDIDATES + 1])

コードの説明

モデルの読み込み

with tf.gfile.FastGFile(MODEL_PATH, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

出力箇所の選択

pool_3という出力層の一つ手前の出力を使う。

    pool3 = sess.graph.get_tensor_by_name('pool_3:0')

deep features

全画像に対して特徴抽出。各画像において2048次元の特徴ベクトルが得られる。

    for i in range(IMG_NUM):
        image_data = tf.gfile.FastGFile('%s/img_%04d.jpg' % (IMG_DIR, i), 'rb').read()
        pool3_features = sess.run(pool3,{'DecodeJpeg/contents:0': image_data})
        features.append(np.squeeze(pool3_features))

類似画像検索

特徴ベクトル同士のコサイン類似度を、クエリ画像とその他の画像で算出し、大きい方から並べる。
ここでは、クエリ画像を除かずに全画像に対して類似度計算しているので、クエリ画像とクエリ画像を比較した時の結果は1.0(厳密には、1.0に限りなく近い値)になる。

query_feat = features[QUERY_IMG]
sims = [(k, round(1 - spatial.distance.cosine(query_feat, v), 3)) for k,v in enumerate(features)]
print(sorted(sims, key=operator.itemgetter(1), reverse=True)[:CANDIDATES + 1])

実行

使っているtensorflowのバージョンによっては、Warningが大量に出るけど、無視。

W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().

こんな感じの出力が得られるはず。img_0022 が類似度1.0で、img_0120 が類似度0.911で、以下同様。

[(22, 1.0), (120, 0.911), (1363, 0.901), (516, 0.882), (974, 0.870), (809, 0.867)]

上記コードではクエリ画像を検索対象から抜いていないので、類似度1.0のものが出ている。
以下、いろいろな結果を見やすく加工したものを。

結果

1枚ずつ画像をアップロードするのが面倒だったので、deep_features_retrieval.pdfに。

↓こんな感じ。最左列がクエリ画像で、それ以外は、似ていると認識された画像。小数の数値は、類似度(0-1)で、大きいほど類似性が高い。
deep_features_retrieval

何を以って似ていると認識されたのかわからない結果もあるが、まぁそれはしょうがない。

↓いい感じ。
deep_features_retrieval

↓これは、あまりうまくいってないんじゃないかと思ったが、実は、路面電車が集まっていて面白い。
deep_features_retrieval

1000カテゴリを識別することを目的に学習されたモデルなのに、中間層の出力が流用できる(汎用性がある)ことが非常に興味深い。

おわりに

今回の記事では、転移学習の紹介とともに、学習済みのCNNのモデルを流用したdeep featuresを使って、類似画像検索を試してみた。中間層の出力に汎用性があるなんて興味深い。

追記:
Deep FeaturesとFacebook ResearchのFaissを使った類似画像検索。
TensorFlowでDeep Neural Networks (15) Deep Features と Faiss

Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account log in.