#0. 概要
実は結構前から地味に研究されているVQA。昔はキャプション生成等結構隆盛であったが、最近は余り目にも耳にもしない。しかし、DeepLearningの信頼性を確認するという意味ではよいメカニズムなのではないかと考えている。
どうのようなタスクかというと「DeepLearningさんに画像を与えて質問をすると、その画像を基に機械が答えてくれる」というもの。
以下の画像が分かりやすいであろう。
簡単にいうとCNNのタスクと文章生成タスクを両方行うようなものである。
事前に1枚の画像に対して、いくつかの質問セットを用意して学習させるというもので、難しい仕組みではない。
ImpelementationはTensorflowがメインっぽい。といっても3年近くも前の話である。
このため、Kerasへの移植版を使うのがよいだろう。ということで、インストールから実装までを進めて行きたい。
#1. Installation
https://github.com/anantzoid/VQA-Keras-Visual-Question-Answering
このGitが一番実装が分かりやすい。
しかしPython2.7かつKerasも古いバージョンでの実装なのでリファクタリングが結構必要。
まずGitからソースコードを持ってくる
$ git https://github.com/anantzoid/VQA-Keras-Visual-Question-Answering
次に学習に必要なデータを持ってくる。どうやらBlob等は作られている模様。
$ cd data
$ cat README.md
- [GloVe vectors](http://nlp.stanford.edu/data/glove.6B.zip)
- [Pre-trained data](https://filebox.ece.vt.edu/~jiasenlu/codeRelease/vqaRelease/train_only/data_train_val.zip)
- [Annotations for validation set](http://visualqa.org/data/mscoco/vqa/Annotations_Val_mscoco.zip)
- [Model Weights](https://drive.google.com/file/d/0B3b69xdtpDT8U2dwajNKOEhUWUU/view?usp=sharing)
$ wget http://nlp.stanford.edu/data/glove.6B.zip
$ wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Questions_Train_mscoco.zip
$ wget https://filebox.ece.vt.edu/~jiasenlu/codeRelease/vqaRelease/train_only/data_train_val.zip
$ unzip data_train_val.zip
$ unzip glove.6B.zip
$ mkdir Questions_Train_mscoco
$ mv Questions_Train_mscoco.zip Questions_Train_mscoco/
$ cd Questions_Train_mscoco
$ unzip Questions_Train_mscoco.zip
取り合えず、MSCOCOのURLが切れているのでAWSの方に書き換える。
また、ファイルパスの設定がやや異なるので置き換える。
後、Modelの書き方が古すぎて、書き換えるので既存のWeightは使えないためダウンロードしない。
これでデータセットが揃った。次に$cd ..
ルートフォルダに戻る。
Python2.7で書かれているので、print xxx
-> print (xxx)
のように全部書き換える。
次にmodel.py
が古いので以下のように書き換える。取り合えず、論文通りに実装してみた。
import keras
import keras.backend as K
from keras.models import Sequential
from keras.models import Model
from keras.layers import Input, Dense, Activation, Dropout, LSTM, Flatten, Embedding, Multiply, Lambda
from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D
import h5py
def vqa_model(embedding_matrix, num_words, embedding_dim, seq_length, dropout_rate, num_classes):
#####################################################
# Word2Vec model
#####################################################
print ("Creating text model...")
input_txt = Input(shape=(seq_length,), name='text_input')
x = Embedding(num_words, embedding_dim, weights=[embedding_matrix], trainable=False)(input_txt)
x = LSTM(units=512, return_sequences=True, input_shape=(seq_length, embedding_dim))(x)
x = Dropout(dropout_rate)(x)
x = LSTM(units=512, return_sequences=False)(x)
x = Dropout(dropout_rate)(x)
output_txt = Dense(1024, activation='tanh')(x)
txt_model = Model(input_txt, output_txt)
txt_model.summary()
#####################################################
# Image model
#####################################################
print ("Creating image model...")
input_img = Input(shape=(4096,), name='image_input')
output_img = Dense(1024, activation='tanh')(input_img)
img_model = Model(input_img, output_img)
img_model.summary()
#####################################################
# VQA model
#####################################################
print ("Creating vqa model...")
input_intermediate_img = Input(shape=(1024,), name='intermediate_image_input')
input_intermediate_txt = Input(shape=(1024,), name='intermediate_text_input')
x = Multiply()([input_intermediate_img, input_intermediate_txt])
x = Dropout(dropout_rate)(x)
x = Dense(1024, activation='tanh')(x)
x = Dropout(dropout_rate)(x)
vqa = Dense(num_classes, activation='softmax')(x)
vqa_model = Model([input_intermediate_img, input_intermediate_txt], vqa)
vqa_model.summary()
# internal connection
output_vqa = vqa_model([img_model(input_img), txt_model(input_txt)])
#####################################################
# Pack model
#####################################################
print ("Packing multi model...")
multiModel = Model([input_img, input_txt], output_vqa, name='multiModel')
multiModel.summary()
# optimizer
multiModel.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
return multiModel
Embeddingのそもそもの意味は単語や文をベクトル空間に写像することである。
今回のVQAで用いられているEmbeddingも同様である。
例えば、Word2Vecで単語のベクトル空間を作成したとする。しかし、これは文のベクトルではない。文は複数の単語で成り立っている。
ここで、仮に単語を5つ使った文章を考えたとする。この時、Word2Vecで分類した単語数は全部で50、単語ベクトル(分散ベクトル)の数は10とする。その場合は、Embedding(50,10, input_length=5)
となる。
Embeddingレイヤーは殆どWord2Vecと同じ動きをするが、後続に続くネットワークにより最適化されたベクトル空間を提供できる。
もちろんWord2Vecで得られたベクトル空間をEmmedingレイヤーのWeightとして使うこともできる。
2. Training
以下でトレーニングを実行できる。
$ python train.py
実行中はこんな感じ。そこそこなGPUを使っていれば、10分くらいで終わる。
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
text_input (InputLayer) (None, 26) 0
_________________________________________________________________
embedding_1 (Embedding) (None, 26, 300) 3780600
_________________________________________________________________
lstm_1 (LSTM) (None, 26, 512) 1665024
_________________________________________________________________
dropout_1 (Dropout) (None, 26, 512) 0
_________________________________________________________________
lstm_2 (LSTM) (None, 512) 2099200
_________________________________________________________________
dropout_2 (Dropout) (None, 512) 0
_________________________________________________________________
dense_1 (Dense) (None, 1024) 525312
=================================================================
Total params: 8,070,136
Trainable params: 4,289,536
Non-trainable params: 3,780,600
_________________________________________________________________
Creating image model...
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
image_input (InputLayer) (None, 4096) 0
_________________________________________________________________
dense_2 (Dense) (None, 1024) 4195328
=================================================================
Total params: 4,195,328
Trainable params: 4,195,328
Non-trainable params: 0
_________________________________________________________________
Creating vqa model...
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
intermediate_image_input (Input (None, 1024) 0
__________________________________________________________________________________________________
intermediate_text_input (InputL (None, 1024) 0
__________________________________________________________________________________________________
multiply_1 (Multiply) (None, 1024) 0 intermediate_image_input[0][0]
intermediate_text_input[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout) (None, 1024) 0 multiply_1[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 1024) 1049600 dropout_3[0][0]
__________________________________________________________________________________________________
dropout_4 (Dropout) (None, 1024) 0 dense_3[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 1000) 1025000 dropout_4[0][0]
==================================================================================================
Total params: 2,074,600
Trainable params: 2,074,600
Non-trainable params: 0
__________________________________________________________________________________________________
Creating multi model...
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
image_input (InputLayer) (None, 4096) 0
__________________________________________________________________________________________________
text_input (InputLayer) (None, 26) 0
__________________________________________________________________________________________________
model_2 (Model) (None, 1024) 4195328 image_input[0][0]
__________________________________________________________________________________________________
model_1 (Model) (None, 1024) 8070136 text_input[0][0]
__________________________________________________________________________________________________
model_3 (Model) (None, 1000) 2074600 model_2[1][0]
model_1[1][0]
==================================================================================================
Total params: 14,340,064
Trainable params: 10,559,464
Non-trainable params: 3,780,600
__________________________________________________________________________________________________
Epoch 1/10
215359/215359 [==============================] - 95s 443us/step - loss: 3.3621 - acc: 0.3163
Epoch 00001: saving model to data/ckpts/model_weights.h5
Epoch 2/10
215359/215359 [==============================] - 92s 426us/step - loss: 2.3662 - acc: 0.4073
Epoch 00002: saving model to data/ckpts/model_weights.h5
Epoch 3/10
215359/215359 [==============================] - 92s 429us/step - loss: 2.1415 - acc: 0.4464
Epoch 00003: saving model to data/ckpts/model_weights.h5
#3. Validation
Validation用のデータをダウンロードして設置する。
$ cd data
$ mkdir validation_data
$ cd validation_data
$ wget https://vision.ece.vt.edu/vqa/release_data/mscoco/vqa/Annotations_Val_mscoco.zip
$ unzip Annotations_Val_mscoco.zip
以下を実行して、Validationを実行
python train.py --type val
以下のように出力されればValidation終了
Loading Weights...
Evaluating Accuracy on validation set:
121512/121512 [==============================] - 139s 1ms/step
loss is 2.4434539682030545
acc is 0.4561936269668839
true positive rate: 0.5853002172624926
学習をマッハで行ったので精度は低い!
Appendix
dataset
トークン化
各単語を全てIDで表現する
"melted", "6738": "neat", "6739": "motorist", "6740": "itinerary", "6741": "balance", "6742": "study",
質問文を作る(最大26単語)
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 11106.
6749. 3713. 12308. 12548. 10945. 6069.]
11106 = waht
、6749 = is
のような形にする
後にこれを各単語のBlobを参照して単語ベクトルに置き換えて、それをさらに文章ベクトルに変換する
glob
例えば以下がwhat
の分散ベクトル
what -0.20017 0.14302 0.052055 -0.00080884 0.017009 0.014899 -0.25524 -0.17907 -0.046713 -2.0547 0.22617 0.082 849 -0.2119 0.19906 0.30946 0.22688 -0.060026 -0.033334 0.038108 0.22626 0.52159 0.59871 0.45327 -0.041098 -0. 40293 -0.079128 0.0025339 -0.36042 0.065823 -0.010745 0.054722 0.50756 -0.64612 -0.0045895 -1.0173 0.30218 -0. 25403 0.095647 -0.047533 -0.32479 0.1473 -0.2024 0.011759 0.11088 0.015631 0.112 0.033769 0.16056 -0.39738 0.0 93925 0.22837 -0.31089 0.035846 0.044164 -0.30598 0.61444 -0.047129 0.23316 0.20707 0.25667 0.11389 -0.11335 0 .37689 0.42253 -0.12522 -0.25177 -0.042779 -0.067905 -0.020935 -0.18254 0.068618 -0.0090734 0.10255 0.33506 -0 .11719 -0.17783 0.027547 0.23482 -0.51104 -0.097399 -0.33855 -0.014775 0.35922 0.16305 0.094341 0.18024 -0.008 4479 0.423 0.24639 0.31141 -0.54777 0.52251 -0.3672 -0.21798 -0.13245 -0.12492 -0.49334 -0.0058997 -0.012498 - 0.49214 0.15075 0.17764 -0.16562 0.025349 0.096398 0.074563 0.30484 0.25288 -0.13054 0.50759 0.17547 -0.23331 -0.043889 -0.28066 0.18836 0.18232 0.24571 -0.042019 0.096973 -0.30044 -0.0054186 -0.31779 0.011232 0.24753 0. 0083373 -0.45666 -0.047196 0.39851 0.0059699 0.037121 0.15238 -0.19107 -0.005559 -0.4424 -0.11946 0.06862 -0.1 778 0.017349 -0.11998 0.16616 0.014128 0.2064 0.10171 0.13337 -0.49099 -0.29886 0.20528 0.21264 -0.14518 0.137 42 -0.61539 -0.093476 0.33104 -0.10262 0.13709 -0.067552 0.092931 0.092051 0.11208 0.017574 0.35888 -0.38706 0 .24634 -0.18565 0.12121 0.19534 -0.10963 0.26456 -0.19938 0.088868 0.11595 -0.23044 -0.62528 -0.27601 0.111 -0 .058878 -0.3095 0.14474 0.087063 0.24053 0.23826 0.075055 0.08908 -0.30241 -0.088252 0.28571 -0.070077 -0.0303 38 0.38133 0.45064 -0.12722 0.2314 0.311 -0.046189 0.10391 -0.012487 0.1838 0.080527 -0.031989 -0.45986 1.2992 -0.46095 -0.043503 -0.049583 0.05659 -0.054186 0.095496 0.19565 -0.13754 -0.39788 -0.26887 0.26798 0.27644 0. 08585 -0.18111 0.17159 -0.0094278 0.20505 -0.27989 -0.12635 0.17923 0.17281 -0.044121 0.20526 0.50476 0.20821 0.22695 0.10859 0.019256 -0.1453 0.085831 0.062539 0.21132 -0.05678 -0.34562 -0.0078592 0.26591 -0.14611 -0.09 4381 -0.11839 0.094688 -0.088096 0.51552 0.2753 -0.67519 -0.41352 -0.15537 0.10511 0.041168 0.17814 0.28327 -0 .24044 -0.32312 0.22443 0.40035 0.040625 -0.072271 -0.21041 0.026456 -0.39834 -0.0099483 -0.023906 0.047874 -0 .16029 -0.072023 0.15884 -0.30268 0.040359 0.091823 0.23122 0.040294 -0.027471 0.2447 0.29567 0.069906 0.21981 -2.3806 -0.029845 0.72655 -0.16307 -0.05433 0.0087763 0.0036759 0.036295 0.023036 -0.057012 0.006363 -0.05500 3 -0.10056 0.14143 0.045239 -0.35298 0.3335 0.28104 0.20338 -0.4788 -0.039697 0.034939 -0.12599 0.21863
正解は1単語
{"1": "yes", "2 ": "no", "3": "2", "4": "1", "5": "white", "6": "3", "7": "red", "8": "blue", "9": "4", "10": "green"}