11
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

タダの2年目Advent Calendar 2023

Day 19

LLMでよく見る関数についての解説

Last updated at Posted at 2023-12-03

NLPだったり機械学習を触ったことがある人ならなんとなーくわかるだろうけどぶっちゃけ詳しくわかってない人も多いと思うので、備忘録も兼ねてよく使う関数の動作やパラメータについて解説していこうと思います。

以下はLLMをとりあえず使ってみようでよく見かけるコードです。
コードはこのページを参考にさせていただきました。
ELYZA-japanese-Llama-2-7bの性能をテストしてみた

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# トークナイザーの読み込み
tokenizer = AutoTokenizer.from_pretrained("elyza/ELYZA-japanese-Llama-2-7b-instruct")
# モデルの読み込み
model = AutoModelForCausalLM.from_pretrained(
    "elyza/ELYZA-japanese-Llama-2-7b-instruct", 
    torch_dtype=torch.float16,
    device_map="auto"
)

prompt = """<s>[INST] <<SYS>>
あなたは誠実で優秀な日本人のアシスタントです。
<</SYS>>

富士山について教えてください。 [/INST]"""

with torch.no_grad():
    token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    output_ids = model.generate(
        token_ids.to(model.device),
        max_new_tokens=512,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1) :], skip_special_tokens=True)
print(output)

今回はLLMを使ってみようでよく見かけるであろう、以下の関数について解説していこうと思います。

AutoTokenizer.from_pretrained()
AutoModelForCausalLM.from_pretrained()
tokenizer.encode()
model.generate()
tokenizer.decode()

LLMで使われている関数は呼び出したいモデルやオプションによって使える引数や返り値が異なる場合があるので、注意する必要があります。今回は多くのモデルで使えるような引数を紹介します。

① AutoTokenizer.from_pretrained()
これは事前に訓練されたトークナイザーを読み込む関数です。公式ドキュメントはこちら
トークナイザーについてはこちらを参照してください。
こちらの関数の良く使われる引数は

pretrained_model_name_or_path:読み込みたいモデルのHuggingFace上の名前もしくはローカルに保存してあるモデルのパス。必須。
revision:モデルのバージョン名。オプション。

モデル名だけを引数を明示せずに渡しても基本的にいい感じに動きます。

②AutoModelForCausalLM.from_pretrained()
これは事前に訓練されたLLM本体を読み込む関数です。公式ドキュメントはこちら
初めて使用するHuggingFace上のモデルを指定してこの関数を実行した場合、モデルが一度Cacheにダウンロードされた後、GPUなどに読み込みます。二回目以降は基本的にCacheにダウンロードされたモデルを読み込むことになるので実行が早くなります。
こちらの関数の良く使われる引数は、

pretrained_model_name_or_path:読み込みたいモデルのHuggingFace上の名前もしくはローカルに保存してあるモデルのパス。必須。
revision:モデルのバージョン名をしているオプション。
torch_dtype:読み込むモデルの精度を決めるオプション。精度を上げるとモデルが重くなる。
device_map:モデルをどのGPUにロードするかを指定するオプション。
load_in_8bit:モデルをtorch.int8型で読み込むオプション。こちらを指定するとモデルがとても軽くなる。torch_dtypeを指定した場合こちらのオプションは使えない(未検証)。このオプションを指定するにはaccelerateとbitsandbytesというライブラリをimportする必要がある。load_in_4bitもある。

③tokenizer.encode()
これはテキストデータをencodeしてLLMが読み込める形に変換する関数です。
こちらの関数の良く使われる引数は、

text:変換したい文字列。必須。
add_special_tokens:デフォルトではTrue。文字列に自動的にスペシャルトークン(bos, eos)を自動的に付与してくれる。オプション。
padding: デフォルトではFalse。これをオンにすると自動的にpaddingをしてくれる。
max_length:入力される文字列の最大値を決定するオプション
return_tensors:返り値がどのようなデータ構造で返されるか決定するオプション。"pt"はpytorch用のtensor型を返す。

④model.generate()
これはLLMで出力を生成する関数です。
こちらの関数の良く使われる引数は

inputs:tokenizerでencodeされた文字列です。今回のプログラムで.to(model.device)が追加されているのは、基本的に文字列とmodelが同じデバイス上にロードされている必要があるためです。.to(model.device)がencodeのところについているパターンもあります。必須。
max_new_tokens:出力する文字列の最大長さを決定するパラメータです。設定しなかったらモデルに設定されているデフォルト値になります。
temperature:モデルのランダム性を決定するパラメータです。1-0のレンジで入力され、0に近づくほど創造的になり、1に近づくほど確率に基づいた答えを返すようになります。
pad_token_id,eos_token_id:パディングとEOSのtoken idを渡すためのパラメータです。オプションですが、入れといた方が無難です。
repitition_penalty:モデルが繰り返し同じ出力をするのを防ぐためのパラメータです。数値が高いほど強いペナルティがかかりますが、高すぎると助詞のようなよく使われるものにまでペナルティがかかってしまうので注意が必要です。

ドキュメントはここら辺を参考にしました。
rinna GPT-2モデルの生成パラメータ
Hugging Face
Huggingface Transformers 入門 (27) - rinnaの日本語GPT-2モデルの推論

⑤tokenizer.decode()
これはLLMで出力された結果を復号するための関数です。
こちらの関数では、一つ目の引数が処理を何個か加えられて与えられていうので、その処理も解説します。処理が複雑なので順次置き換えています。

処理A = output_ids.tolist()[0]

この処理はgenerate()で出力されたoutput_idsを.tolist()関数でリスト化しています。今回は二次元のリストに変換されています。その中で、私たちが見たい出力は0次元目にあるのでそれを取り出しています。

処理B = 処理A[token_ids.size(l]) :]

この処理は出力された文章から元々与えたプロンプトを除く処理です。
generateで出力されたoutput_idsには、元々与えたプロンプトが含まれています。
入力が"カレーはどこの料理ですか?"だとすると通常の出力は"カレーはどこの料理ですか? A.インドの料理です"
となります。しかし、入力は結果を確認するのには邪魔です。そのため、入力した文字列の長さを取得して、その長さ未満の出力をリストのスライスで取り除いています。
つまり、処理Aと処理Bを加えられたoutput_idsはピュアな出力と言えます。
こちらの関数でよく使われる引数は

inputs:今回のコードでは、生成された出力から最初に与えたプロンプトにあたる部分を除外したものを入力として与えています。必須。
skip_special_tokens:特殊トークンを出力された文字列から除外するかどうかを決定するパラメータです。好き好んで特殊トークンを読みたがる方以外はTrueにしておくのが良いでしょう。

いかがでしたか?
これを読んでもう少し知りたいって人は参考にしたドキュメントを読むのが一番いいと思います。
読んでいただきありがとうございました。

11
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?