2
1

HuggingFaceの読込速度の比較(ローカルパス、キャッシュ、pickle)

Last updated at Posted at 2024-01-09

概要

HuggingFaceのモデルのロード速度を3つの方法で比較しました。

  • ローカルパス:model_nameにダウンロードしたものモデルフォルダのパス(e.g. "./mulitlingual-e5-large")を指定してロード
  • キャッシュ:model_nameにURLのモデル名(e.g. "intfloat/multilingual-e5-large")を指定してキャッシュ済みのモデルをロードする
  • pickle: pickleでダンプしたインスタンスをロード

結果としてはpickleが一番早かったです。ローカルパスとキャッシュは同じくらいでした。

背景

HuggingFaceのモデルをなるべく早くロードできる方法が何か知りたくて調べました。

環境

  • M1 MacBookAirのメモリ16GB、SSD 1TB

実装

intfloat/multilingual-e5-largeの読込み速度を比較します。2GB以上あるので、より軽量なモデルで試したい場合は適宜、変更ください。

ローカルパス

huggingface_hubを使ってモデルをダウンロードします。
参考: https://wonderhorn.net/programming/hfdownload.html

import huggingface_hub

MODEL_NAME = "intfloat/multilingual-e5-large"
LOCAL_DIR = "./model/me5_local"

model_id = MODEL_NAME  # 落とすモデル名
local_dir = LOCAL_DIR  # 保存先フォルダ名
huggingface_hub.snapshot_download(model_id, local_dir=local_dir, local_dir_use_symlinks=False)

時間を計測します。楽をするためにモデルのロードにはlangchainのHuggingFaceEmbeddingsを使います。

from langchain.embeddings import HuggingFaceEmbeddings
import time

# 時間を測る
start = time.time()
# モデルを読み込み
model = HuggingFaceEmbeddings(model_name=LOCAL_DIR)
# 時間を出力
print("local:", time.time() - start)
local: 7.256836175918579

キャッシュ

モデル名をそのまま指定してロード速度を測ります。注意点として、もしキャッシュが存在していない場合は、一度ロードを実施してキャッシュを作成しておきます。

from langchain.embeddings import HuggingFaceEmbeddings
import os

MODEL_NAME = "intfloat/multilingual-e5-large"
# もしキャッシュがなければ、モデルをダウンロードしてキャッシュする
if not os.path.exists(".cache/huggingface/hub/models--intfloat--multilingual-e5-large"):
    model = HuggingFaceEmbeddings(model_name=MODEL_NAME)

# 時間を測る
start = time.time()
# モデルを読み込み
model = HuggingFaceEmbeddings(model_name=MODEL_NAME)
# 時間を出力
print("cache:", time.time() - start)
cache: 5.98180627822876

pickle

ダンプ済みのインスタンスをロードします。

from langchain.embeddings import HuggingFaceEmbeddings
import os
import pickle

MODEL_NAME = "intfloat/multilingual-e5-large"
PICKLE_FILE = "me5.pkl"

model = HuggingFaceEmbeddings(model_name=MODEL_NAME)
# pickleで保存する
with open(PICKLE_FILE, "wb") as f:
    pickle.dump(model, f)
    

# 時間を測る
start = time.time()
# モデルを読み込み
with open(PICKLE_FILE, "rb") as f:
    model = pickle.load(f)
# 時間を出力
print("pickle:", time.time() - start)
pickle: 2.64107608795166

結果

出力だけまとめると以下のとおりです。

local: 7.256836175918579
cache: 5.98180627822876
pickle: 2.64107608795166

1回しか計測していないため誤差の範囲かもしれませんが、pickleが一番早かったです。バージョンやOSが固定で安定性が確保できるなら、pickleを使うのはありかもしれません。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1