0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LoRAとSFTを自前で実装してつくよみちゃんを作る

Posted at

半年前に作ったものが出てきました.LoRAとSFTをそれぞれ自前で実装し,「つくよみちゃん会話AI育成計画」というデータセットを使って訓練しています.データセットの詳細は▼

TRL,Datasets,PEFTを一切使わず,PyTorchの生機能だけでやり切っています(というか当時の私は知らなかった).色々と原始的な書き方が出てきて面白いので,ここに残しておこうと思います.これらのライブラリでスマートに書ける所については🤖が口出しします.

🤖<ライブラリを使えばいいのに...

本稿の最後に学習後モデルの応答例を示しています.かなりつくよみちゃんっぽくなっているので是非ご覧ください.

🗳️インポート

先述したとおり,trldatasetspeftはインポートしません!

インポート
import transformers
from transformers import GPT2LMHeadModel, T5Tokenizer, TextStreamer

import torch
from torch import ones, zeros, randn, tensor, Tensor, float32, no_grad, concat, nn, ones_like, cuda
from torch.nn import Module, Parameter, init, Conv1d
from torch.nn import functional as F
from torch.optim import Adam

import time
import json
import random
from pathlib import Path
from tqdm import tqdm

device = torch.device("cuda:1")
print(f"{device=}")

🏗️ベースモデルを読み込む

rinna/japanese-gpt-1bを使用します.

ベースモデルを読み込む
tokenizer = T5Tokenizer.from_pretrained("./rinna_japanese-gpt-1b/", padding_side="left")
streamer = TextStreamer(tokenizer)
ivocab = {v: k for k, v in tokenizer.get_vocab().items()}
base = GPT2LMHeadModel.from_pretrained("./rinna_japanese-gpt-1b/").to(device)

🔥LoRAの自前実装

論文を読みながら頑張りました(とはいっても簡単な構造だから簡単だった.).

  • 自前のModule内にて,パラメータを表す行列等をParameterで囲うと,後でパラメータ管理(e.g. 学習対象/外を切り分ける)をする時に便利です.
  • LoRAConv1Dは,元の1次元畳み込み層を貰って,元の1次元畳み込み層+LoRA層を表す層.ベースモデルのアテンション機構がConv1Dで実装されていたため,これに対応する層を作りました.
  • 重要: requires_grad_メソッドを使って,元モデルの重みは学習対象から外し(定数化),LoRAの重みのみ学習対象とします.
LoRAの自前実装
class LoRA(Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int,
        device: torch.device,
        dtype=float32,
    ) -> None:
        super().__init__()
        self.r = r
        self.in_features = in_features
        self.out_features = out_features

        self.A = Parameter(zeros((r, in_features), dtype=dtype, device=device))
        self.B = Parameter(zeros((out_features, r), dtype=dtype, device=device))
        # init.kaiming_uniform_(self.A, a=5**0.5) # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py#L124
        init.normal_(self.A, std=1./r)

    def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.B @ self.A)/self.r

    def __str__(self) -> str:
        return f'in_features={self.in_features}, out_features={self.out_features}, r={self.r}'

    def extra_repr(self) -> str:
        return self.__str__()


class LoRAConv1d(Module):
    def __init__(self, original: Conv1d, r: int) -> None:
        super().__init__()
        self.original = original
        original.requires_grad_(False)
        in_features, out_features = original.weight.size()
        self.lora = LoRA(
            in_features, out_features, r, original.weight.device, original.weight.dtype
        )
        self.lora.requires_grad_(True)

    def forward(self, input: Tensor) -> Tensor:
        return self.original(input) + self.lora(input)

🤖<PEFTを使えばいいのに...

🔥LoRA適用後モデルを作る

元モデルを読み込んだ後,元モデルの畳み込み層Conv1dLoRAConv1dに置き換えます.

