LoginSignup
1
0

Huggingface pipelineの中身を見てみた

Last updated at Posted at 2024-05-01

Huggingface公式によるPipeineクラスの使い方の一例をもとに、内部の処理を追いかけてみた。
https://huggingface.co/blog/llama3#using-%F0%9F%A4%97-transformers

import transformers
import torch

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device="cuda",
)

messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]

prompt = pipeline.tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
)

terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

outputs = pipeline(
    prompt,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
print(outputs[0]["generated_text"][len(prompt):])

Pipelineクラスのインスタンス化

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device="cuda",
)

Pipelineクラスを親クラスとした派生クラスが、タスクごとにたくさん用意されている。
https://huggingface.co/docs/transformers/ja/main_classes/pipelines#transformers.pipeline.task
指定したタスク(上の例では"text-generation")に対応する派生クラスが使用される。
tokenizerを指定していない場合、モデル(strの場合)のデフォルトトークナイザが設定される(ref)。

Promptの作成

messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]

prompt = pipeline.tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
)

tokenizerクラス()が持つapply_chat_templateメソッドでプロンプトを整える。

apply_chat_template: Converts a list of dictionaries with "role" and "content" keys to a list of token ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to determine the format and control tokens to use when converting. When chat_template is None, it will fall back to the default_chat_template specified at the class level.

すなわち、上例のmessagesのような形式の入力を、対応するトークナイザのchat_templateに従ってプロンプトに整える。今回使うLlama-3-8B-Instructの場合、tokenizer_config.jsonで定義されているchat_templateが適用される。

add_generation_prompt (bool, optional): Whether to end the prompt with the token(s) that indicate the start of an assistant message. This is useful when you want to generate a response from the model. Note that this argument will be passed to the chat template, and so it must be supported in the template for this argument to have any effect.

add_generation_prompt=Trueにすることで、<|start_header_id|>assistant<|end_header_id|>\n\nまでがプロンプトテンプレートに含まれる。

Pipelineの呼び出し

outputs = pipeline(
    prompt,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
print(outputs[0]["generated_text"][len(prompt):])

今回の入力はstrなので、親クラスの__call__が呼ばれる(関係ないけどchat chat chatの🐈絵文字コメントアウトかわいい)。
スクリーンショット 2024-05-01 15.27.11.png

親クラスの__call__では、inputsの型によって処理を分けている。strなので一番下のself.run_singleが呼ばれる。
スクリーンショット 2024-05-01 15.21.24.png

run_singleは、inputsにself.preprocess,self.forward,self.postprocessの順で処理をかける。
スクリーンショット 2024-05-01 15.22.15.png

preprocessでは、プロンプトの形式により処理が分けられる。今回はstrなので、tokenizerがそのまま適用される。
スクリーンショット 2024-05-01 15.39.18.png

forward内では実質_forwardが呼ばれる。
スクリーンショット 2024-05-01 15.44.58.png

_forward内で実際にmodel.generateが実行される。
スクリーンショット 2024-05-01 15.45.47.png

postprocessでは、出力トークンのdecodeが行われる。
スクリーンショット 2024-05-01 15.52.38.png

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0