ANN(Approximate Nearest Neighbor)のPythonパッケージである faissを動かしてみました。いくつかあるANNのPythonパッケージの中でfaissを選んだのには、特に深い理由はありません(たまたま仕事で関係あったから)。あまり時間かけて調べてないので、書いている内容に間違いあるかもしれません。
環境
普通にpipでPythonパッケージはインストールし、ローカルのMac PCでJupyterで動かしています。
種類 | バージョン | 備考 |
---|---|---|
Python | 3.9.13 | 今どき少し古め |
faiss-cpu | 1.7.4 | CPUで動かしいます |
numpy | 1.26.3 | |
pandas | 2.1.4 | Jupyter表示に少しだけ使用 |
動かした内容
フラットインデックスでの総当りの検索と、"IndexIVFFlat"というインデックスでの近傍検索のインデックス作成時間、検索時間およびリコールを調べています。
処理速度結果
まずは結果から。512次元×100万レコードに対する10クエリのTop 10を検索。
インデックス構築時間
インデックスIndexIVFFlat
で3000のボロノイ領域(nlist)にすると28秒と少しだけ遅いです。
検索時間と再現率
分割しすぎると再現率が非常に悪いですね。少し期待しすぎました。
処理
0. Package Import
time
パッケージは時間計測に使っています。
import time
import numpy as np
import faiss
import pandas as pd
1. 定数定義
DIMENSION = 512
NUM_NEIGHBOR = 100 # Top N Nearest Neighbors
NUM_QUERY = 10 # Number of Query for search
2. 乱数の配列生成
以下の配列を0から1までの乱数で生成します。
- vectors: 検索先(インデックス)の配列
- query: 検索元の配列
def get_random_array(dimension, size, seed=1234):
np.random.seed(seed)
vectors = np.random.random((size, dimension)).astype('float32') #from 0 to 1
print(vectors.shape)
return vectors
vectors = get_random_array(DIMENSION, 1_000_000)
query = get_random_array(DIMENSION, NUM_QUERY)
(1000000, 512)
(10, 512)
3. インデックス作成
インデックスを作成する関数create_index
です。
パラメータnlist
を渡すとIndexIVFFlat
のインデックスを作成します。パラメータnlist
を渡さないとIndexFlatL2
というフラットなインデックスを作成します。
細かく調べていませんが、ベクトル空間をボロノイ領域に分割してANNを実現しているようです。パラメータnlist
はその分割数。
※「ボロノイ領域」は「はじパタ全力解説: 第5章 k最近傍法(kNN法)」で勉強しました。
実用としては、各レコードにIDが必要なので、add_with_ids
関数を使って配列とIDを登録しています(コード上ではnp.arange
でID生成)。IndexFlatL2
とIndexIVFFlat
で少しお作法が異なるので注意が必要です。「Faiss ID mapping」参照。
def create_index(vectors, nlist=0):
index_flat = faiss.IndexFlatL2(vectors.shape[1]) # build the index
if nlist == 0:
index = faiss.IndexIDMap2(index_flat)
else:
# IndexIVF doesn't need IndexIDMap
# https://github.com/facebookresearch/faiss/wiki/Pre--and-post-processing#ids-in-the-indexivf
index = faiss.IndexIVFFlat(index_flat, vectors.shape[1], nlist, faiss.METRIC_L2)
index.train(vectors)
index.add_with_ids(vectors, np.arange(vectors.shape[0]))
# print(f'Trained: {index.is_trained}, Size: {index.ntotal}')
return index
faissは他にも色々なインデックスの種類がありますが、公式Tutorialの「Faster search
」に記載のあったIndexFlatL2
を使っています。
4. 検索
検索の関数です。IndexFlatL2
でもIndexIVFFlat
でも検索できるようにしていますが、IndexIVFFlat
だとパラメータnprobe
を渡して、検索対象のボロノイ領域数を指定します。
def search(index, query, k, nprobe=0):
if nprobe > 0:
index.nprobe = nprobe
# 本当は1番目に来る自身のベクトルを除外した方がいいかも
distances, neighbors = index.search(query, k)
#print(f'Distance shape: {distances.shape}, Neighbor shape: {neighbors.shape}')
#print(f'{distances=}')
#print(f'{neighbors=}')
return distances, neighbors
5. フラットインデックス登録と正解データ取得
フラットインデックスを登録して近傍検索の正解データを取得。この正解データから後で再現率を求めます。
def get_truth(vectors, query, k):
start = time.time()
index = create_index(vectors) # create FLAT index
end = time.time()
print(f'build FLAT: {end-start}, {start}, {end}')
start = time.time()
_, neighbors = search(index, query, NUM_NEIGHBOR)
end = time.time()
print(f'search: {end-start}, {start}, {end}')
return neighbors
truth = get_truth(vectors, query, NUM_NEIGHBOR)
print(f'{truth}')
インデックス構築に0.4秒、総当り検索に0.15秒かかっています。512次元の100万レコードなので、こんなものでしょう。
build FLAT: 0.4003911018371582, 1706175708.877467, 1706175709.277858
search: 0.14706897735595703, 1706175709.277965, 1706175709.425034
[[ 0 26726 43139 47473 67963 60394 28560 7757 23261 14442 4391 80261
...中略
34653 49011 58559 24668]]
ちなみに検索結果で出力される距離はデフォルトでL2ノルムです。平方根になっていないので注意しましょう。
6. IndexIVFFlat
インデックス登録と検索
IndexIVFFlat
インデックス登録と検索をして、各処理時間とリコールを一覧出力します。
インデックス登録時には、nlist
(ボロノイ領域数)を変更。検索時にはnprobe
(検索対象ボロノイ領域数)を変更しています。
results = []
for nlist in [10, 100, 500, 1_000, 2_000, 3_000]:
start = time.time()
index = create_index(vectors, nlist)
end = time.time()
build_time = end-start
print(f'build {nlist=}: {build_time}, {start}, {end}')
for nprobe in [1, 3, 10, 30, 100]:
result = {}
result['build_time'] = build_time
result['nlist'] = nlist
start = time.time()
_, neighbors = search(index, query, NUM_NEIGHBOR, nprobe)
end = time.time()
result['nprobe'] = nprobe
result['search_time'] = end-start
recalls = [np.count_nonzero(np.isin(neighbors[i], truth[i]) )
for i in range(NUM_QUERY)]
result['recall'] = sum(recalls) / NUM_NEIGHBOR / NUM_QUERY
print(f'search {nlist=}, {nprobe=}: {end-start}, {result["recall"]}')
results.append(result)
pd.DataFrame(results)
参考リンク
- ANN Benchmarks
- aiss解説シリーズ(第一回)基本編: 細かくやる場合にここを参考にしたい