LoginSignup
68
62

More than 5 years have passed since last update.

転移学習 / Deep Features [TensorFlowでDeep Learning 12]

Last updated at Posted at 2016-06-17

(目次はこちら)

はじめに

これまでの記事 で見てきた、Deep Neural Networks / Convolutional Neural Networks では、精度を高めるために中間層を増やすと、それに伴い、パラメータ数も多くなり、それを学習するためにさらに多くのデータが必要とされる。画像認識のコンペティションなどで上位に来るようなものは、複数のGPUを利用したとしても、当然のように学習に数週間以上かかっている。

正直、こんなに時間がかかるものなんて、頻繁に学習したくない。カテゴリ識別において、たとえば、カテゴリを1つ追加したくなった時、また数週間待つんだろうか?? 嫌だ。こんなときに用いられるのが、転移学習 (transfer learning) / 深層特徴 (deep features)である。

転移学習 (transfer learning) / 深層特徴 (deep features)

これは、何かというと、学習したモデルにデータを入力して、出力層での結果を利用するのではなくて、中間層の出力を特徴量として利用するもの。

そうやって得られた特徴量を使ってSVMなどの識別器を学習することを、転移学習とよび、Deep Neural Networksから得られた特徴量は、deep featuresと呼ばれている。全結合層の出力を利用することが多いが、もっと前段の出力を使っているケースもある。

たとえば、「画像を1000カテゴリに分類する」という問題設定で学習されたモデルがあったとして、カテゴリが1つ増えましたって時に、最初から学習し直すのは時間がかかるので面倒、ってときに非常に有効。

さらに、カテゴリが増えた時だけでなく、別の用途に使っても、わりとよい結果を返してくれたりするのでまた面白い。

やってみる

転移学習ではなく、deep featuresを使って、類似画像を検索してみる。

データ

google image searchで、1408枚の電車の画像を集めた。
いちおう、"非営利目的での再使用が許可された画像"ってオプションで。
以下、img_0000.jpg 〜 img_1407.jpg というファイル名で保存している前提。

方法

tensorflowで、inception-v3モデルってのが公開されているのでそれを利用する。

学習済みモデル (inception v3)

http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz

から、classify_image_graph_def.pbを抽出。

このモデルは、 ILSVRC 2012という1000カテゴリ含まれるデータセットを使っていて、5つカテゴリ選択した時の正解率は、96.54%らしい。ちなみに、Andrej Karpathyって人は、人間が同じデータで頑張ると、94.9%だったそうな。

コード

deep_features.py

deep_features.py
from helper import *

IMG_DIR = '/path/to/img'
MODEL_PATH = '/path/to/classify_image_graph_def.pb'
IMG_NUM = 1408
QUERY_IMG = 22
CANDIDATES = 5

with tf.gfile.FastGFile(MODEL_PATH, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
    pool3 = sess.graph.get_tensor_by_name('pool_3:0')
    features = []
    for i in range(IMG_NUM):
        image_data = tf.gfile.FastGFile('%s/img_%04d.jpg' % (IMG_DIR, i), 'rb').read()
        pool3_features = sess.run(pool3,{'DecodeJpeg/contents:0': image_data})
        features.append(np.squeeze(pool3_features))

query_feat = features[QUERY_IMG]
sims = [(k, round(1 - spatial.distance.cosine(query_feat, v), 3)) for k,v in enumerate(features)]
print(sorted(sims, key=operator.itemgetter(1), reverse=True)[:CANDIDATES + 1])

コードの説明

モデルの読み込み

with tf.gfile.FastGFile(MODEL_PATH, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

出力箇所の選択

pool_3という出力層の一つ手前の出力を使う。

    pool3 = sess.graph.get_tensor_by_name('pool_3:0')

deep features

全画像に対して特徴抽出。各画像において2048次元の特徴ベクトルが得られる。

    for i in range(IMG_NUM):
        image_data = tf.gfile.FastGFile('%s/img_%04d.jpg' % (IMG_DIR, i), 'rb').read()
        pool3_features = sess.run(pool3,{'DecodeJpeg/contents:0': image_data})
        features.append(np.squeeze(pool3_features))

類似画像検索

特徴ベクトル同士のコサイン類似度を、クエリ画像とその他の画像で算出し、大きい方から並べる。
ここでは、クエリ画像を除かずに全画像に対して類似度計算しているので、クエリ画像とクエリ画像を比較した時の結果は1.0(厳密には、1.0に限りなく近い値)になる。

query_feat = features[QUERY_IMG]
sims = [(k, round(1 - spatial.distance.cosine(query_feat, v), 3)) for k,v in enumerate(features)]
print(sorted(sims, key=operator.itemgetter(1), reverse=True)[:CANDIDATES + 1])

実行

使っているtensorflowのバージョンによっては、Warningが大量に出るけど、無視。

W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().

こんな感じの出力が得られるはず。img_0022 が類似度1.0で、img_0120 が類似度0.911で、以下同様。

[(22, 1.0), (120, 0.911), (1363, 0.901), (516, 0.882), (974, 0.870), (809, 0.867)]

上記コードではクエリ画像を検索対象から抜いていないので、類似度1.0のものが出ている。
以下、いろいろな結果を見やすく加工したものを。

結果

1枚ずつ画像をアップロードするのが面倒だったので、deep_features_retrieval.pdfに。

↓こんな感じ。最左列がクエリ画像で、それ以外は、似ていると認識された画像。小数の数値は、類似度(0-1)で、大きいほど類似性が高い。
deep_features_retrieval

何を以って似ていると認識されたのかわからない結果もあるが、まぁそれはしょうがない。

↓いい感じ。
deep_features_retrieval

↓これは、あまりうまくいってないんじゃないかと思ったが、実は、路面電車が集まっていて面白い。
deep_features_retrieval

1000カテゴリを識別することを目的に学習されたモデルなのに、中間層の出力が流用できる(汎用性がある)ことが非常に興味深い。

おわりに

今回の記事では、転移学習の紹介とともに、学習済みのCNNのモデルを流用したdeep featuresを使って、類似画像検索を試してみた。中間層の出力に汎用性があるなんて興味深い。

追記:
Deep FeaturesとFacebook ResearchのFaissを使った類似画像検索。
TensorFlowでDeep Neural Networks (15) Deep Features と Faiss

68
62
5

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
68
62