8
6

More than 1 year has passed since last update.

ジョジョの奇妙な PEFT LoRA ファインチューニング - 究極生命体カーズとは? -

Last updated at Posted at 2023-08-29

大規模言語モデルをファインチューニングして、究極の大規模言語モデルにするのが本日のミッション。究極の大規模言語モデルで解きたい問題は、

question = "究極生命体カーズとは?"

環境設定

# パッケージのインストール
!pip install -q bitsandbytes datasets accelerate loralib
!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git
import torch

torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base model をそのまま使う

model_name = "cyberagent/open-calm-large"

ベースになる事前学習済みモデルとトークナイザーを読み込む。

# モデルの読み込み
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import bitsandbytes as bnb
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    #load_in_8bit=True,
    device_map="auto",
    #torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
2023-08-29 01:31:20.625774: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-29 01:31:20.669403: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-08-29 01:31:21.535199: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

「究極生命体カーズとは?」という問いをインプットにする。

inputs = tokenizer(question, return_tensors="pt").to(model.device)

準備はできた。ではいよいよ、問いに答える文章を10個生成するッ!

for _ in range(10):
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=True,
            temperature=0.9,
            top_p=0.75,
            top_k=40,
            num_beams=10,
            repetition_penalty=5.0,
            pad_token_id=tokenizer.pad_token_id,
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    print(output)
    print()
究極生命体カーズとは?
【ネタバレ注意】劇場版「Fate/stay night [Heaven’s Feel]」III.spring song 感想(前編)

究極生命体カーズとは?
【ネタバレ】映画『アベンジャーズ/インフィニティ・ウォー』のあらすじと感想!MCU(マーベル・シネマティック・ユニバース)シリーズの完結編

究極生命体カーズとは?
【tdr0361】東京ディズニーリゾートのCMには現役の学生キャストも出演していた!
【tdr0314】超低額宿泊許可制度・特例第一種(優先)指定制度って何?

究極生命体カーズとは?
【ネタバレ】映画『ファンタスティック・ビーストと黒い魔法使いの誕生』感想(あらすじ&評価)

究極生命体カーズとは?
【tdr0351】ジャングル・クルーズに登場するリサルガって何者?

究極生命体カーズとは?
【tdr0601】東京ディズニーリゾートのCMには○○の表示がない?!

究極生命体カーズとは?
【動画あり】若々しく歳を重ねたいなら筋トレを!パーソナルトレーナーが徹底解説

究極生命体カーズとは?
【ネタバレ】『恋んトス』男性新メンバーの有力候補は誰!? 第1話レビュー

究極生命体カーズとは?
【ネタバレ】映画『スパイダーマン:ファー・フロム・ホーム』感想(レビュー)

究極生命体カーズとは?
【ネタバレ】『恋んトス』男性新メンバーのプロフィール解禁!気になる相手には“亜麻色の髪”で触れようとする!?

し...質問に答えてねえエェェーーーッ!!

「答えよう」という「姿勢」すらねえエェェーーーッ!!

ファインチューニング用のデータセット

ということで、この「答えようという姿勢すらねェ」言語モデルをファインチューニングするッ! 訓練に用いるのはこのQ&Aデータセット!

dataset_name = "kunishou/databricks-dolly-15k-ja"
import transformers
from datasets import load_dataset

data = load_dataset(dataset_name)
data = data.map(lambda samples: tokenizer(samples["output"]), batched=True)
Map:   0%|          | 0/15015 [00:00<?, ? examples/s]

PEFTを用いたLoRAファインチューニング

では、ここから huggingface (ハギング・フェイス)にある PEFT というライブラリでLoRAファインチューニングを行う!

for param in model.parameters():
    param.requires_grad = False  # モデルをフリーズ
    if param.ndim == 1:
        # 安定のためにレイヤーノルムをfp32にキャスト
        param.data = param.data.to(torch.float32)

model.gradient_checkpointing_enable()
model.enable_input_require_grads()


class CastOutputToFloat(nn.Sequential):
    def forward(self, x):
        return super().forward(x).to(torch.float32)