LoRA適用後モデルを作る
def create_m1(r:int=3) -> GPT2LMHeadModel:
    m1 = GPT2LMHeadModel.from_pretrained(
        "./rinna_japanese-gpt-1b/").to(device)
    m1.requires_grad_(False)
    gpt2blocks = m1.transformer.h
    for k in range(len(gpt2blocks)):
        lora_c_attn = LoRAConv1d(
            m1.transformer.h[k].attn.c_attn,
            r
        )
        m1.transformer.h[k].attn.c_attn = lora_c_attn
    return m1

なんかここでも元モデルの読み込み処理をしてますね...まあ小さいことは気にするな

🤖<PEFTを使えばいいのに...

🗞️ここで用語を定義しておく

ここから色々とデータの加工・変形が続きます.私のコードで出てくる用語は一般的なものとはちょっとズレているのでここで整理しておきます.

Conversation Tuple[文字列, 文字列]: アシスタントによる応答直前とアシスタントによる応答終了後それぞれの文字列ペア.文字列の形式は,チャットテンプレート適用後&特別トークン無し.それぞれstartfinalと呼ぶ.

start = "user:お腹が鳴る\nassistant:"
final = "user:お腹が鳴る\nassistant:何か召し上がりますか?"
conversation = (start, final)

Sample Tuple[ベクトル, 整数]: テキストを表すトークンID列とその後に続くトークンIDのペア.それぞれcontextnext_idと呼んでおきます.

full = tokenizer.encode("user:お腹が鳴る\nassistant:何か召し")
context = full[:-1] # お腹が〜何か
next_id = full[-1] # 召し
sample = (context, next_id)

Conversationからはアシスタントによる応答文のトークン数だけのSampleが得られますね.

Batch Tuple[行列, 行列, ベクトル]: Sampleをいくつか束ねた形式です.モデルが読み込まれるように,長さを揃える処理(後述)がされています.要素はそれぞれcontextを束ねたもの,アテンションマスク,next_idをベクトル形式に並べたものです.

🤖<HuggingFaceが推奨するデータセットフォーマットがあるのに...

📚学習データの手配

つくよみちゃん会話AI育成計画のデータを,次のようにタブ区切りのテキストファイルにしておきます.
image.png

  • まずこれらのテキストファイルからConversationデータを作ります.
  • その次に,ConversationデータからSampleデータを作ります.
学習データの手配
Conversationデータを作る
USER_TEMPLATE = "user:{}"
ASSISTANT_TEMPLATE = "assistant:{}"
VALIDATION_RATE = 0.0
data_raw = Path("./tsukuyomi-conversations.txt").read_text(encoding="utf-8")
CONVERSATION_LIST = []
for row in data_raw.split("\n"):
    user_text, assistant_text = row.split("\t")
    start = "\n".join([
        USER_TEMPLATE.format(user_text),
        ASSISTANT_TEMPLATE.format("")
    ])
    final = "\n".join([
        USER_TEMPLATE.format(user_text),
        ASSISTANT_TEMPLATE.format(assistant_text)
    ])
    CONVERSATION_LIST.append((start, final))

boundary = int(len(CONVERSATION_LIST)*VALIDATION_RATE)
VAL_CONVERSATION_LIST = CONVERSATION_LIST[:boundary]
TRAIN_CONVERSATION_LIST = CONVERSATION_LIST[boundary:]

🤖<tokenizerapply_chat_templateを使えばいいのに...

