8
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

話題のrinna-3.6bをColab無料枠で動かしたい!(Hugging Face load_in_8bitを使ったサンプルコード)

Last updated at Posted at 2023-05-19

最近話題のLLMってやつを触ってみたい

最近めちゃくちゃ大規模言語モデルって人気ですよね。でも触ってみようとするとColabとかKaggleだとかの無料のやつだとRAMが足りないとかVRAMが足りないとかでとっても苦労します。でもどうにかして触りたい・・・

対策

対策としてはシンプルにサブスクとか自分でマシン買うとか「課金する」というのが考えられます。でも課金するほどのモチベはないなあって人多いと思います、自分もそうです。
そこで無料枠で頑張って動かす方法を考えます。

手法

結論から言うとモデルを量子化して読み込めばいけ(る場合があり)ます。以下コードです。

rinna-instruction-3.6b-8bit.ipynb
# 必要なライブラリのインストール
!pip install accelerate
!pip install  -i https://test.pypi.org/simple/ bitsandbytes
!pip install transformers[sentencepiece]
import gc
import torch
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, logging
torch.cuda.empty_cache()
#モデルの定義
model_name = "rinna/japanese-gpt-neox-3.6b-instruction-sft"
#重みなしで読み込む(modelを定義するため)
with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained(model_name)
gc.collect()
#メモリの制限をかける(Out of memoryの対策)
device_map = infer_auto_device_map(
                                  model, max_memory={0: "11GiB", "cpu": "7GiB"}, 
                                   no_split_module_classes=["GPTNeoXLayer"]
                                   ,dtype=torch.int8
                              )
#load_in_8bitをTrueにして量子化する
model = AutoModelForCausalLM.from_pretrained(model_name,
                                          device_map = device_map, load_in_8bit=True,
                                          low_cpu_mem_usage=True,
                                          )
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = [
    {
        "speaker": "ユーザー",
        "text": "りんなは人間ですか?それともAIですか?。"
    }
]
prompt = [
    f"{uttr['speaker']}: {uttr['text']}"
    for uttr in prompt
]
prompt = "<NL>".join(prompt)
prompt = (
    prompt
    + "<NL>"
    + "システム: "
)

token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")

with torch.no_grad():
    output_ids = model.generate(
        token_ids.to(model.device),
        do_sample=True,
        max_new_tokens=128,
        temperature=0.7,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
output = output.replace("<NL>", "\n")
print(output)

これで動きます。りんなちゃんは人間なのか?AIなのか?

出力

image.png

。。。IBM製のAIだった:rolling_eyes:

まとめ

  • 大規模言語モデルはでっかいメモリがいる
  • でっかいメモリを使おうとすると金がかかる
  • 金をかけたくない-> 量子化すればワンチャン読み込めるかも?

最後にColabのリンクを共有します!prompt部分いじったら色々遊べます。
他のモデルでも3.6b程度までなら省メモリで動かせると思われます!

8
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?