LoginSignup
2
2

FaissでKNN実装&Streamlitで超お手軽(手抜き)文字認識

Last updated at Posted at 2023-09-28

概要

先日のブログ記事に関連し、近似最近傍探索ライブラリ Faiss による類似データ検索の実装例を紹介します。
さらに、Streamlitを使って超お手軽(手抜き)文字認識アプリで遊んでみます。単純にMNIST画像のピクセル値を比較するだけですが、そこそこ判定できそうです(改めて深層学習のありがたみも感じました)。
興味ある方、トライしてみてください!

はじめに

類似データ検索は、商品検索や顔識別など広く活用されている技術です。
検索対象から特徴量を抽出(ベクトル化)し、ベクトル同士の類似度をコサイン類似度やユークリッド距離などの指標に基づいて算出することにより、類似度の高いデータを検索できます。
ベクトル化は深層学習モデルが得意としており、画像や文章など複雑なデータにおいても高度な特徴抽出が可能になってきました。

なお、類似データ検索ではベクトル化だけでなく類似度計算の手法も重要になります。
特徴量の次元数をD、検索先のサンプル数をN、距離指標をユークリッド距離とした場合、最近傍探索では計算量はDとNに依存します。サンプル数が多い場合は処理時間を要するため、近年は検索精度を低下させる代わりに計算量を削減できる近似最近傍探索の利用が一般的です。
代表的な近似最近傍探索のPythonライブラリには、Faissやnmslibなどがあり、今回はFaissを使用します。

検証条件

  • 環境、ライブラリ
    • python : 3.10.11
    • faiss-cpu==1.7.4
    • keras==2.13.1
    • matplotlib==3.8.0
    • numpy==1.24.3
    • Pillow==9.5.0
    • streamlit==1.27.0
    • streamlit-drawable-canvas==0.9.3
    • tensorflow==2.13.0

データセットの用意

MNISTデータで動作確認します。

from typing import Union

import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
import faiss

# MNISTデータ読込
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train.shape, X_test.shape
>>> ((60000, 28, 28), (10000, 28, 28))