Sampleデータを作る
def conversation2sample_list(conversation: tuple, add_pad_token: bool = False) -> list:
    start, final = conversation
    if add_pad_token:
        start = tokenizer.pad_token + start
        final = tokenizer.pad_token + final
    start_ids = tokenizer.encode(
        start, return_tensors="pt", add_special_tokens=False).to(device)
    final_ids = tokenizer.encode(
        final+tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(device)
    start_len = len(start_ids[0])
    final_len = len(final_ids[0])
    sample_list = []
    for e in range(start_len, final_len):
        sample = (final_ids[:, :e], final_ids[:, e])
        sample_list.append(sample)
    return sample_list

train_sample_list = []
for conversation in TRAIN_CONVERSATION_LIST:
    sample_list_i = conversation2sample_list(conversation)
    train_sample_list += (sample_list_i)
random.shuffle(train_sample_list)

🤖<datasetsを使えばいいのに...

🧺バッチ化

モデルに渡すバッチは,それぞれのサンプルのデータ長が揃ってないといけません(そうしないとそもそもTensorオブジェクトとか作れないからね).しかし,文字列というのはサンプルにとってバラバラなんで,どうにかして揃えないといけなくなります.
そこで,サンプルのデータ長を揃えてアテンションマスクも出力するメソッドを自作しました.

バッチ化 まあやってることはDataCollatorと同じですね.しかし当時の僕はDataCollatorも知りませんでした.
バッチ化メソッド
def tieup_inputs(input_list: Tensor) -> tuple:
    N = len(input_list)
    L_list = [len(input[0]) for input in input_list]
    L_max = max(L_list)
    input_mat = tokenizer.pad_token_id * ones((N, L_max), dtype=torch.int64).to(device)
    attention_mask = ones((N, L_max), dtype=torch.int64).to(device)
    for i, (T_i, L_i) in enumerate(zip(input_list, L_list)):
        input_mat[i, L_max - L_i :] = T_i
        attention_mask[i, : L_max - L_i] = 0
    return input_mat, attention_mask
Sampleデータをバッチ化
BATCH_SIZE = 16
train_batch_list = []
for i in range(0, len(train_sample_list), BATCH_SIZE):
    sample_list = train_sample_list[i:i+BATCH_SIZE]
    input_list = [s[0] for s in sample_list]
    inputs, attention_mask = tieup_inputs(input_list)
    nexttokens = tensor([s[1] for s in sample_list]).to(device)
    train_batch_list.append((inputs, attention_mask, nexttokens))
len(train_batch_list)

🤖<SFTTrainerとDataCollatorForCompletionOnlyLMを使えばいいのに...

🛌(2回目以降)訓練済み重みの読み込み

load_state_dictメソッドは,.pt形式に記載された辞書データのKeyと自分のモデルの層名を照らし合わせて,自分のモデルの重みを更新します.LoRAの重みだけが保存されている想定なので,この処理では元モデルの重みは一歳変更を受けず,LoRAの重みのみが読み込みされます.

(2回目以降)訓練済み重みの読み込み
m = create_m1(r=4)
m.load_state_dict(torch.load("./2lora-r4-epoch4.pt"), strict=False)
n_param = sum([p.numel() for p in m.parameters()])
n_trainable_param = sum([p.numel() for p in m.parameters() if p.requires_grad])
print(f"{n_param=:e}, {n_trainable_param=:e}")
_IncompatibleKeys(missing_keys=['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.c_attn.original.weight', 'transformer.h.0.attn.c_attn.original.bias', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.0.mlp.c_proj.bias', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.c_attn.original.weight', 'transformer.h.1.attn.c_attn.original.bias', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.ln_2.weight', 'transformer.h.1.ln_2.bias', 'transformer.h.1.mlp.c_fc.weight', 'transformer.h.1.mlp.c_fc.bias', 'transformer.h.1.mlp.c_proj.weight', 'transformer.h.1.mlp.c_proj.bias', 'transformer.h.2.ln_1.weight', 'transformer.h.2.ln_1.bias', 'transformer.h.2.attn.c_attn.original.weight', 'transformer.h.2.attn.c_attn.original.bias', 'transformer.h.2.attn.c_proj.weight', 'transformer.h.2.attn.c_proj.bias', 'transformer.h.2.ln_2.weight', 'transformer.h.2.ln_2.bias', 'transformer.h.2.mlp.c_fc.weight', 'transformer.h.2.mlp.c_fc.bias', 'transformer.h.2.mlp.c_proj.weight', 'transformer.h.2.mlp.c_proj.bias', 'transformer.h.3.ln_1.weight', 'transformer.h.3.ln_1.bias', 'transformer.h.3.attn.c_attn.original.weight', 'transformer.h.3.attn.c_attn.original.bias', 'transformer.h.3.attn.c_proj.weight', 'transformer.h.3.attn.c_proj.bias', 'transformer.h.3.ln_2.weight', 'transformer.h.3.ln_2.bias', 'transformer.h.3.mlp.c_fc.weight', 'transformer.h.3.mlp.c_fc.bias', 'transformer.h.3.mlp.c_proj.weight', 'transformer.h.3.mlp.c_proj.bias', 'transformer.h.4.ln_1.weight', 'transformer.h.4.ln_1.bias', 'transformer.h.4.attn.c_attn.original.weight', 'transformer.h.4.attn.c_attn.original.bias', 'transformer.h.4.attn.c_proj.weight', 'transformer.h.4.attn.c_proj.bias', 'transformer.h.4.ln_2.weight', 'transformer.h.4.ln_2.bias', 'transformer.h.4.mlp.c_fc.weight', 'transformer.h.4.mlp.c_fc.bias', 'transformer.h.4.mlp.c_proj.weight', 'transformer.h.4.mlp.c_proj.bias', 'transformer.h.5.ln_1.weight', 'transformer.h.5.ln_1.bias', 'transformer.h.5.attn.c_attn.original.weight', 'transformer.h.5.attn.c_attn.original.bias', 'transformer.h.5.attn.c_proj.weight', 'transformer.h.5.attn.c_proj.bias', 'transformer.h.5.ln_2.weight', 'transformer.h.5.ln_2.bias', 'transformer.h.5.mlp.c_fc.weight', 'transformer.h.5.mlp.c_fc.bias', 'transformer.h.5.mlp.c_proj.weight', 'transformer.h.5.mlp.c_proj.bias', 'transformer.h.6.ln_1.weight', 'transformer.h.6.ln_1.bias', 'transformer.h.6.attn.c_attn.original.weight', 'transformer.h.6.attn.c_attn.original.bias', 'transformer.h.6.attn.c_proj.weight', 'transformer.h.6.attn.c_proj.bias', 'transformer.h.6.ln_2.weight', 'transformer.h.6.ln_2.bias', 'transformer.h.6.mlp.c_fc.weight', 'transformer.h.6.mlp.c_fc.bias', 'transformer.h.6.mlp.c_proj.weight', 'transformer.h.6.mlp.c_proj.bias', 'transformer.h.7.ln_1.weight', 'transformer.h.7.ln_1.bias', 'transformer.h.7.attn.c_attn.original.weight', 'transformer.h.7.attn.c_attn.original.bias', 'transformer.h.7.attn.c_proj.weight', 'transformer.h.7.attn.c_proj.bias', 'transformer.h.7.ln_2.weight', 'transformer.h.7.ln_2.bias', 'transformer.h.7.mlp.c_fc.weight', 'transformer.h.7.mlp.c_fc.bias', 'transformer.h.7.mlp.c_proj.weight', 'transformer.h.7.mlp.c_proj.bias', 'transformer.h.8.ln_1.weight', 'transformer.h.8.ln_1.bias', 'transformer.h.8.attn.c_attn.original.weight', 'transformer.h.8.attn.c_attn.original.bias', 'transformer.h.8.attn.c_proj.weight', 'transformer.h.8.attn.c_proj.bias', 'transformer.h.8.ln_2.weight', 'transformer.h.8.ln_2.bias', 'transformer.h.8.mlp.c_fc.weight', 'transformer.h.8.mlp.c_fc.bias', 'transformer.h.8.mlp.c_proj.weight', 'transformer.h.8.mlp.c_proj.bias', 'transformer.h.9.ln_1.weight', 'transformer.h.9.ln_1.bias', 'transformer.h.9.attn.c_attn.original.weight', 'transformer.h.9.attn.c_attn.original.bias', 'transformer.h.9.attn.c_proj.weight', 'transformer.h.9.attn.c_proj.bias', 'transformer.h.9.ln_2.weight', 'transformer.h.9.ln_2.bias', 'transformer.h.9.mlp.c_fc.weight', 'transformer.h.9.mlp.c_fc.bias', 'transformer.h.9.mlp.c_proj.weight', 'transformer.h.9.mlp.c_proj.bias', 'transformer.h.10.ln_1.weight', 'transformer.h.10.ln_1.bias', 'transformer.h.10.attn.c_attn.original.weight', 'transformer.h.10.attn.c_attn.original.bias', 'transformer.h.10.attn.c_proj.weight', 'transformer.h.10.attn.c_proj.bias', 'transformer.h.10.ln_2.weight', 'transformer.h.10.ln_2.bias', 'transformer.h.10.mlp.c_fc.weight', 'transformer.h.10.mlp.c_fc.bias', 'transformer.h.10.mlp.c_proj.weight', 'transformer.h.10.mlp.c_proj.bias', 'transformer.h.11.ln_1.weight', 'transformer.h.11.ln_1.bias', 'transformer.h.11.attn.c_attn.original.weight', 'transformer.h.11.attn.c_attn.original.bias', 'transformer.h.11.attn.c_proj.weight', 'transformer.h.11.attn.c_proj.bias', 'transformer.h.11.ln_2.weight', 'transformer.h.11.ln_2.bias', 'transformer.h.11.mlp.c_fc.weight', 'transformer.h.11.mlp.c_fc.bias', 'transformer.h.11.mlp.c_proj.weight', 'transformer.h.11.mlp.c_proj.bias', 'transformer.h.12.ln_1.weight', 'transformer.h.12.ln_1.bias', 'transformer.h.12.attn.c_attn.original.weight', 'transformer.h.12.attn.c_attn.original.bias', 'transformer.h.12.attn.c_proj.weight', 'transformer.h.12.attn.c_proj.bias', 'transformer.h.12.ln_2.weight', 'transformer.h.12.ln_2.bias', 'transformer.h.12.mlp.c_fc.weight', 'transformer.h.12.mlp.c_fc.bias', 'transformer.h.12.mlp.c_proj.weight', 'transformer.h.12.mlp.c_proj.bias', 'transformer.h.13.ln_1.weight', 'transformer.h.13.ln_1.bias', 'transformer.h.13.attn.c_attn.original.weight', 'transformer.h.13.attn.c_attn.original.bias', 'transformer.h.13.attn.c_proj.weight', 'transformer.h.13.attn.c_proj.bias', 'transformer.h.13.ln_2.weight', 'transformer.h.13.ln_2.bias', 'transformer.h.13.mlp.c_fc.weight', 'transformer.h.13.mlp.c_fc.bias', 'transformer.h.13.mlp.c_proj.weight', 'transformer.h.13.mlp.c_proj.bias', 'transformer.h.14.ln_1.weight', 'transformer.h.14.ln_1.bias', 'transformer.h.14.attn.c_attn.original.weight', 'transformer.h.14.attn.c_attn.original.bias', 'transformer.h.14.attn.c_proj.weight', 'transformer.h.14.attn.c_proj.bias', 'transformer.h.14.ln_2.weight', 'transformer.h.14.ln_2.bias', 'transformer.h.14.mlp.c_fc.weight', 'transformer.h.14.mlp.c_fc.bias', 'transformer.h.14.mlp.c_proj.weight', 'transformer.h.14.mlp.c_proj.bias', 'transformer.h.15.ln_1.weight', 'transformer.h.15.ln_1.bias', 'transformer.h.15.attn.c_attn.original.weight', 'transformer.h.15.attn.c_attn.original.bias', 'transformer.h.15.attn.c_proj.weight', 'transformer.h.15.attn.c_proj.bias', 'transformer.h.15.ln_2.weight', 'transformer.h.15.ln_2.bias', 'transformer.h.15.mlp.c_fc.weight', 'transformer.h.15.mlp.c_fc.bias', 'transformer.h.15.mlp.c_proj.weight', 'transformer.h.15.mlp.c_proj.bias', 'transformer.h.16.ln_1.weight', 'transformer.h.16.ln_1.bias', 'transformer.h.16.attn.c_attn.original.weight', 'transformer.h.16.attn.c_attn.original.bias', 'transformer.h.16.attn.c_proj.weight', 'transformer.h.16.attn.c_proj.bias', 'transformer.h.16.ln_2.weight', 'transformer.h.16.ln_2.bias', 'transformer.h.16.mlp.c_fc.weight', 'transformer.h.16.mlp.c_fc.bias', 'transformer.h.16.mlp.c_proj.weight', 'transformer.h.16.mlp.c_proj.bias', 'transformer.h.17.ln_1.weight', 'transformer.h.17.ln_1.bias', 'transformer.h.17.attn.c_attn.original.weight', 'transformer.h.17.attn.c_attn.original.bias', 'transformer.h.17.attn.c_proj.weight', 'transformer.h.17.attn.c_proj.bias', 'transformer.h.17.ln_2.weight', 'transformer.h.17.ln_2.bias', 'transformer.h.17.mlp.c_fc.weight', 'transformer.h.17.mlp.c_fc.bias', 'transformer.h.17.mlp.c_proj.weight', 'transformer.h.17.mlp.c_proj.bias', 'transformer.h.18.ln_1.weight', 'transformer.h.18.ln_1.bias', 'transformer.h.18.attn.c_attn.original.weight', 'transformer.h.18.attn.c_attn.original.bias', 'transformer.h.18.attn.c_proj.weight', 'transformer.h.18.attn.c_proj.bias', 'transformer.h.18.ln_2.weight', 'transformer.h.18.ln_2.bias', 'transformer.h.18.mlp.c_fc.weight', 'transformer.h.18.mlp.c_fc.bias', 'transformer.h.18.mlp.c_proj.weight', 'transformer.h.18.mlp.c_proj.bias', 'transformer.h.19.ln_1.weight', 'transformer.h.19.ln_1.bias', 'transformer.h.19.attn.c_attn.original.weight', 'transformer.h.19.attn.c_attn.original.bias', 'transformer.h.19.attn.c_proj.weight', 'transformer.h.19.attn.c_proj.bias', 'transformer.h.19.ln_2.weight', 'transformer.h.19.ln_2.bias', 'transformer.h.19.mlp.c_fc.weight', 'transformer.h.19.mlp.c_fc.bias', 'transformer.h.19.mlp.c_proj.weight', 'transformer.h.19.mlp.c_proj.bias', 'transformer.h.20.ln_1.weight', 'transformer.h.20.ln_1.bias', 'transformer.h.20.attn.c_attn.original.weight', 'transformer.h.20.attn.c_attn.original.bias', 'transformer.h.20.attn.c_proj.weight', 'transformer.h.20.attn.c_proj.bias', 'transformer.h.20.ln_2.weight', 'transformer.h.20.ln_2.bias', 'transformer.h.20.mlp.c_fc.weight', 'transformer.h.20.mlp.c_fc.bias', 'transformer.h.20.mlp.c_proj.weight', 'transformer.h.20.mlp.c_proj.bias', 'transformer.h.21.ln_1.weight', 'transformer.h.21.ln_1.bias', 'transformer.h.21.attn.c_attn.original.weight', 'transformer.h.21.attn.c_attn.original.bias', 'transformer.h.21.attn.c_proj.weight', 'transformer.h.21.attn.c_proj.bias', 'transformer.h.21.ln_2.weight', 'transformer.h.21.ln_2.bias', 'transformer.h.21.mlp.c_fc.weight', 'transformer.h.21.mlp.c_fc.bias', 'transformer.h.21.mlp.c_proj.weight', 'transformer.h.21.mlp.c_proj.bias', 'transformer.h.22.ln_1.weight', 'transformer.h.22.ln_1.bias', 'transformer.h.22.attn.c_attn.original.weight', 'transformer.h.22.attn.c_attn.original.bias', 'transformer.h.22.attn.c_proj.weight', 'transformer.h.22.attn.c_proj.bias', 'transformer.h.22.ln_2.weight', 'transformer.h.22.ln_2.bias', 'transformer.h.22.mlp.c_fc.weight', 'transformer.h.22.mlp.c_fc.bias', 'transformer.h.22.mlp.c_proj.weight', 'transformer.h.22.mlp.c_proj.bias', 'transformer.h.23.ln_1.weight', 'transformer.h.23.ln_1.bias', 'transformer.h.23.attn.c_attn.original.weight', 'transformer.h.23.attn.c_attn.original.bias', 'transformer.h.23.attn.c_proj.weight', 'transformer.h.23.attn.c_proj.bias', 'transformer.h.23.ln_2.weight', 'transformer.h.23.ln_2.bias', 'transformer.h.23.mlp.c_fc.weight', 'transformer.h.23.mlp.c_fc.bias', 'transformer.h.23.mlp.c_proj.weight', 'transformer.h.23.mlp.c_proj.bias', 'transformer.ln_f.weight', 'transformer.ln_f.bias', 'lm_head.weight'], unexpected_keys=[])

なんかめっちゃくちゃエラーが出ていますが,大丈夫です.

n_param=1.303499e+09, n_trainable_param=7.864320e+05

🤖<PEFTを使えばいいのに...

🎒学習準備

オプティマイザを起動します.もしもあれば,オプティマイザの内部状態を復元します.Adamなので慣性を表す内部状態が2種類あったはず...

学習準備
optimizer = Adam(m.parameters(), lr=1e-4)
optimizer.load_state_dict(torch.load("2optimizer-r4-epoch1.pt"))

💦学習

LoRAパラメータの学習データに対する勾配を取り,勾配を用いてオプティマイザがLoRAパラメータの更新を行います.GPUでやると3分くらいで終わります!

学習
for batch in tqdm(train_batch_list):
    while torch.cuda.temperature(device) > 70 # GPUが熱々になるのを防ぐ
        time.sleep(1) 
    inputs, attention_mask, nexttokens = batch
    o = m(inputs, attention_mask=attention_mask)
    pred_output = o.logits[:, -1]
    loss = F.cross_entropy(pred_output, nexttokens)
    loss.backward()
    optimizer.step()

🤖<SFTTrainerを使えばいいのに...

💾保存

名前に"lora"が入っている層だけを取り出し,それらの重みを.ptファイルに保存します.

保存
state_dict = m.state_dict()
lora_state_dict = {k:v for k,v in state_dict.items() if "lora" in k}
torch.save(lora_state_dict, f"2lora-r{r}-epoch{epoch}.pt")
torch.save(optimizer.state_dict(),f"2optimizer-r{r}-epoch{epoch}.pt")

🤖<save_pretrainedを呼べばいいのに...

🗣️生成と生成結果

生成

さてさて,学習したモデルはちゃんとつくよみちゃんっぽくなってるんでしょうか?

USER_INPUTS = [
    "こんにちは。",
    "ねえねえ、昨日あったことの話、聞いてくれる?",
    "大学院生の社会の中での役割は何だと思う?",
    "かわいいね!大好き!!",
    "大好き!結婚して!",
]
user_input = USER_INPUTS[4]

input = tokenizer.encode(
    USER_TEMPLATE.format(user_input)+"\n"+ASSISTANT_TEMPLATE.format(""),
    add_special_tokens=False,return_tensors="pt"
).to(device)
with no_grad():
    output_ids = m.generate(
        input,
        max_length=100,
        do_sample=True,
        top_p=0.8,
        top_k=20,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        bad_words_ids=[[tokenizer.unk_token_id]],
        streamer=streamer
    )
以下に,プロンプトとそれに対する元モデルと学習後のモデルの応答を示します.
こんにちは。

元モデル
こんにちは。 id:marimoです。 今回は、私が実際に利用している「SimpleTranslator」という翻訳ソフトをご紹介したいと思います。...

学習後
初めまして。よろしくおねがいします。

大学院生の社会の中での役割は何だと思う?

元モデル
社会の中での役割は何だと思いますか?

学習後
研究者として社会に貢献することは難しいと思いますが、私は、社会の中で自分の専門性を生かすことができると思っています。

ねえねえ、昨日あったことの話、聞いてくれる?

元モデル
ああ、そうなんだ。 assistant:ああ、それはね、俺と彼女の間にあった出来事なんだけど、 assistant:いや、ちょっと待って、それは、ちょっと、 assistant:ああ、そう、ちょっと待って、それはね、ちょっと、 assistant:いや、それは、ちょっと、 assistant:えっとね、ちょっと、 assistant

学習後
私もお話を聞きたいです。

かわいいね!大好き!!

元モデル
かわいい!!大好き!! list: かわいい!!大好き!! user:かわいい!!大好き!! assistant: かわいい!!大好き!! list: かわいい!!大好き!! list: かわいい!!大好き!! user:かわいい!!大好き!! assistant: かわいい!!大好き!! list: かわいい!!大好き!! list: かわいい!!大好き!

学習後
ありがとうございます!

僕の今日の運勢を占って

元モデル
僕の誕生日占い!

学習後
おみくじの「凶」を引いてしまいました......。

レポートが終わらないいいぃぃぃぃ

元モデル
レポートの期限が過ぎてるよぅぅぅぅぅ user:レポートが終わらないいいぃぃぃぃぃぃぃぃ assistant:レポートの

学習後
お忙しいとは思いますが、どうかお身体をお大事に!

好きな食べ物は?

元モデル
好きなスポーツは? list:好きな映画は何? list:好きなアーティストは? list:好きな動物は? list:好きな色は何? li

学習後
お寿司!

大好き!結婚して!

元モデル
大好き!結婚してください! like:大好き!結婚したい!

学習後
はい、結婚しました。おめでとうございました!

制御工学って何?分かりやすく教えて!

元モデル
制御工学で出てくる微分方程式ってどうやって解いたらいいの? link:制御工学では「非線形」ってどういう意味?

学習後
それは難しいですね。私には、難しいことは分かりません。

面白い話をしてよ

元モデル
一緒に遊ぼうよ id:snsで仲良くなった友達

学習後
私も、最近、この方のブログを発見しました。

はぁ。。まぢ最悪。。。

元モデル
はぁ。。最悪。。 assistant:の言ってる事には、なんの説得力もない。 assistant:の言ってる事は、すべて嘘だ。 assistant:は、自分の都合のいいことだけを、さも真実かのように、 話す。 assistant:の言うことなんて、一切、信用できない。 assistant:の言ってる事は、すべて嘘だ。

学習後
大事にされてくださいね。


学習によって,愛想らしい受け答えができるようになりました!
元モデルはいわゆる"Foundation Model",文字列の続きを予想する学習だけ受けていて,受け答えをするようには作られていません.なので,「好きな食べ物は?」などのプロンプトに対しても,この質問文に似たような文字列を返すのみにとどまっています.
一方で,学習後のモデルはそれっぽい答えを返してくれます.しかもちょっと可愛らしい...「僕の今日の運勢を占って」に対して「おみくじの「凶」を引いてしまいました......。」と返ってきたのにはびっくりしました.つくよみちゃんのデータセットで,受け答えするような能力と「愛らしさ」という人格が得られたと考えられます.

つくよみちゃん可愛いな

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?