毎回忘れるのでメモ。。。
Install
- torch
- transformers
- accelerate (4bit, 8bit, device_mapを使う場合)
- bitsandbytes (4bit, 8bitを使う場合)
load
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "elyza/ELYZA-japanese-CodeLlama-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# float32
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
# float16
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16,, device_map="auto")
# 8bit
model = AutoModelForCausalLM.from_pretrained(model_name, load_in_8bit=True, device_map="auto")
# 4bit
model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=True, device_map="auto")
とはいえ…
The load_in_4bit
and load_in_8bit
arguments are deprecated and will be removed in the future versions. Please, pass a BitsAndBytesConfig
object in quantization_config
argument instead.
ということらしいのでそのうちダメになることでしょう。。。
以上。