# 可視化
fig, axs = plt.subplots(4,4, figsize=(8, 8))
for i, (image, label) in enumerate(zip(X_train, y_train)):
    axs[i//4][i%4].imshow(image)
    axs[i//4][i%4].set_title(label)
    if i == 15:
        break
plt.tight_layout()
plt.show()

image.png

K近傍探索の実装

Faissは検索先のデータを index に登録し、searchメソッドにより近傍探索を実施します。
前処理など含め scikit-learnライクに実行出来るよう、fitpredictメソッドでラップし、
引数 kで探索数、metricで距離指標をeuclid(ユークリッド距離)cosine(コサイン距離)に設定します。

class FaissKNeighbors:
    def __init__(self, k: int = 20, metric: Union["euclid", "cosine"] = "euclid"):
        self.index = None
        self.d = None
        self.k = k
        self.metric = metric

    def fit(self, X: np.ndarray):
        X = X.copy(order="C")
        self.d = X.shape[1]
        X = X.astype(np.float32)
        if self.metric == "cosine":
            self.index = faiss.IndexFlatIP(self.d)  # cosine
            faiss.normalize_L2(X)
        elif self.metric == "euclid":
            self.index = faiss.IndexFlatL2(self.d)  # euclid
        self.index.add(X)

    def predict(self, X: np.ndarray):
        X = X.copy(order="C")
        X = np.reshape(X, (-1, self.d))
        X = X.astype(np.float32)
        if self.metric == "cosine":
            faiss.normalize_L2(X)
        distances, indices = self.index.search(X, k=self.k)
        if self.metric == "euclid":
            distances = np.sqrt(distances)
        if X.shape[0] == 1:
            return distances[0], indices[0]
        else:
            return distances, indices

動作確認

knn = FaissKNeighbors(k=5, metric='euclid')
# X_trainデータを検索先として登録
knn.fit(X_train.reshape(-1, 28*28))
# X_testデータ100件分 検索
pred_dists, pred_indexs = knn.predict(X_test[:100].reshape(-1, 28*28))

# 検索結果の可視化
for i in range(5):
    fig, axs = plt.subplots(1,6)
    axs[0].imshow(X_test[i])
    axs[0].set_title(f'Target : {y_test[i]}')
    for col, p_i in enumerate(pred_indexs[i]):
        axs[col+1].imshow(X_train[p_i])
        axs[col+1].set_title(f'pred{col+1} : {y_train[p_i]}')
    plt.tight_layout()
    plt.show()

image-1.png

検索対象(Target)と類似した画像が検索できています(pred1~5)。

おまけ

MNISTは手書き文字データセットなので、実際に自分の手書き文字も検索・判定出来るか気になってきました。
そこで、ChatGPTに相談しつつ手書き文字認識アプリをStreamlitで実装してみました。文字入力はStreamlit拡張機能のstreamlit-drawable-canvasを利用しています。数字の判定は、類似度上位5件の多数決で決めてます。

前述のFaissKNeighborクラスをimport出来るよう、faiss_kneighbors.pyとしてスクリプト化

faiss_kneighbors.py
from typing import Union

import numpy as np
import faiss


class FaissKNeighbors:
    def __init__(self, k: int = 20, metric: Union["euclid", "cosine"] = "euclid"):
        # 上記「K近傍探索の実装」と同様のため割愛
        # ...

アプリ情報をdemo.pyに実装

demo.py
import faiss
import numpy as np
from PIL import Image
from keras.datasets import mnist
import streamlit as st
from streamlit_drawable_canvas import st_canvas

from faiss_kneighbors import FaissKNeighbors

K_NUM = 5

_, (X_test, y_test) = mnist.load_data()
knn = FaissKNeighbors(k=K_NUM)
knn.fit(X_test.reshape(-1, 28*28))

# Streamlitアプリケーションの開始
st.title("手書き文字認識アプリ")

# マウスで文字を描画するキャンバスを作成
st.write("マウスで文字を描いてください。")
canvas_result = st_canvas(
    stroke_width=20,
    update_streamlit=False,
    height=200,
    width=200,
    drawing_mode='freedraw',
)

# 描画した文字を整形し、KNNを使用して分類
if canvas_result.image_data is not None:
    drawn_image = canvas_result.image_data
    drawn_image_gray = drawn_image[:, :, 3]
    if np.sum(drawn_image_gray) > 0:
        drawn_image_gray = Image.fromarray(drawn_image_gray)
        resized_image = drawn_image_gray.resize((28, 28))
        resized_image = np.array(resized_image)

        # KNNで分類
        dists, indexs = knn.predict(resized_image)
        votes = y_test[indexs]
        predictions = np.flatnonzero(np.bincount(votes) == np.bincount(votes).max())

        st.write(f"判定結果: {predictions}")
        st.write(f"類似度上位")
        cols = st.columns(K_NUM)
        for i, idx in enumerate(indexs):
            pred_image = Image.fromarray(X_test[idx])
            pred_image = pred_image.resize((100,100))
            pred_image = np.array(pred_image)
            cols[i].image(pred_image, clamp=True, caption=f'label = {y_test[idx]}')

アプリ起動

streamlit run demo.py

こんな感じで動きます。
単純な仕組みなため、4、5、7、9は他の数字と間違えてしまってますね...。
数字を書く位置にとても敏感なため、実用化は厳しいですね。CNNの位置不変性が活躍してくれそうです。
output.gif

最後まで読んでいただき、ありがとうございました!
今後も機械学習の活用を始め、開発環境やシミュレーションなど幅広く技術情報発信をしていく予定です!

最後になりますが、本記事の内容に誤りなどあれば、コメントにてご教授お願いいたします。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2