LoginSignup
5
2

faiss入門: ANN(近似最近傍探索)を動かしてみた

Last updated at Posted at 2024-01-26

ANN(Approximate Nearest Neighbor)のPythonパッケージである faissを動かしてみました。いくつかあるANNのPythonパッケージの中でfaissを選んだのには、特に深い理由はありません(たまたま仕事で関係あったから)。あまり時間かけて調べてないので、書いている内容に間違いあるかもしれません。

環境

普通にpipでPythonパッケージはインストールし、ローカルのMac PCでJupyterで動かしています。

種類 バージョン 備考
Python 3.9.13 今どき少し古め
faiss-cpu 1.7.4 CPUで動かしいます
numpy 1.26.3
pandas 2.1.4 Jupyter表示に少しだけ使用

動かした内容

フラットインデックスでの総当りの検索と、"IndexIVFFlat"というインデックスでの近傍検索のインデックス作成時間、検索時間およびリコールを調べています。

処理速度結果

まずは結果から。512次元×100万レコードに対する10クエリのTop 10を検索。

インデックス構築時間

インデックスIndexIVFFlatで3000のボロノイ領域(nlist)にすると28秒と少しだけ遅いです。
image.png

検索時間と再現率

分割しすぎると再現率が非常に悪いですね。少し期待しすぎました。
image.png

処理

0. Package Import

timeパッケージは時間計測に使っています。

import time

import numpy as np
import faiss
import pandas as pd

1. 定数定義

DIMENSION = 512
NUM_NEIGHBOR = 100 # Top N Nearest Neighbors
NUM_QUERY = 10 # Number of Query for search

2. 乱数の配列生成

以下の配列を0から1までの乱数で生成します。

  • vectors: 検索先(インデックス)の配列
  • query: 検索元の配列
def get_random_array(dimension, size, seed=1234):
    np.random.seed(seed)
    vectors = np.random.random((size, dimension)).astype('float32') #from 0 to 1
    print(vectors.shape)
    return vectors

vectors = get_random_array(DIMENSION, 1_000_000)
query = get_random_array(DIMENSION, NUM_QUERY)
Output(各配列の次元)
(1000000, 512)
(10, 512)

3. インデックス作成

インデックスを作成する関数create_indexです。
パラメータnlistを渡すとIndexIVFFlatのインデックスを作成します。パラメータnlistを渡さないとIndexFlatL2というフラットなインデックスを作成します。
細かく調べていませんが、ベクトル空間をボロノイ領域に分割してANNを実現しているようです。パラメータnlistはその分割数。
※「ボロノイ領域」は「はじパタ全力解説: 第5章 k最近傍法(kNN法)」で勉強しました。

実用としては、各レコードにIDが必要なので、add_with_ids関数を使って配列とIDを登録しています(コード上ではnp.arangeでID生成)。IndexFlatL2IndexIVFFlatで少しお作法が異なるので注意が必要です。「Faiss ID mapping」参照。

def create_index(vectors, nlist=0):
    index_flat = faiss.IndexFlatL2(vectors.shape[1])   # build the index
    if nlist == 0:
        index = faiss.IndexIDMap2(index_flat)
    else:
        # IndexIVF doesn't need IndexIDMap
        # https://github.com/facebookresearch/faiss/wiki/Pre--and-post-processing#ids-in-the-indexivf
        index = faiss.IndexIVFFlat(index_flat, vectors.shape[1], nlist, faiss.METRIC_L2)
        index.train(vectors)

    index.add_with_ids(vectors, np.arange(vectors.shape[0]))
    # print(f'Trained: {index.is_trained}, Size: {index.ntotal}')
    return index

faissは他にも色々なインデックスの種類がありますが、公式Tutorialの「Faster search
に記載のあったIndexFlatL2を使っています。

4. 検索

検索の関数です。IndexFlatL2でもIndexIVFFlatでも検索できるようにしていますが、IndexIVFFlatだとパラメータnprobeを渡して、検索対象のボロノイ領域数を指定します。

def search(index, query, k, nprobe=0):
    if nprobe > 0:
        index.nprobe = nprobe
    # 本当は1番目に来る自身のベクトルを除外した方がいいかも
    distances, neighbors = index.search(query, k)
    #print(f'Distance shape: {distances.shape}, Neighbor shape: {neighbors.shape}')
    #print(f'{distances=}')
    #print(f'{neighbors=}')
    return distances, neighbors

5. フラットインデックス登録と正解データ取得

フラットインデックスを登録して近傍検索の正解データを取得。この正解データから後で再現率を求めます。

def get_truth(vectors, query, k):
    start = time.time()
    index = create_index(vectors) # create FLAT index
    end = time.time()
    print(f'build FLAT: {end-start}, {start}, {end}')
    start = time.time()
    _, neighbors = search(index, query, NUM_NEIGHBOR)
    end = time.time()
    print(f'search: {end-start}, {start}, {end}')
    return neighbors
truth = get_truth(vectors, query, NUM_NEIGHBOR)
print(f'{truth}')

インデックス構築に0.4秒、総当り検索に0.15秒かかっています。512次元の100万レコードなので、こんなものでしょう。

出力結果
build FLAT: 0.4003911018371582, 1706175708.877467, 1706175709.277858
search: 0.14706897735595703, 1706175709.277965, 1706175709.425034
[[    0 26726 43139 47473 67963 60394 28560  7757 23261 14442  4391 80261
...中略
  34653 49011 58559 24668]]

ちなみに検索結果で出力される距離はデフォルトでL2ノルムです。平方根になっていないので注意しましょう。

6. IndexIVFFlatインデックス登録と検索

IndexIVFFlatインデックス登録と検索をして、各処理時間とリコールを一覧出力します。
インデックス登録時には、nlist(ボロノイ領域数)を変更。検索時にはnprobe(検索対象ボロノイ領域数)を変更しています。

results = []
for nlist in [10, 100, 500, 1_000, 2_000, 3_000]:
    start = time.time()
    index = create_index(vectors, nlist) 
    end = time.time()
    build_time = end-start
    print(f'build {nlist=}: {build_time}, {start}, {end}')
    for nprobe in [1, 3, 10, 30, 100]:
        result = {}
        result['build_time'] = build_time
        result['nlist'] = nlist
        start = time.time()
        _, neighbors = search(index, query, NUM_NEIGHBOR, nprobe)
        end = time.time()
        result['nprobe'] = nprobe
        result['search_time'] = end-start
        recalls = [np.count_nonzero(np.isin(neighbors[i], truth[i]) )
                                      for i in range(NUM_QUERY)]
        result['recall'] = sum(recalls) / NUM_NEIGHBOR / NUM_QUERY
        print(f'search {nlist=}, {nprobe=}: {end-start}, {result["recall"]}')
        results.append(result)

pd.DataFrame(results)

参考リンク

5
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
2