PythonとFlaskで類似画像を検索できるアプリを作ってみようPart1
概要
せっかくなのでgithubリポジトリ作りました 星ください(乞食)
SSAM-SimilaritySearchAppwithMNIST
タイトルの通りです。今回は簡単にデータが手に入ってかつデータサイズも小さいご存知MNISTに対してMNISTの画像を入れると似たMNISTの画像を見つけてくれる、そんなアプリを作ることで近似近傍探索についての話とFlaskを使ったアプリ作りについての話を自分のメモもかねて書いていきます。
手順
作業ディレクトリをつくろう!
- 作業ディレクトリ
適当に名前を決めていいですが、私は調子に乗って SSAM-SimilaritySearchAppwithMNIST なんて洒落た名前をつけてみました。
- MNIST用ディレクトリ
次にMNISTのデータを保存しておくフォルダを作成しましょう。そうしないと毎回ダウンロードしてくる必要があって面倒です。
static/mnist_dataみたいな名前でフォルダを作っておきましょう。
- 現在のディレクトリ構成はこんな感じです
SSAM-SimilaritySearchAppwithMNIST
└── static
└── mnist_data
MNISTデータをダウンロードしてこよう!
MNISTって何?
有名なので、今更説明もいらないと思いますが知らない人のために説明すると、28x28サイズの手書きの数字の画像です。
手軽に使えるデータなので、よく機械学習のモデルの精度の評価に使われたりしますが、まぁそれはそれは綺麗なデータなので大体精度は高く出ます。
(この記事みたいに)MNISTで高い精度の出るモデルを作れました!みたいな記事はそこんところ注意した方がいいです。おーおー、またMNISTごときでイキってる記事が上がってるぞくらいのスタンスで見た方が安全です。(これはただの自虐ネタのつもりなので他のMNISTを扱う記事を煽ってるわけではないです、本当です。)
使用するライブラリ
コードって細かく分割して逐一解説入れながら書いてくのが分かりやすいのか、一気に書いてしまって後から補足説明入れる方が分かりやすいのか少し悩みますね。今後ちょっと考えたいので参考意見があると嬉しいです。(Qiitaにコメント機能があるのかすら知りませんが。)
とりあえず少なくとも、今回使うライブラリはまとめて書いておいた方が親切だと思うのでそうします。pipインストールは適宜各自しておいてください。
import os
import gzip
import pickle
import numpy as np
import urllib.request
from mnist import MNIST
from annoy import AnnoyIndex
from sklearn.metrics import accuracy_score
MNISTをダウンロードする
早速コードを書いていきます。
上でも書きましたがコードって細かく分割して逐一解説入れながら書いてくのが親切なのか、一気に書いてしまって後から補足説明入れる方が親切なのか少し悩みます。そこそこコメント丁寧に書いているので、今回は一気に書いてみますが、ここら辺意見欲しいです。
まずMNISTをダウンロードしてくるコードです。
def load_mnist():
# MNISTのデータをダウンロードしてくる
url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
'train_img':'train-images-idx3-ubyte.gz',
'train_label':'train-labels-idx1-ubyte.gz',
'test_img':'t10k-images-idx3-ubyte.gz',
'test_label':'t10k-labels-idx1-ubyte.gz'
}
# ファイルをダウンロード(.gz形式)
for filename in key_file.values():
file_path = f'./static/mnist_data/{filename}' # 読み込んでくる.gzファイルのパス
if os.path.isfile(file_path.replace('.gz', '')): continue # すでにファイルがあるときは飛ばす
urllib.request.urlretrieve(url_base + filename, file_path)
# .gzの解凍と.gzファイルの削除
with gzip.open(file_path, mode='rb') as f:
mnist = f.read()
# 解凍して保存
with open(file_path.replace('.gz', ''), 'wb') as w:
w.write(mnist)
os.remove(file_path) # .gzファイルの削除
# mnistのデータをnp.arrayの形に読み込んで返す
mndata = MNIST('./static/mnist_data/')
# train
images, labels = mndata.load_training()
train_images, train_labels = np.reshape(np.array(images), (-1,28,28)), np.array(labels) # np.arrayに変換する mnistの画像は28x28
# test
images, labels = mndata.load_testing()
test_images, test_labels = np.reshape(np.array(images), (-1,28,28)), np.array(labels) # np.arrayに変換する mnistの画像は28x28
return train_images, train_labels, test_images, test_labels
上記コードを実行すると、まずMNISTのデータが置いてある場所から.gzファイルという圧縮されたファイルをダウンロードしてstatic/mnist_data/の下におきます。(./static/mnist_data/は事前にフォルダを作っておかないとフォルダがないよとエラーが出るかもしれません。すみません。)
そしたら.gzを解凍して.gzファイルはいらないので削除しておきます。実はこの解凍してきたファイルはバイナリ形式なので扱いが面倒なのですが、
from mnist import MNIST
というなんとも間抜けな名前のライブラリを使うとtrain用のデータとtest用のデータで分けてarray型にしてくれる様です。実はここでだいぶ悩んで数時間かけてたりするのはここだけの秘密です。
そんな感じでいつの間にかMNISTデータをダウンロードしてくる関数が完成しました。
近似最近傍探索
最近傍探索とは
wikipediaから定義を引用してみると、下の様に書いています。
最近傍探索(英: Nearest neighbor search, NNS)は、距離空間における最も近い点を探す最適化問題の一種、あるいはその解法。近接探索(英: proximity search)、類似探索(英: similarity search)、最近点探索(英: closest point search)などとも呼ぶ。問題はすなわち、距離空間 M における点の集合 S があり、クエリ点 q ∈ M があるとき、S の中で q に最も近い点を探す、という問題である。多くの場合、M には d次元のユークリッド空間が採用され、距離はユークリッド距離かマンハッタン距離で測定される。低次元の場合と高次元の場合で異なるアルゴリズムがとられる。 ~wikipediaより
これを日本語に訳すと、なんらかの類似度を測る関数を決めて、それに基づいて類似度の高い点を見つけてくるアルゴリズムということが書いてあります。つまりそういうことです。
近似?
今回使うアルゴリズムはannoyというものを使うのですが、これには"近似"最近傍探索とついています。この"近似"とはなんなのでしょうか?
実はこの近傍探索アルゴリズム、きちんと計算しようとするとかなり計算資源を食います。例えば最も単純に全画素値の差を計算してみるかwなどと思ってしまった日には大変なことになります。特に画像の場合は縦x横xチャンネル数あるので爆発的に計算時間が伸びます。
今回のMNISTは所詮28x28でしかもグレイスケール(白黒画像)なので大したことはないのですが、よっしゃーお父さん今日は張り切ってフルHD1920×1080の類似画像を愚直に探索するぞーなんてした日にはお父さんは絶望して寝れなくなります。
指数的爆発の怖さについては、計算量おねえさんが体を張って教えてくれます。面白い上に勉強になる最強のコンテンツなので是非知らない人は一度見ておいて、お姉さんの執念に涙しましょう。
『フカシギの数え方』 おねえさんといっしょ! みんなで数えてみよう!
今回使うアルゴリズム
似た画像を見つけてくるのに近似最近傍探索ライブラリのannoyと言うものを使います。
これを使う理由は自分が慣れているからというのと、比較的コードが読みやすいと思うからです。
アルゴリズムの詳しい内容についてはannoy作者のこのブログNearest neighbors and vector models – part 2 – algorithms and data structuresの解説と、日本語では近似最近傍探索の最前線このスライドシェアの解説がわかりやすいです。この記事を見る前にこっちを見た方がためになります。
面倒くさいやという人のためにここで簡単に説明すると、データポイントが存在する空間を再帰的に区切っていって2分木をいくつか作っておくことで高速にO(logn)近傍を探索できる様にしたアルゴリズムとのことです。とはいえ決定木をビルドする必要はあるのですが。
実はannoyの作者がいくつかの近似最近傍探索ライブラリの比較をして他のライブラリの方がいいぞと勧めてくれてたりします。New approximate nearest neighbor benchmarks
本当に速度を追い求めるならおそらくFacebook謹製のfaissを使うのがいいっぽいのですが、こいつはcondaからしかインストールできないらしくそれだけで使う気をなくします。
annoyの使用
少し説明が長くなりました。さっそくannoyを使った近傍探索を試してみましょう。
def make_annoy_db(train_imgs):
# 近似近傍探索のライブラリannoyを使う
# 入力するデータのshapeとmetricを決めて、データを突っ込んでいく
annoy_db = AnnoyIndex((28*28), metric='euclidean') # 入力されるデータと類似度の計算方法を与える、MNISTは28x28のサイズなのでそのまま28*28と書いちゃう
for i, train_img in enumerate(train_imgs):
annoy_db.add_item(i, train_img.flatten()) # インデックスとそれと対応するデータを入れていく
annoy_db.build(n_trees=10) # ビルド
annoy_db.save('./static/mnist_db.ann') # static配下に作成したデータベースを保存する
こんな感じになるでしょうか。データベースと言う言葉を使っていますが、これは僕が他にいい適切な言葉を知らないだけで厳密にはデータベースとは違います、多分。
ちなみにAnnoyIndexには一次元の配列しか与えられないので画像データを入れたいときはflatten()を使ったり、reshapeしたり、今回の様にハードコーディングしたりしてください。
精度をみてみる
近似最近傍探索と言うだけあって厳密な類似度を計算しているわけではないので、もちろん多少の誤差が生じます。(annoyではこれを複数の木を使ったりなんだりして解決してます。詳しくは上記URLで)
本当に精度があるのか、少し確かめてみましょうか。
幸いMNISTには各画像とそれに対応する正解ラベルがすでに用意されています。テストデータに対して類似した画像を一つ引っ張って来てもらって、正解と同じかどうか確認してみましょう。
train_imgs, train_lbls, test_imgs, test_lbls = load_mnist()
if not os.path.isfile('./static/mnist_db.ann'):
make_annoy_db(train_imgs) # .annファイルがまだないなら、annoydbのビルド
annoy_db = AnnoyIndex((28*28), metric='euclidean')
annoy_db.load('./static/mnist_db.ann') # annoyのデータベースをロードする
# テストデータを入力して近い近傍を取ってきて実際と比較することで試しに精度をみてみる
y_pred = [train_lbls[annoy_db.get_nns_by_vector(test_img.flatten(), 1)[0]] for test_img in test_imgs]
score = accuracy_score(test_lbls, y_pred)
print('acc:', score)
# 出力 acc: 0.9595
なんとも高い精度が出ました。せっかくなら0.2525と出た方がニコ厨かよwとネタにできるのですが現実はそう甘くありません。(いったい僕は何をいってるのでしょうか?)
実際の画像の場合、正解のラベルが同じだからといってそれが本当に似た画像かどうかはなんともいえないのですが(例えばどっちも猫画像だとしても背景が違ったり、黒猫と白猫だったりした場合、人の脳は似た画像だとは判断しない)、このMNISTに限っていえばきれいなデータセットなので多分似た画像が出てきているのでしょう。
全体のコード
ここまでの全体を通してのコードが以下の様になります。
import os
import gzip
import pickle
import numpy as np
import urllib.request
from mnist import MNIST
from annoy import AnnoyIndex
from sklearn.metrics import accuracy_score
def load_mnist():
# MNISTのデータをダウンロードしてくる
url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
'train_img':'train-images-idx3-ubyte.gz',
'train_label':'train-labels-idx1-ubyte.gz',
'test_img':'t10k-images-idx3-ubyte.gz',
'test_label':'t10k-labels-idx1-ubyte.gz'
}
# ファイルをダウンロード(.gz形式)
for filename in key_file.values():
file_path = f'./static/mnist_data/{filename}' # 読み込んでくる.gzファイルのパス
if os.path.isfile(file_path.replace('.gz', '')): continue # すでにファイルがあるときは飛ばす
urllib.request.urlretrieve(url_base + filename, file_path)
# .gzの解凍と.gzファイルの削除
with gzip.open(file_path, mode='rb') as f:
mnist = f.read()
# 解凍して保存
with open(file_path.replace('.gz', ''), 'wb') as w:
w.write(mnist)
os.remove(file_path) # .gzファイルの削除
# mnistのデータをnp.arrayの形に読み込んで返す
mndata = MNIST('./static/mnist_data/')
# train
images, labels = mndata.load_training()
train_images, train_labels = np.reshape(np.array(images), (-1,28,28)), np.array(labels) # np.arrayに変換する mnistの画像は28x28
# test
images, labels = mndata.load_testing()
test_images, test_labels = np.reshape(np.array(images), (-1,28,28)), np.array(labels) # np.arrayに変換する mnistの画像は28x28
return train_images, train_labels, test_images, test_labels
def make_annoy_db(train_imgs):
# 近似近傍探索のライブラリannoyを使う
# 入力するデータのshapeとmetricを決めて、データを突っ込んでいく
annoy_db = AnnoyIndex((28*28), metric='euclidean')
for i, train_img in enumerate(train_imgs):
annoy_db.add_item(i, train_img.flatten())
annoy_db.build(n_trees=10) # ビルド
annoy_db.save('./static/mnist_db.ann')
def main():
# mnist画像の読み込み処理
train_imgs, train_lbls, test_imgs, test_lbls = load_mnist()
print(train_imgs.shape, train_lbls.shape, test_imgs.shape, test_lbls.shape) # MNISTがどのくらいの枚数データがあるのか確認したかった
if not os.path.isfile('./static/mnist_db.ann'):
make_annoy_db(train_imgs) # .annファイルがまだないなら、annoydbのビルド
annoy_db = AnnoyIndex((28*28), metric='euclidean')
annoy_db.load('./static/mnist_db.ann') # annoyのデータベースをロードする
# テストデータを入力して近い近傍を取ってきて実際と比較することで試しに精度をみてみる
y_pred = [train_lbls[annoy_db.get_nns_by_vector(test_img.flatten(), 1)[0]] for test_img in test_imgs]
score = accuracy_score(test_lbls, y_pred)
print('acc:', score)
if __name__ == "__main__":
main()
次の目標
なんだかんだここまでで、既に分量が多くなったので続きは次回に回します。
今回はannoyを使って類似画像を探索できそうだと言うところまでやりました。
次回は、Flaskを使って画像を選んだら似た画像が出てくるアプリを作っていきましょう。
(僕が飽きたり忙しくなかったら)続く:-> 次回
おまけ:本当にどうでもいい愚痴なのですが、僕はEmacsのorgmodeが大好きで今回の記事もorgmodeで書くぞーと勇んで書き始めたのですが、悲しいかな互換性の弱さには対応しきれず結局Markdownで書いてしまっていました。
単体でいえばorgmodeが文書作成ソフトでは最強だと思ってるのですが(というかMarkdownって使いにくくないですか?どこの馬鹿が考えたら半角スペース二回で改行にしようとか思いつくのか)、どうしても他との連携を考えるとそうもいっていられないのが悲しいところです。
orgmodeにもexport as markdownみたいなものがあり、それでなんとかしようと思っていたのですが、それをさらにQiitaに載せるぞとなると思ったよりきれいにならなかったりしてねぇ...何とも悲しいものです。
Emacsもキーバインドが独特でVScodeの楽さになれちゃうとどうも使いづらくてねー、一応VScodeにもorgmode extension的なのはありますが、一番重要なtabキーで目だしを折り畳んだする表示状態を切換える機能と、htmlとかへのexport機能がないじゃないかとかねえ、VHS対ベータマックスみたいなものを連想してしまいました、世知辛いものです。
そういえば、Software Design 2020年8月号の表紙を見て悲しくなりました。Vimと戦うのはEmacs、昔からそうだったじゃねえかよ。もう無理なのかEmacs?そんな感じで記事を閉めたいと思います。