(目次はこちら)
はじめに
これまでの記事 で見てきた、Deep Neural Networks / Convolutional Neural Networks では、精度を高めるために中間層を増やすと、それに伴い、パラメータ数も多くなり、それを学習するためにさらに多くのデータが必要とされる。画像認識のコンペティションなどで上位に来るようなものは、複数のGPUを利用したとしても、当然のように学習に数週間以上かかっている。
正直、こんなに時間がかかるものなんて、頻繁に学習したくない。カテゴリ識別において、たとえば、カテゴリを1つ追加したくなった時、また数週間待つんだろうか?? 嫌だ。こんなときに用いられるのが、転移学習 (transfer learning) / 深層特徴 (deep features)である。
転移学習 (transfer learning) / 深層特徴 (deep features)
これは、何かというと、学習したモデルにデータを入力して、出力層での結果を利用するのではなくて、中間層の出力を特徴量として利用するもの。
そうやって得られた特徴量を使ってSVMなどの識別器を学習することを、転移学習とよび、Deep Neural Networksから得られた特徴量は、deep featuresと呼ばれている。全結合層の出力を利用することが多いが、もっと前段の出力を使っているケースもある。
たとえば、「画像を1000カテゴリに分類する」という問題設定で学習されたモデルがあったとして、カテゴリが1つ増えましたって時に、最初から学習し直すのは時間がかかるので面倒、ってときに非常に有効。
さらに、カテゴリが増えた時だけでなく、別の用途に使っても、わりとよい結果を返してくれたりするのでまた面白い。
やってみる
転移学習ではなく、deep featuresを使って、類似画像を検索してみる。
データ
google image searchで、1408枚の電車の画像を集めた。
いちおう、"非営利目的での再使用が許可された画像"ってオプションで。
以下、img_0000.jpg 〜 img_1407.jpg というファイル名で保存している前提。
方法
tensorflowで、inception-v3モデルってのが公開されているのでそれを利用する。
学習済みモデル (inception v3)
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
から、classify_image_graph_def.pb
を抽出。
このモデルは、 ILSVRC 2012
という1000カテゴリ含まれるデータセットを使っていて、5つカテゴリ選択した時の正解率は、96.54%らしい。ちなみに、Andrej Karpathyって人は、人間が同じデータで頑張ると、94.9%だったそうな。
コード
from helper import *
IMG_DIR = '/path/to/img'
MODEL_PATH = '/path/to/classify_image_graph_def.pb'
IMG_NUM = 1408
QUERY_IMG = 22
CANDIDATES = 5
with tf.gfile.FastGFile(MODEL_PATH, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
features = []
for i in range(IMG_NUM):
image_data = tf.gfile.FastGFile('%s/img_%04d.jpg' % (IMG_DIR, i), 'rb').read()
pool3_features = sess.run(pool3,{'DecodeJpeg/contents:0': image_data})
features.append(np.squeeze(pool3_features))
query_feat = features[QUERY_IMG]
sims = [(k, round(1 - spatial.distance.cosine(query_feat, v), 3)) for k,v in enumerate(features)]
print(sorted(sims, key=operator.itemgetter(1), reverse=True)[:CANDIDATES + 1])
コードの説明
モデルの読み込み
with tf.gfile.FastGFile(MODEL_PATH, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
出力箇所の選択
pool_3という出力層の一つ手前の出力を使う。
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
deep features
全画像に対して特徴抽出。各画像において2048次元の特徴ベクトルが得られる。
for i in range(IMG_NUM):
image_data = tf.gfile.FastGFile('%s/img_%04d.jpg' % (IMG_DIR, i), 'rb').read()
pool3_features = sess.run(pool3,{'DecodeJpeg/contents:0': image_data})
features.append(np.squeeze(pool3_features))
類似画像検索
特徴ベクトル同士のコサイン類似度を、クエリ画像とその他の画像で算出し、大きい方から並べる。
ここでは、クエリ画像を除かずに全画像に対して類似度計算しているので、クエリ画像とクエリ画像を比較した時の結果は1.0(厳密には、1.0に限りなく近い値)になる。
query_feat = features[QUERY_IMG]
sims = [(k, round(1 - spatial.distance.cosine(query_feat, v), 3)) for k,v in enumerate(features)]
print(sorted(sims, key=operator.itemgetter(1), reverse=True)[:CANDIDATES + 1])
実行
使っているtensorflowのバージョンによっては、Warning
が大量に出るけど、無視。
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
こんな感じの出力が得られるはず。img_0022 が類似度1.0で、img_0120 が類似度0.911で、以下同様。
[(22, 1.0), (120, 0.911), (1363, 0.901), (516, 0.882), (974, 0.870), (809, 0.867)]
上記コードではクエリ画像を検索対象から抜いていないので、類似度1.0のものが出ている。
以下、いろいろな結果を見やすく加工したものを。
結果
1枚ずつ画像をアップロードするのが面倒だったので、deep_features_retrieval.pdfに。
↓こんな感じ。最左列がクエリ画像で、それ以外は、似ていると認識された画像。小数の数値は、類似度(0-1)で、大きいほど類似性が高い。
何を以って似ていると認識されたのかわからない結果もあるが、まぁそれはしょうがない。
↓これは、あまりうまくいってないんじゃないかと思ったが、実は、路面電車が集まっていて面白い。
1000カテゴリを識別することを目的に学習されたモデルなのに、中間層の出力が流用できる(汎用性がある)ことが非常に興味深い。
おわりに
今回の記事では、転移学習の紹介とともに、学習済みのCNNのモデルを流用したdeep featuresを使って、類似画像検索を試してみた。中間層の出力に汎用性があるなんて興味深い。
追記:
Deep FeaturesとFacebook ResearchのFaissを使った類似画像検索。
TensorFlowでDeep Neural Networks (15) Deep Features と Faiss