LoginSignup
6
2

LangChainを用いてHuggingFaceモデルをローカルで使う

Posted at

概要

HuggingFace Hubに登録されているモデルをローカルにダウンロードして、LangChain経由で対話型のプログラムを作成する。

前提条件

ランタイムは Python 3.11.6 を使用する。今回使用するモデルには、比較的汎用的な計算機でも動作する "TinyLlama/TinyLlama-1.1B-Chat-v1.0" を採用した

使用する Python パッケージは以下の通り。

パッケージ名 バージョン
langchain 0.0.322
transformers 4.35.2
huggingface_hub 0.20.2

実装コード

実装は以下の通り。

snapshot_download関数でリポジトリのスナップショットを撮り、そのスナップショットにあるファイルを元に tokenizer と model を準備する。次にHuggingFacePipelineクラスをインスタンス化し、これとChatPromptTemplateでチェーンを作成し、chain.invokeメソッドで実行する。

app.py
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.prompts import ChatPromptTemplate
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from huggingface_hub import snapshot_download

# 1.HuggingFaceからモデルをダウンロード
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
download_path = snapshot_download(repo_id=model_id)

# 2.TokenizerとModelをスナップショットからインスタンス化
tokenizer = AutoTokenizer.from_pretrained(download_path)
model = AutoModelForCausalLM.from_pretrained(download_path)

# 3.HuggingFacePipelineを作成
# ref. https://python.langchain.com/docs/integrations/llms/huggingface_pipelines
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)
hf = HuggingFacePipeline(pipeline=pipe)

# 4.ChatPromptTemplateを作成
prompt = ChatPromptTemplate.from_template(
    "<|system|>You are a chatbot who can help code!</s><|user|>{question}</s><|assistant|>"
)
chain = prompt | hf

# 5.Chainの実行
question = "Write me a function to print out 'Hello, World!' to the CLI in Python."
answer = chain.invoke({"question": question})
print(answer)

実行結果

実行した結果は以下の通り。
スナップショットから "TinyLlama/TinyLlama-1.1B-Chat-v1.0" のデータをフェッチし、LangChainを介してモデルとの対話を行えている事がわかる。

bash
$ python app.py
Fetching 10 files: 100%|█████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 10860.45it/s]

Certainly! Here's a Python function that prints 'Hello, World!' to the CLI:

```python
def print_hello_world():
    print("Hello, World!")
```

To use this function, simply call it with no arguments:

```python
print_hello_world()
```

This will print 'Hello, World!' to the CLI. Note that this function assumes that you have already imported the `print` function from the `sys` module. If you're not using Python 3.6 or later, you may need to import the `print` function explicitly:

```python
from sys import print as print_
```

補足:GPUによる推論時間の短縮

GPUが利用できるのであれば、引数にdevice=0を指定することで、計算時間をかなり短縮できる。

app.py (抜粋)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256, device=0)

以下は i7-13700F + RTX 4070 の構成で、10回推論を行ったときの計算時間の平均である。
※ かなり雑な計測なので、実際に手元で試してほしい

計測ケース 計測時間(秒)
CPU のみ 24.3
GPU サポート 2.3

参考文献

6
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
2