概要
HuggingFaceのモデルのロード速度を3つの方法で比較しました。
- ローカルパス:model_nameにダウンロードしたものモデルフォルダのパス(e.g. "./mulitlingual-e5-large")を指定してロード
- キャッシュ:model_nameにURLのモデル名(e.g. "intfloat/multilingual-e5-large")を指定してキャッシュ済みのモデルをロードする
- pickle: pickleでダンプしたインスタンスをロード
結果としてはpickleが一番早かったです。ローカルパスとキャッシュは同じくらいでした。
背景
HuggingFaceのモデルをなるべく早くロードできる方法が何か知りたくて調べました。
環境
- M1 MacBookAirのメモリ16GB、SSD 1TB
実装
intfloat/multilingual-e5-large
の読込み速度を比較します。2GB以上あるので、より軽量なモデルで試したい場合は適宜、変更ください。
ローカルパス
huggingface_hubを使ってモデルをダウンロードします。
参考: https://wonderhorn.net/programming/hfdownload.html
import huggingface_hub
MODEL_NAME = "intfloat/multilingual-e5-large"
LOCAL_DIR = "./model/me5_local"
model_id = MODEL_NAME # 落とすモデル名
local_dir = LOCAL_DIR # 保存先フォルダ名
huggingface_hub.snapshot_download(model_id, local_dir=local_dir, local_dir_use_symlinks=False)
時間を計測します。楽をするためにモデルのロードにはlangchainのHuggingFaceEmbeddingsを使います。
from langchain.embeddings import HuggingFaceEmbeddings
import time
# 時間を測る
start = time.time()
# モデルを読み込み
model = HuggingFaceEmbeddings(model_name=LOCAL_DIR)
# 時間を出力
print("local:", time.time() - start)
local: 7.256836175918579
キャッシュ
モデル名をそのまま指定してロード速度を測ります。注意点として、もしキャッシュが存在していない場合は、一度ロードを実施してキャッシュを作成しておきます。
from langchain.embeddings import HuggingFaceEmbeddings
import os
MODEL_NAME = "intfloat/multilingual-e5-large"
# もしキャッシュがなければ、モデルをダウンロードしてキャッシュする
if not os.path.exists(".cache/huggingface/hub/models--intfloat--multilingual-e5-large"):
model = HuggingFaceEmbeddings(model_name=MODEL_NAME)
# 時間を測る
start = time.time()
# モデルを読み込み
model = HuggingFaceEmbeddings(model_name=MODEL_NAME)
# 時間を出力
print("cache:", time.time() - start)
cache: 5.98180627822876
pickle
ダンプ済みのインスタンスをロードします。
from langchain.embeddings import HuggingFaceEmbeddings
import os
import pickle
MODEL_NAME = "intfloat/multilingual-e5-large"
PICKLE_FILE = "me5.pkl"
model = HuggingFaceEmbeddings(model_name=MODEL_NAME)
# pickleで保存する
with open(PICKLE_FILE, "wb") as f:
pickle.dump(model, f)
# 時間を測る
start = time.time()
# モデルを読み込み
with open(PICKLE_FILE, "rb") as f:
model = pickle.load(f)
# 時間を出力
print("pickle:", time.time() - start)
pickle: 2.64107608795166
結果
出力だけまとめると以下のとおりです。
local: 7.256836175918579
cache: 5.98180627822876
pickle: 2.64107608795166
1回しか計測していないため誤差の範囲かもしれませんが、pickleが一番早かったです。バージョンやOSが固定で安定性が確保できるなら、pickleを使うのはありかもしれません。