model.lm_head = CastOutputToFloat(model.embed_out)
def print_trainable_parameters(model):
    """
    モデル内の学習可能なパラメータ数を出力
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    # target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)
print_trainable_parameters(model)
trainable params: 2359296 || all params: 842357760 || trainable%: 0.28008242008716105

訓練開始ッ!

trainer = transformers.Trainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_steps=100,
        max_steps=200,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False  # 警告を黙らせます。 推論のために再度有効にしてください。
trainer.train()
You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.




<div>

  <progress value='200' max='200' style='width:300px; height:20px; vertical-align: middle;'></progress>
  [200/200 02:24, Epoch 0/1]
</div>
<table border="1" class="dataframe">
Step Training Loss 1 3.763700 2 3.280300 3 3.034100 4 3.464900 5 3.155600 6 3.547800 7 3.444300 8 3.649600 9 3.376100 10 3.108800 11 3.226000 12 4.007900 13 3.232700 14 3.412200 15 3.515400 16 3.246000 17 2.965100 18 3.515900 19 3.429800 20 4.067200 21 3.390800 22 3.307400 23 3.568200 24 3.423400 25 3.318000 26 3.470600 27 3.366900 28 3.443400 29 3.222800 30 3.555100 31 3.087100 32 3.550800 33 3.526700 34 3.326400 35 3.025200 36 2.831000 37 3.794000 38 3.287500 39 3.234800 40 3.658000 41 3.281300 42 3.419900 43 3.246300 44 3.517200 45 2.909100 46 3.263900 47 3.497800 48 3.555900 49 3.374100 50 3.539200 51 3.449700 52 3.543500 53 3.152900 54 3.586600 55 3.439200 56 3.301100 57 3.088100 58 3.308500 59 3.389100 60 3.305100 61 3.106700 62 3.375400 63 3.938300 64 3.520700 65 3.049000 66 3.375400 67 3.082500 68 3.312400 69 3.538900 70 3.448900 71 2.978700 72 3.486800 73 3.217400 74 3.372400 75 3.369900 76 2.880600 77 3.129500 78 3.104000 79 3.033800 80 3.569300 81 3.164900 82 3.162200 83 3.335400 84 3.274300 85 2.953800 86 3.517600 87 3.330300 88 3.531300 89 3.157800 90 3.348100 91 3.547000 92 3.272600 93 3.443300 94 3.435500 95 3.304300 96 2.908200 97 3.194600 98 3.406600 99 3.129400 100 3.324200 101 3.486200 102 2.799900 103 3.501100 104 3.342600 105 3.288700 106 3.568300 107 3.368600 108 3.341900 109 3.275300 110 2.866300 111 3.409100 112 3.190300 113 3.484400 114 3.159600 115 3.373800 116 3.200800 117 3.215100 118 3.040700 119 3.364300 120 3.553900 121 3.037000 122 3.467100 123 3.318700 124 3.043100 125 3.283600 126 3.771500 127 3.209900 128 3.498100 129 3.322300 130 2.963200 131 3.181600 132 3.170100 133 3.613200 134 3.184200 135 3.044100 136 3.339600 137 3.098500 138 2.887900 139 3.397600 140 2.750500 141 3.353500 142 3.304300 143 3.008700 144 3.128500 145 3.456000 146 3.111500 147 3.265000 148 3.170000 149 3.490500 150 3.173100 151 3.485900 152 3.114800 153 3.145900 154 3.405300 155 3.305400 156 2.818800 157 3.155200 158 3.188900 159 3.327900 160 3.385000 161 3.285300 162 3.161700 163 3.008000 164 2.968700 165 2.921000 166 3.233200 167 3.522600 168 3.256600 169 3.195000 170 3.651500 171 3.254200 172 3.303700 173 3.153100 174 2.899100 175 3.081000 176 2.918000 177 3.282900 178 3.352800 179 3.427300 180 3.413900 181 3.114600 182 3.120700 183 3.117800 184 3.139100 185 3.014600 186 3.097700 187 3.391100 188 3.188200 189 3.418200 190 3.257300 191 3.208800 192 3.226200 193 3.090200 194 3.116800 195 2.832700 196 3.350500 197 3.221900 198 3.577800 199 3.098700 200 2.987700

『ハギング・フェイス』ッ!
ディ・モールト ディ・モールト
(非常に 非常に) 良いぞッ!
良く学習してるぞッ!

TrainOutput(global_step=200, training_loss=3.288494335412979, metrics={'train_runtime': 145.2264, 'train_samples_per_second': 22.035, 'train_steps_per_second': 1.377, 'total_flos': 2511526591881216.0, 'train_loss': 3.288494335412979, 'epoch': 0.21})

学習結果を確認しよう。質問に答える文章を10個生成するッ!

batch = tokenizer(question, return_tensors="pt").to(device)

for _ in range(10):
    with torch.cuda.amp.autocast():
        output_tokens = model.generate(**batch, max_new_tokens=50)
    
    print("\n\n", tokenizer.decode(output_tokens[0], skip_special_tokens=True))
  warnings.warn(
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
/usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.




 究極生命体カーズとは?
カーズとは、映画「カーズ」に登場するキャラクターです。

主人公のレーシングカーに乗って、レースをするという、

とてもシンプルなストーリーです。

しかし、このキャラクターには、

様々な秘密が隠されています。




Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.




 究極生命体カーズとは?
カーズとは、映画「カーズ」に登場するキャラクターです。

映画「カーズ」は、アメリカで大ヒットした映画です。

カーズの主人公は、主人公であるライトニング・マックィーンです。

マックィーンは、カーズの主人公


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.




 究極生命体カーズとは?
カーズとは、アメリカの漫画家、スティーブ・マックイーンによって描かれた、自動車を題材とした漫画です。

主人公のカーズは、自動車を運転する能力を持ち、自動車を運転する能力で、自動車を運転する能力で、自動車を


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.




 究極生命体カーズとは?
カーズとは、映画「カーズ」に登場するキャラクターです。

映画「カーズ」は、アメリカで大ヒットした映画です。

カーズの主人公は、主人公であるレーサーのライトニング・マックィーンです。

マックィーンは


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.




 究極生命体カーズとは?
カーズとは、映画「カーズ」に登場するキャラクターです。

映画「カーズ」は、車好きの少年カーと、レーサーとしての才能を持つ少女レーサーのカーとの友情を描いた作品です。

カーとレーサーの2人は


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.




 究極生命体カーズとは?
カーズとは、映画「カーズ」に登場するキャラクターです。

映画「カーズ」は、1989年に公開された映画です。

この映画は、車好きの人なら誰でも知っている、車好きにはたまらない映画です。

この映画は、車


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.




 究極生命体カーズとは?
カーズとは、映画「カーズ」に登場するキャラクターです。

映画「カーズ」は、1989年に公開された映画です。

この映画は、アメリカで大ヒットし、続編も製作されました。

続編は、カーズ2と3です


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.




 究極生命体カーズとは?
カーズとは、映画「カーズ」に登場するキャラクターです。

主人公のレーシングカーに乗って、レースをするという、

とてもシンプルなストーリーです。

しかし、このキャラクターには、

様々な秘密が隠されています。




Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.




 究極生命体カーズとは?
カーズとは、映画「カーズ」に登場するキャラクターです。

映画「カーズ」は、車好きの少年カーと、レーサーとしての才能を持つ少女レーサーのカーとの友情を描いた作品です。

カーとレーサーの2人は


 究極生命体カーズとは?
カーズとは、映画「カーズ」に登場するキャラクターです。

映画「カーズ」は、アメリカで大ヒットした映画です。

カーズの主人公は、主人公であるレーサーのライトニング・マックィーンです。

マックィーンは

良お~~~~しよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよしよし、

ファインチューニングして「質問に答える」ことができるようになったぞッ!答えが正しいかどうかは別として!

謝辞

Google Colab で PEFT による大規模言語モデルのファインチューニングを試すをめちゃくちゃ参考にしました。ありがとうございます!

8
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
6