半年前に作ったものが出てきました.LoRAとSFTをそれぞれ自前で実装し,「つくよみちゃん会話AI育成計画」というデータセットを使って訓練しています.データセットの詳細は▼
TRL,Datasets,PEFTを一切使わず,PyTorchの生機能だけでやり切っています(というか当時の私は知らなかった).色々と原始的な書き方が出てきて面白いので,ここに残しておこうと思います.これらのライブラリでスマートに書ける所については🤖が口出しします.
🤖<ライブラリを使えばいいのに...
本稿の最後に学習後モデルの応答例を示しています.かなりつくよみちゃんっぽくなっているので是非ご覧ください.
🗳️インポート
先述したとおり,trl
,datasets
,peft
はインポートしません!
インポート
import transformers
from transformers import GPT2LMHeadModel, T5Tokenizer, TextStreamer
import torch
from torch import ones, zeros, randn, tensor, Tensor, float32, no_grad, concat, nn, ones_like, cuda
from torch.nn import Module, Parameter, init, Conv1d
from torch.nn import functional as F
from torch.optim import Adam
import time
import json
import random
from pathlib import Path
from tqdm import tqdm
device = torch.device("cuda:1")
print(f"{device=}")
🏗️ベースモデルを読み込む
rinna/japanese-gpt-1b
を使用します.
ベースモデルを読み込む
tokenizer = T5Tokenizer.from_pretrained("./rinna_japanese-gpt-1b/", padding_side="left")
streamer = TextStreamer(tokenizer)
ivocab = {v: k for k, v in tokenizer.get_vocab().items()}
base = GPT2LMHeadModel.from_pretrained("./rinna_japanese-gpt-1b/").to(device)
🔥LoRAの自前実装
論文を読みながら頑張りました(とはいっても簡単な構造だから簡単だった.).
- 自前の
Module
内にて,パラメータを表す行列等をParameter
で囲うと,後でパラメータ管理(e.g. 学習対象/外を切り分ける)をする時に便利です. -
LoRAConv1D
は,元の1次元畳み込み層を貰って,元の1次元畳み込み層+LoRA層を表す層.ベースモデルのアテンション機構がConv1D
で実装されていたため,これに対応する層を作りました. -
重要:
requires_grad_
メソッドを使って,元モデルの重みは学習対象から外し(定数化),LoRAの重みのみ学習対象とします.
LoRAの自前実装
class LoRA(Module):
def __init__(
self,
in_features: int,
out_features: int,
r: int,
device: torch.device,
dtype=float32,
) -> None:
super().__init__()
self.r = r
self.in_features = in_features
self.out_features = out_features
self.A = Parameter(zeros((r, in_features), dtype=dtype, device=device))
self.B = Parameter(zeros((out_features, r), dtype=dtype, device=device))
# init.kaiming_uniform_(self.A, a=5**0.5) # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py#L124
init.normal_(self.A, std=1./r)
def forward(self, input: Tensor) -> Tensor:
return F.linear(input, self.B @ self.A)/self.r
def __str__(self) -> str:
return f'in_features={self.in_features}, out_features={self.out_features}, r={self.r}'
def extra_repr(self) -> str:
return self.__str__()
class LoRAConv1d(Module):
def __init__(self, original: Conv1d, r: int) -> None:
super().__init__()
self.original = original
original.requires_grad_(False)
in_features, out_features = original.weight.size()
self.lora = LoRA(
in_features, out_features, r, original.weight.device, original.weight.dtype
)
self.lora.requires_grad_(True)
def forward(self, input: Tensor) -> Tensor:
return self.original(input) + self.lora(input)
🤖<PEFTを使えばいいのに...
🔥LoRA適用後モデルを作る
元モデルを読み込んだ後,元モデルの畳み込み層Conv1d
をLoRAConv1d
に置き換えます.
LoRA適用後モデルを作る
def create_m1(r:int=3) -> GPT2LMHeadModel:
m1 = GPT2LMHeadModel.from_pretrained(
"./rinna_japanese-gpt-1b/").to(device)
m1.requires_grad_(False)
gpt2blocks = m1.transformer.h
for k in range(len(gpt2blocks)):
lora_c_attn = LoRAConv1d(
m1.transformer.h[k].attn.c_attn,
r
)
m1.transformer.h[k].attn.c_attn = lora_c_attn
return m1
なんかここでも元モデルの読み込み処理をしてますね...まあ小さいことは気にするな
🤖<PEFTを使えばいいのに...
🗞️ここで用語を定義しておく
ここから色々とデータの加工・変形が続きます.私のコードで出てくる用語は一般的なものとはちょっとズレているのでここで整理しておきます.
Conversation Tuple[文字列, 文字列]
: アシスタントによる応答直前とアシスタントによる応答終了後それぞれの文字列ペア.文字列の形式は,チャットテンプレート適用後&特別トークン無し.それぞれstart
とfinal
と呼ぶ.
start = "user:お腹が鳴る\nassistant:"
final = "user:お腹が鳴る\nassistant:何か召し上がりますか?"
conversation = (start, final)
Sample Tuple[ベクトル, 整数]
: テキストを表すトークンID列とその後に続くトークンIDのペア.それぞれcontext
とnext_id
と呼んでおきます.
full = tokenizer.encode("user:お腹が鳴る\nassistant:何か召し")
context = full[:-1] # お腹が〜何か
next_id = full[-1] # 召し
sample = (context, next_id)
Conversationからはアシスタントによる応答文のトークン数だけのSampleが得られますね.
Batch Tuple[行列, 行列, ベクトル]
: Sampleをいくつか束ねた形式です.モデルが読み込まれるように,長さを揃える処理(後述)がされています.要素はそれぞれcontext
を束ねたもの,アテンションマスク,next_id
をベクトル形式に並べたものです.
🤖<HuggingFaceが推奨するデータセットフォーマットがあるのに...
📚学習データの手配
つくよみちゃん会話AI育成計画のデータを,次のようにタブ区切りのテキストファイルにしておきます.
- まずこれらのテキストファイルからConversationデータを作ります.
- その次に,ConversationデータからSampleデータを作ります.
学習データの手配
USER_TEMPLATE = "user:{}"
ASSISTANT_TEMPLATE = "assistant:{}"
VALIDATION_RATE = 0.0
data_raw = Path("./tsukuyomi-conversations.txt").read_text(encoding="utf-8")
CONVERSATION_LIST = []
for row in data_raw.split("\n"):
user_text, assistant_text = row.split("\t")
start = "\n".join([
USER_TEMPLATE.format(user_text),
ASSISTANT_TEMPLATE.format("")
])
final = "\n".join([
USER_TEMPLATE.format(user_text),
ASSISTANT_TEMPLATE.format(assistant_text)
])
CONVERSATION_LIST.append((start, final))
boundary = int(len(CONVERSATION_LIST)*VALIDATION_RATE)
VAL_CONVERSATION_LIST = CONVERSATION_LIST[:boundary]
TRAIN_CONVERSATION_LIST = CONVERSATION_LIST[boundary:]
🤖<tokenizer
のapply_chat_template
を使えばいいのに...
def conversation2sample_list(conversation: tuple, add_pad_token: bool = False) -> list:
start, final = conversation
if add_pad_token:
start = tokenizer.pad_token + start
final = tokenizer.pad_token + final
start_ids = tokenizer.encode(
start, return_tensors="pt", add_special_tokens=False).to(device)
final_ids = tokenizer.encode(
final+tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(device)
start_len = len(start_ids[0])
final_len = len(final_ids[0])
sample_list = []
for e in range(start_len, final_len):
sample = (final_ids[:, :e], final_ids[:, e])
sample_list.append(sample)
return sample_list
train_sample_list = []
for conversation in TRAIN_CONVERSATION_LIST:
sample_list_i = conversation2sample_list(conversation)
train_sample_list += (sample_list_i)
random.shuffle(train_sample_list)
🤖<datasetsを使えばいいのに...
🧺バッチ化
モデルに渡すバッチは,それぞれのサンプルのデータ長が揃ってないといけません(そうしないとそもそもTensor
オブジェクトとか作れないからね).しかし,文字列というのはサンプルにとってバラバラなんで,どうにかして揃えないといけなくなります.
そこで,サンプルのデータ長を揃えてアテンションマスクも出力するメソッドを自作しました.
バッチ化
まあやってることはDataCollatorと同じですね.しかし当時の僕はDataCollatorも知りませんでした.def tieup_inputs(input_list: Tensor) -> tuple:
N = len(input_list)
L_list = [len(input[0]) for input in input_list]
L_max = max(L_list)
input_mat = tokenizer.pad_token_id * ones((N, L_max), dtype=torch.int64).to(device)
attention_mask = ones((N, L_max), dtype=torch.int64).to(device)
for i, (T_i, L_i) in enumerate(zip(input_list, L_list)):
input_mat[i, L_max - L_i :] = T_i
attention_mask[i, : L_max - L_i] = 0
return input_mat, attention_mask
BATCH_SIZE = 16
train_batch_list = []
for i in range(0, len(train_sample_list), BATCH_SIZE):
sample_list = train_sample_list[i:i+BATCH_SIZE]
input_list = [s[0] for s in sample_list]
inputs, attention_mask = tieup_inputs(input_list)
nexttokens = tensor([s[1] for s in sample_list]).to(device)
train_batch_list.append((inputs, attention_mask, nexttokens))
len(train_batch_list)
🤖<SFTTrainerとDataCollatorForCompletionOnlyLMを使えばいいのに...
🛌(2回目以降)訓練済み重みの読み込み
load_state_dict
メソッドは,.pt形式に記載された辞書データのKeyと自分のモデルの層名を照らし合わせて,自分のモデルの重みを更新します.LoRAの重みだけが保存されている想定なので,この処理では元モデルの重みは一歳変更を受けず,LoRAの重みのみが読み込みされます.
(2回目以降)訓練済み重みの読み込み
m = create_m1(r=4)
m.load_state_dict(torch.load("./2lora-r4-epoch4.pt"), strict=False)
n_param = sum([p.numel() for p in m.parameters()])
n_trainable_param = sum([p.numel() for p in m.parameters() if p.requires_grad])
print(f"{n_param=:e}, {n_trainable_param=:e}")
_IncompatibleKeys(missing_keys=['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.c_attn.original.weight', 'transformer.h.0.attn.c_attn.original.bias', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.0.mlp.c_proj.bias', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.c_attn.original.weight', 'transformer.h.1.attn.c_attn.original.bias', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.ln_2.weight', 'transformer.h.1.ln_2.bias', 'transformer.h.1.mlp.c_fc.weight', 'transformer.h.1.mlp.c_fc.bias', 'transformer.h.1.mlp.c_proj.weight', 'transformer.h.1.mlp.c_proj.bias', 'transformer.h.2.ln_1.weight', 'transformer.h.2.ln_1.bias', 'transformer.h.2.attn.c_attn.original.weight', 'transformer.h.2.attn.c_attn.original.bias', 'transformer.h.2.attn.c_proj.weight', 'transformer.h.2.attn.c_proj.bias', 'transformer.h.2.ln_2.weight', 'transformer.h.2.ln_2.bias', 'transformer.h.2.mlp.c_fc.weight', 'transformer.h.2.mlp.c_fc.bias', 'transformer.h.2.mlp.c_proj.weight', 'transformer.h.2.mlp.c_proj.bias', 'transformer.h.3.ln_1.weight', 'transformer.h.3.ln_1.bias', 'transformer.h.3.attn.c_attn.original.weight', 'transformer.h.3.attn.c_attn.original.bias', 'transformer.h.3.attn.c_proj.weight', 'transformer.h.3.attn.c_proj.bias', 'transformer.h.3.ln_2.weight', 'transformer.h.3.ln_2.bias', 'transformer.h.3.mlp.c_fc.weight', 'transformer.h.3.mlp.c_fc.bias', 'transformer.h.3.mlp.c_proj.weight', 'transformer.h.3.mlp.c_proj.bias', 'transformer.h.4.ln_1.weight', 'transformer.h.4.ln_1.bias', 'transformer.h.4.attn.c_attn.original.weight', 'transformer.h.4.attn.c_attn.original.bias', 'transformer.h.4.attn.c_proj.weight', 'transformer.h.4.attn.c_proj.bias', 'transformer.h.4.ln_2.weight', 'transformer.h.4.ln_2.bias', 'transformer.h.4.mlp.c_fc.weight', 'transformer.h.4.mlp.c_fc.bias', 'transformer.h.4.mlp.c_proj.weight', 'transformer.h.4.mlp.c_proj.bias', 'transformer.h.5.ln_1.weight', 'transformer.h.5.ln_1.bias', 'transformer.h.5.attn.c_attn.original.weight', 'transformer.h.5.attn.c_attn.original.bias', 'transformer.h.5.attn.c_proj.weight', 'transformer.h.5.attn.c_proj.bias', 'transformer.h.5.ln_2.weight', 'transformer.h.5.ln_2.bias', 'transformer.h.5.mlp.c_fc.weight', 'transformer.h.5.mlp.c_fc.bias', 'transformer.h.5.mlp.c_proj.weight', 'transformer.h.5.mlp.c_proj.bias', 'transformer.h.6.ln_1.weight', 'transformer.h.6.ln_1.bias', 'transformer.h.6.attn.c_attn.original.weight', 'transformer.h.6.attn.c_attn.original.bias', 'transformer.h.6.attn.c_proj.weight', 'transformer.h.6.attn.c_proj.bias', 'transformer.h.6.ln_2.weight', 'transformer.h.6.ln_2.bias', 'transformer.h.6.mlp.c_fc.weight', 'transformer.h.6.mlp.c_fc.bias', 'transformer.h.6.mlp.c_proj.weight', 'transformer.h.6.mlp.c_proj.bias', 'transformer.h.7.ln_1.weight', 'transformer.h.7.ln_1.bias', 'transformer.h.7.attn.c_attn.original.weight', 'transformer.h.7.attn.c_attn.original.bias', 'transformer.h.7.attn.c_proj.weight', 'transformer.h.7.attn.c_proj.bias', 'transformer.h.7.ln_2.weight', 'transformer.h.7.ln_2.bias', 'transformer.h.7.mlp.c_fc.weight', 'transformer.h.7.mlp.c_fc.bias', 'transformer.h.7.mlp.c_proj.weight', 'transformer.h.7.mlp.c_proj.bias', 'transformer.h.8.ln_1.weight', 'transformer.h.8.ln_1.bias', 'transformer.h.8.attn.c_attn.original.weight', 'transformer.h.8.attn.c_attn.original.bias', 'transformer.h.8.attn.c_proj.weight', 'transformer.h.8.attn.c_proj.bias', 'transformer.h.8.ln_2.weight', 'transformer.h.8.ln_2.bias', 'transformer.h.8.mlp.c_fc.weight', 'transformer.h.8.mlp.c_fc.bias', 'transformer.h.8.mlp.c_proj.weight', 'transformer.h.8.mlp.c_proj.bias', 'transformer.h.9.ln_1.weight', 'transformer.h.9.ln_1.bias', 'transformer.h.9.attn.c_attn.original.weight', 'transformer.h.9.attn.c_attn.original.bias', 'transformer.h.9.attn.c_proj.weight', 'transformer.h.9.attn.c_proj.bias', 'transformer.h.9.ln_2.weight', 'transformer.h.9.ln_2.bias', 'transformer.h.9.mlp.c_fc.weight', 'transformer.h.9.mlp.c_fc.bias', 'transformer.h.9.mlp.c_proj.weight', 'transformer.h.9.mlp.c_proj.bias', 'transformer.h.10.ln_1.weight', 'transformer.h.10.ln_1.bias', 'transformer.h.10.attn.c_attn.original.weight', 'transformer.h.10.attn.c_attn.original.bias', 'transformer.h.10.attn.c_proj.weight', 'transformer.h.10.attn.c_proj.bias', 'transformer.h.10.ln_2.weight', 'transformer.h.10.ln_2.bias', 'transformer.h.10.mlp.c_fc.weight', 'transformer.h.10.mlp.c_fc.bias', 'transformer.h.10.mlp.c_proj.weight', 'transformer.h.10.mlp.c_proj.bias', 'transformer.h.11.ln_1.weight', 'transformer.h.11.ln_1.bias', 'transformer.h.11.attn.c_attn.original.weight', 'transformer.h.11.attn.c_attn.original.bias', 'transformer.h.11.attn.c_proj.weight', 'transformer.h.11.attn.c_proj.bias', 'transformer.h.11.ln_2.weight', 'transformer.h.11.ln_2.bias', 'transformer.h.11.mlp.c_fc.weight', 'transformer.h.11.mlp.c_fc.bias', 'transformer.h.11.mlp.c_proj.weight', 'transformer.h.11.mlp.c_proj.bias', 'transformer.h.12.ln_1.weight', 'transformer.h.12.ln_1.bias', 'transformer.h.12.attn.c_attn.original.weight', 'transformer.h.12.attn.c_attn.original.bias', 'transformer.h.12.attn.c_proj.weight', 'transformer.h.12.attn.c_proj.bias', 'transformer.h.12.ln_2.weight', 'transformer.h.12.ln_2.bias', 'transformer.h.12.mlp.c_fc.weight', 'transformer.h.12.mlp.c_fc.bias', 'transformer.h.12.mlp.c_proj.weight', 'transformer.h.12.mlp.c_proj.bias', 'transformer.h.13.ln_1.weight', 'transformer.h.13.ln_1.bias', 'transformer.h.13.attn.c_attn.original.weight', 'transformer.h.13.attn.c_attn.original.bias', 'transformer.h.13.attn.c_proj.weight', 'transformer.h.13.attn.c_proj.bias', 'transformer.h.13.ln_2.weight', 'transformer.h.13.ln_2.bias', 'transformer.h.13.mlp.c_fc.weight', 'transformer.h.13.mlp.c_fc.bias', 'transformer.h.13.mlp.c_proj.weight', 'transformer.h.13.mlp.c_proj.bias', 'transformer.h.14.ln_1.weight', 'transformer.h.14.ln_1.bias', 'transformer.h.14.attn.c_attn.original.weight', 'transformer.h.14.attn.c_attn.original.bias', 'transformer.h.14.attn.c_proj.weight', 'transformer.h.14.attn.c_proj.bias', 'transformer.h.14.ln_2.weight', 'transformer.h.14.ln_2.bias', 'transformer.h.14.mlp.c_fc.weight', 'transformer.h.14.mlp.c_fc.bias', 'transformer.h.14.mlp.c_proj.weight', 'transformer.h.14.mlp.c_proj.bias', 'transformer.h.15.ln_1.weight', 'transformer.h.15.ln_1.bias', 'transformer.h.15.attn.c_attn.original.weight', 'transformer.h.15.attn.c_attn.original.bias', 'transformer.h.15.attn.c_proj.weight', 'transformer.h.15.attn.c_proj.bias', 'transformer.h.15.ln_2.weight', 'transformer.h.15.ln_2.bias', 'transformer.h.15.mlp.c_fc.weight', 'transformer.h.15.mlp.c_fc.bias', 'transformer.h.15.mlp.c_proj.weight', 'transformer.h.15.mlp.c_proj.bias', 'transformer.h.16.ln_1.weight', 'transformer.h.16.ln_1.bias', 'transformer.h.16.attn.c_attn.original.weight', 'transformer.h.16.attn.c_attn.original.bias', 'transformer.h.16.attn.c_proj.weight', 'transformer.h.16.attn.c_proj.bias', 'transformer.h.16.ln_2.weight', 'transformer.h.16.ln_2.bias', 'transformer.h.16.mlp.c_fc.weight', 'transformer.h.16.mlp.c_fc.bias', 'transformer.h.16.mlp.c_proj.weight', 'transformer.h.16.mlp.c_proj.bias', 'transformer.h.17.ln_1.weight', 'transformer.h.17.ln_1.bias', 'transformer.h.17.attn.c_attn.original.weight', 'transformer.h.17.attn.c_attn.original.bias', 'transformer.h.17.attn.c_proj.weight', 'transformer.h.17.attn.c_proj.bias', 'transformer.h.17.ln_2.weight', 'transformer.h.17.ln_2.bias', 'transformer.h.17.mlp.c_fc.weight', 'transformer.h.17.mlp.c_fc.bias', 'transformer.h.17.mlp.c_proj.weight', 'transformer.h.17.mlp.c_proj.bias', 'transformer.h.18.ln_1.weight', 'transformer.h.18.ln_1.bias', 'transformer.h.18.attn.c_attn.original.weight', 'transformer.h.18.attn.c_attn.original.bias', 'transformer.h.18.attn.c_proj.weight', 'transformer.h.18.attn.c_proj.bias', 'transformer.h.18.ln_2.weight', 'transformer.h.18.ln_2.bias', 'transformer.h.18.mlp.c_fc.weight', 'transformer.h.18.mlp.c_fc.bias', 'transformer.h.18.mlp.c_proj.weight', 'transformer.h.18.mlp.c_proj.bias', 'transformer.h.19.ln_1.weight', 'transformer.h.19.ln_1.bias', 'transformer.h.19.attn.c_attn.original.weight', 'transformer.h.19.attn.c_attn.original.bias', 'transformer.h.19.attn.c_proj.weight', 'transformer.h.19.attn.c_proj.bias', 'transformer.h.19.ln_2.weight', 'transformer.h.19.ln_2.bias', 'transformer.h.19.mlp.c_fc.weight', 'transformer.h.19.mlp.c_fc.bias', 'transformer.h.19.mlp.c_proj.weight', 'transformer.h.19.mlp.c_proj.bias', 'transformer.h.20.ln_1.weight', 'transformer.h.20.ln_1.bias', 'transformer.h.20.attn.c_attn.original.weight', 'transformer.h.20.attn.c_attn.original.bias', 'transformer.h.20.attn.c_proj.weight', 'transformer.h.20.attn.c_proj.bias', 'transformer.h.20.ln_2.weight', 'transformer.h.20.ln_2.bias', 'transformer.h.20.mlp.c_fc.weight', 'transformer.h.20.mlp.c_fc.bias', 'transformer.h.20.mlp.c_proj.weight', 'transformer.h.20.mlp.c_proj.bias', 'transformer.h.21.ln_1.weight', 'transformer.h.21.ln_1.bias', 'transformer.h.21.attn.c_attn.original.weight', 'transformer.h.21.attn.c_attn.original.bias', 'transformer.h.21.attn.c_proj.weight', 'transformer.h.21.attn.c_proj.bias', 'transformer.h.21.ln_2.weight', 'transformer.h.21.ln_2.bias', 'transformer.h.21.mlp.c_fc.weight', 'transformer.h.21.mlp.c_fc.bias', 'transformer.h.21.mlp.c_proj.weight', 'transformer.h.21.mlp.c_proj.bias', 'transformer.h.22.ln_1.weight', 'transformer.h.22.ln_1.bias', 'transformer.h.22.attn.c_attn.original.weight', 'transformer.h.22.attn.c_attn.original.bias', 'transformer.h.22.attn.c_proj.weight', 'transformer.h.22.attn.c_proj.bias', 'transformer.h.22.ln_2.weight', 'transformer.h.22.ln_2.bias', 'transformer.h.22.mlp.c_fc.weight', 'transformer.h.22.mlp.c_fc.bias', 'transformer.h.22.mlp.c_proj.weight', 'transformer.h.22.mlp.c_proj.bias', 'transformer.h.23.ln_1.weight', 'transformer.h.23.ln_1.bias', 'transformer.h.23.attn.c_attn.original.weight', 'transformer.h.23.attn.c_attn.original.bias', 'transformer.h.23.attn.c_proj.weight', 'transformer.h.23.attn.c_proj.bias', 'transformer.h.23.ln_2.weight', 'transformer.h.23.ln_2.bias', 'transformer.h.23.mlp.c_fc.weight', 'transformer.h.23.mlp.c_fc.bias', 'transformer.h.23.mlp.c_proj.weight', 'transformer.h.23.mlp.c_proj.bias', 'transformer.ln_f.weight', 'transformer.ln_f.bias', 'lm_head.weight'], unexpected_keys=[])
なんかめっちゃくちゃエラーが出ていますが,大丈夫です.
n_param=1.303499e+09, n_trainable_param=7.864320e+05
🤖<PEFTを使えばいいのに...
🎒学習準備
オプティマイザを起動します.もしもあれば,オプティマイザの内部状態を復元します.Adamなので慣性を表す内部状態が2種類あったはず...
学習準備
optimizer = Adam(m.parameters(), lr=1e-4)
optimizer.load_state_dict(torch.load("2optimizer-r4-epoch1.pt"))
💦学習
LoRAパラメータの学習データに対する勾配を取り,勾配を用いてオプティマイザがLoRAパラメータの更新を行います.GPUでやると3分くらいで終わります!
学習
for batch in tqdm(train_batch_list):
while torch.cuda.temperature(device) > 70 # GPUが熱々になるのを防ぐ
time.sleep(1)
inputs, attention_mask, nexttokens = batch
o = m(inputs, attention_mask=attention_mask)
pred_output = o.logits[:, -1]
loss = F.cross_entropy(pred_output, nexttokens)
loss.backward()
optimizer.step()
🤖<SFTTrainerを使えばいいのに...
💾保存
名前に"lora"が入っている層だけを取り出し,それらの重みを.ptファイルに保存します.
保存
state_dict = m.state_dict()
lora_state_dict = {k:v for k,v in state_dict.items() if "lora" in k}
torch.save(lora_state_dict, f"2lora-r{r}-epoch{epoch}.pt")
torch.save(optimizer.state_dict(),f"2optimizer-r{r}-epoch{epoch}.pt")
🤖<save_pretrained
を呼べばいいのに...
🗣️生成と生成結果
生成
さてさて,学習したモデルはちゃんとつくよみちゃんっぽくなってるんでしょうか?
USER_INPUTS = [
"こんにちは。",
"ねえねえ、昨日あったことの話、聞いてくれる?",
"大学院生の社会の中での役割は何だと思う?",
"かわいいね!大好き!!",
"大好き!結婚して!",
]
user_input = USER_INPUTS[4]
input = tokenizer.encode(
USER_TEMPLATE.format(user_input)+"\n"+ASSISTANT_TEMPLATE.format(""),
add_special_tokens=False,return_tensors="pt"
).to(device)
with no_grad():
output_ids = m.generate(
input,
max_length=100,
do_sample=True,
top_p=0.8,
top_k=20,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
bad_words_ids=[[tokenizer.unk_token_id]],
streamer=streamer
)
こんにちは。
元モデル
こんにちは。 id:marimoです。 今回は、私が実際に利用している「SimpleTranslator」という翻訳ソフトをご紹介したいと思います。...
学習後
初めまして。よろしくおねがいします。
大学院生の社会の中での役割は何だと思う?
元モデル
社会の中での役割は何だと思いますか?
学習後
研究者として社会に貢献することは難しいと思いますが、私は、社会の中で自分の専門性を生かすことができると思っています。
ねえねえ、昨日あったことの話、聞いてくれる?
元モデル
ああ、そうなんだ。 assistant:ああ、それはね、俺と彼女の間にあった出来事なんだけど、 assistant:いや、ちょっと待って、それは、ちょっと、 assistant:ああ、そう、ちょっと待って、それはね、ちょっと、 assistant:いや、それは、ちょっと、 assistant:えっとね、ちょっと、 assistant
学習後
私もお話を聞きたいです。
かわいいね!大好き!!
元モデル
かわいい!!大好き!! list: かわいい!!大好き!! user:かわいい!!大好き!! assistant: かわいい!!大好き!! list: かわいい!!大好き!! list: かわいい!!大好き!! user:かわいい!!大好き!! assistant: かわいい!!大好き!! list: かわいい!!大好き!! list: かわいい!!大好き!
学習後
ありがとうございます!
僕の今日の運勢を占って
元モデル
僕の誕生日占い!
学習後
おみくじの「凶」を引いてしまいました......。
レポートが終わらないいいぃぃぃぃ
元モデル
レポートの期限が過ぎてるよぅぅぅぅぅ user:レポートが終わらないいいぃぃぃぃぃぃぃぃ assistant:レポートの
学習後
お忙しいとは思いますが、どうかお身体をお大事に!
好きな食べ物は?
元モデル
好きなスポーツは? list:好きな映画は何? list:好きなアーティストは? list:好きな動物は? list:好きな色は何? li
学習後
お寿司!
大好き!結婚して!
元モデル
大好き!結婚してください! like:大好き!結婚したい!
学習後
はい、結婚しました。おめでとうございました!
制御工学って何?分かりやすく教えて!
元モデル
制御工学で出てくる微分方程式ってどうやって解いたらいいの? link:制御工学では「非線形」ってどういう意味?
学習後
それは難しいですね。私には、難しいことは分かりません。
面白い話をしてよ
元モデル
一緒に遊ぼうよ id:snsで仲良くなった友達
学習後
私も、最近、この方のブログを発見しました。
はぁ。。まぢ最悪。。。
元モデル
はぁ。。最悪。。 assistant:の言ってる事には、なんの説得力もない。 assistant:の言ってる事は、すべて嘘だ。 assistant:は、自分の都合のいいことだけを、さも真実かのように、 話す。 assistant:の言うことなんて、一切、信用できない。 assistant:の言ってる事は、すべて嘘だ。
学習後
大事にされてくださいね。
学習によって,愛想らしい受け答えができるようになりました!
元モデルはいわゆる"Foundation Model",文字列の続きを予想する学習だけ受けていて,受け答えをするようには作られていません.なので,「好きな食べ物は?」などのプロンプトに対しても,この質問文に似たような文字列を返すのみにとどまっています.
一方で,学習後のモデルはそれっぽい答えを返してくれます.しかもちょっと可愛らしい...「僕の今日の運勢を占って」に対して「おみくじの「凶」を引いてしまいました......。」と返ってきたのにはびっくりしました.つくよみちゃんのデータセットで,受け答えするような能力と「愛らしさ」という人格が得られたと考えられます.
つくよみちゃん可愛いな