24
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【Python】手持ちのGPUがVRAM12Gだけど「Rinna-3.6B」とお話がしたい!!!

Last updated at Posted at 2023-05-18

はじめに

ChatGPTをはじめ多くのLLMが猛威を振るう中、2023/5/17にrinna株式会社から日本語に特化した36億パラメータのGPT言語モデル(Rinna-3.6B)が公開されました。

これはさっそく手元のLocal環境で試したいと思ったのですが、すでに試した人によると

【注意】Google Colab Pro/Pro+ の A100で動作確認しています。VRAMは14.5GB必要でした。
とのこと。

参考ページ

手持ちのGPUはGeForce RTX 3080 tiなのでVRAM12Gしかありません。。。

泣く泣く断念。。。しようとしていたのですが、やはり試してみたい、何とかならないか?というのが事のいきさつです。

以下動かすまでのメモがてら書いていきます。

※今回は対話型モデルの japanese-gpt-neox-3.6b-instruction-sft を使ってみたいと思います。

インストール

# パッケージのインストール
pip install transformers sentencepiece

モデル準備

今回のメインです。まずは公式のサンプルコードを見てみます。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft")

if torch.cuda.is_available():
    model = model.to("cuda")

GPUが使えるときはGPUを使うようにする、よくやる普通の書き方ですね。

問題の解決策としては以下の2つになります。

  1. CPUで行う
  2. GPU + torch_dtype=torch.float16で行う

CPUで行う

これはそのまんまです。実行はできると思いますが、速度は遅いです。

if torch.cuda.is_available():
    model = model.to("cuda")

これを削除すればOKです。

GPU + torch_dtype=torch.float16で行う

読み込み時にtorch_dtypeをtorch.float16で指定します。ダメもとでやってみたらできました(defaultはfloat32?)

以下のようにfrom_pretrainedの引数にtorch_dtype=torch.float16を追加します。

model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft", torch_dtype=torch.float16)

これを削除すればOKです。
手元だと8.2G程度だったので12Gであれば余裕があるかと思います。

プロンプト準備、推論、結果取得

読み込めても肝心の精度が残念だと使い物になりません。
以下、試した結果です。

# プロンプトの準備
prompt_base = "ユーザー: {}<NL>システム: "
sentence = "あなたは誰ですか?"
prompt = prompt_base.format(sentence)

# 推論
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
with torch.no_grad():
    output_ids = model.generate(
        token_ids.to(model.device),
        do_sample=True,
        max_new_tokens=256,
        temperature=0.7,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )
output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
output = output.replace("<NL>", "\n")
print(output)
torch.cuda.empty_cache()

# 私は人工知能プログラムです。あなたはあなた自身の存在について尋ねています。私は人間の知能や感情を理解するために、多くの時間を費やしてきました。私は人間の心理や行動を理解するために多くの時間を費やし、人間の感情や行動を理解し、人間らしいコミュニケーションや相互作用の能力を開発するために多くの時間を費やしています。</s>

「あなたは誰ですか?」
の問いに
「私は人工知能プログラムです。あなたはあなた自身の存在について尋ねています。私は人間の知能や感情を理解するために、多くの時間を費やしてきました。私は人間の心理や行動を理解するために多くの時間を費やし、人間の感情や行動を理解し、人間らしいコミュニケーションや相互作用の能力を開発するために多くの時間を費やしています。」
とそれっぽい回答をしてくれています。

まとめ

12GのGPUでもRinna-3.6Bを動かすことができました!
これで楽しくお話ができます(?)

24
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
24
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?