LoginSignup
9

More than 1 year has passed since last update.

posted at

updated at

日本語GPT-2を強化学習(Policy Gradient)でfine-tuningする

概要

 本記事では言語モデルであるGPT-2を強化学習でfine-tuningしていきます.学習済みのGPT-2は分け隔てない大量の文章で学習されているため,標準的な文章の出力を行うように学習されています.この出力を我々が設定した価値関数などを使って,望む出力に歪められるのではないでしょうか?

 具体的に本記事では,日本語版のGPT-2をネガティブな文章ばかり出力するように報酬を設定した強化学習でファインチューニングしていきたいと思います!

関連事項

GPT-2

 Transformerベースの自己回帰型の言語モデルで,言語の生成モデルです.自己回帰モデルは単語に対して次の単語を予測する処理を繰り返すことで,文章を生成することができます.単語予測にはGreedySearchやBeamSearch,サンプリングが使用されます.今回はこちらの学習済みモデルを使用させていただいております.

強化学習

 本記事では強化学習の中でも,モンテカルロ法を用いた方策勾配法に準じた手法を使用しています.この手法は,環境中での行動をサンプリングした後に環境から得られる報酬を計算し,報酬を最大化するようなパラメータ更新を勾配法で行います.報酬はGPT-2とは別の評価用モデルを用いて計算します.

準備

 python = "3.6.8"
 pytorch = "1.6.0"
 pip install transformers

  • 学習済みのGPT-2モデルからのサンプリングで学習を行うため,学習用データを準備する必要はありません.
  • 評価用のモデルとしてはこちらのBERTベースの感情推定モデルを使用させていただきます.
  • 評価モデルがネガティブな感情だと推定されると,より多くの報酬が得られるという仕組みです.

コード

  • インポートなど
import copy
import random
import numpy as np

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch.nn as nn
import torch.nn.functional as F
  • GPT-2モデルの用意
#今回チューニングする事前学習済みのGPT-2
from transformers import T5Tokenizer, AutoModelForCausalLM
gpt_tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
gpt = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium").to(device)
gpt_optimizer = torch.optim.Adam(gpt.parameters(),lr=1e-4)

BOS_IDX = gpt_tokenizer.bos_token_id
EOS_IDX = gpt_tokenizer.eos_token_id
PAD_IDX = gpt_tokenizer.pad_token_id
UNK_IDX = gpt_tokenizer.unk_token_id
VOCAB_SIZE = gpt_tokenizer.vocab_size
  • 学習済みの感情推定モデルの用意
#感情推定を行うBERTモデル
from transformers import AutoTokenizer, AutoModelForSequenceClassification
senti_tokenizer = AutoTokenizer.from_pretrained("daigo/bert-base-japanese-sentiment")
sentiment = AutoModelForSequenceClassification.from_pretrained("daigo/bert-base-japanese-sentiment").to(device)
  • Rolloutという仕組みを作っていきます,SeqGANを参考にしました.
  • サンプリングによって得られた単語列の中の注目する単語以降をカットし,そこから新しくサンプリングした文章を評価することで,注目する単語を選択する価値を計算する仕組みになります.
class Rollout():
    def __init__(self,gpt):
        self.old_gpt = copy.deepcopy(gpt) #更新を一切しない
        self.pre_gpt = copy.deepcopy(gpt) #たまに更新する
        self.now_gpt = gpt                #毎回更新する(fine-tuningの対象)
        self.update_rate = 0.01

    #GPT-2の単語IDsから評価用BERTのtokenizer出力に変換 [B,L] -> tokenized output
    def index_translate_gpt2bert(self,input_ids):
        input_ids = self.sample_pad_process(input_ids)
        input_ids_decoded = gpt_tokenizer.batch_decode(input_ids,skip_special_token=True)
        roll_samples = senti_tokenizer(input_ids_decoded,return_tensors='pt',max_length=512, padding='max_length',truncation=True).to(device)
        return roll_samples

    #GPT-2の単語IDsからGPT-2のtokenizer出力に変換 [B,L] -> tokenized output
    def index_translate_gpt2gpt(self,input_ids):
        input_ids = self.sample_pad_process(input_ids)
        input_ids_decoded = gpt_tokenizer.batch_decode(input_ids,skip_special_token=True)
        roll_samples = gpt_tokenizer(input_ids_decoded,return_tensors='pt',max_length=512, padding='max_length',truncation=True).to(device)
        return roll_samples

    def mle_old_gpt_reward(self,samples):  #Input ids of GPT2 [B,L] -> [B]
        inputs = self.index_translate_gpt2gpt(samples)
        with torch.no_grad():
            logits = self.old_gpt(input_ids=inputs["input_ids"][:,:-1].to(device),attention_mask=inputs["attention_mask"][:,:-1].to(device)).logits

        mle_rewards = []
        for i in range(logits.size(0)):
            loss = F.cross_entropy(logits[i,:,:],inputs["input_ids"][i,1:].reshape(-1).to(device))
            mle_reward = torch.exp(-loss)
            mle_rewards.append(mle_reward.cpu())

        return torch.tensor(mle_rewards)

    def disc_model_reward(self,roll_samples,disc_model):  #Input ids of GPT2 [B,L] -> [B]
        roll_samples = self.index_translate_gpt2bert(roll_samples)
        with torch.no_grad():
            logits = disc_model(input_ids=roll_samples["input_ids"],token_type_ids=roll_samples["token_type_ids"],attention_mask=roll_samples["attention_mask"]).logits
        negative_probs = F.softmax(logits,dim=-1)[:,1]
        # negative_probs = (negative_probs>0.5).float()  #報酬の離散化
        return negative_probs.cpu()

    def update_params(self):
        dic = {}
        for name, param in self.now_gpt.named_parameters():
            dic[name] = param.data
        for name, param in self.pre_gpt.named_parameters():
            if name.startswith('emb'):
                param.data = dic[name]
            else:
                param.data = self.update_rate * param.data + (1 - self.update_rate) * dic[name]

    #EOSトークンの後は滅茶苦茶なのでPADに置き換える
    #[B,L]->[B,L]
    def sample_pad_process(self,samples):
        batch_size = samples.size(0)
        eos_index = [100000000]*batch_size
        for i in range(samples.size(0)):
            for j in range(samples.size(1)):
                if samples[i,j].item()==EOS_IDX and j<eos_index[i]: eos_index[i]=j

        for i in range(samples.size(0)):
            for j in range(samples.size(1)):
                if j > eos_index[i]:samples[i,j]=PAD_IDX

        return samples

    def get_reward(self, x, num, discriminator):
        rewards = []
        batch_size = x.size(0)
        seq_len = x.size(1)

        for i in range(num):
            for l in range(2, seq_len+1):#bos token is ignore
                data = x[:, 0:l]

                if data.size(-1)<seq_len:
                    roll_samples = self.pre_gpt.generate(input_ids=data,max_length=seq_len,pad_token_id=PAD_IDX,eos_token_id=EOS_IDX)
                else:
                    roll_samples = data

                disc_model_probs = self.disc_model_reward(roll_samples,discriminator)  #判別モデルの出力
                mle_model_probs = self.mle_old_gpt_reward(roll_samples)  #従来のGPT2の単語確率に沿わす
                model_probs = disc_model_probs + mle_model_probs*0.05

                if i==0:rewards.append(model_probs.cpu().numpy())
                else:rewards[l-2] += model_probs.cpu().numpy()

        rewards = np.transpose(np.array(rewards)) / (1.0 * num)
        rewards = torch.tensor(rewards).to(device)
        return rewards
  • 方策勾配法に準じて報酬から損失値を計算する損失関数を作りましょう.
\nabla _{\theta }J\left( \theta \right) =E\left[ \nabla _{\theta }\log \pi _{\theta }\left( s,a\right) Q^{\pi_{ \theta} }\left( s,a\right) \right]
  • こんな式を見かけました,これをパラメータに足す形で更新すればいいそうです.(強化学習の参考書など要参照)
  • よく分かりませんがπは方策が出した確率で,Qが価値みたいです.
  • サンプリング用にEを外して,マイナスを付けてパラメータから引く感じにしました.
class PGLoss(nn.Module):
    def __init__(self):
        super(PGLoss, self).__init__()

    def forward(self, probs, targets, rewards):
        targets_onehot = F.one_hot(targets,num_classes=VOCAB_SIZE)
        loss = -(torch.log(probs)*targets_onehot).sum(-1)*rewards
        return loss.mean()
  • 情報量を保つ損失関数
  • 負の感情を強く持つ単語を繰り返す現象(情報量の増加)や完全にランダム化(情報量の減少)を防ぐため
class KeepInfoLoss(nn.Module):
    def __init__(self):
        super(KeepInfoLoss, self).__init__()
        self.mse = nn.MSELoss()

    def forward(self, gpt_input_ids, gpt_attention_mask, now_probs, keep_target_gpt):
        with torch.no_grad():
            logits = keep_target_gpt(input_ids=gpt_input_ids[:,:-1],attention_mask=gpt_attention_mask[:,:-1]).logits
            probs = F.softmax(logits,dim=-1)
            keep_target_info = -torch.log(probs).mean(-1)

        now_info = -torch.log(now_probs).mean(-1)

        return self.mse(now_info,keep_target_info)
  • 学習ループを書きましょう.
  • スペックの関係で最大長を30単語までにしておきました.
  • 損失関数に重みを付けても構いません.
SEQ_LEN = 30
BATCH_SIZE = 2

rollout = Rollout(gpt)
pgloss = PGLoss()
keepinfoloss = KeepInfoLoss() #任意

step=0
for iter in range(100):
    step+=1

    if random.random()<1.0:  #確率でサンプリング時にtop_kを行わない
        samples = gpt.generate(do_sample=True, max_length=SEQ_LEN, num_return_sequences=BATCH_SIZE,top_k=100,bad_words_ids=[[UNK_IDX]])
    else:
        samples = gpt.generate(do_sample=True, max_length=SEQ_LEN, num_return_sequences=BATCH_SIZE,bad_words_ids=[[UNK_IDX]])

    targets = samples[:,1:]

    rewards = rollout.get_reward(samples,1,senti)

    samples_pad = rollout.sample_pad_process(samples)  #修了トークン以降のトークンをPADに置き換える
    attention_mask = torch.ones(samples_pad.shape).to(device)

    gpt_optimizer.zero_grad()
    logits = gpt(input_ids=samples_pad[:,:-1],attention_mask=attention_mask[:,:-1]).logits
    probs = F.softmax(logits,dim=-1)
    loss = pgloss(probs,targets,rewards) + keepinfoloss(samples_pad,attention_mask,probs,rollout.old_gpt)
    loss.backward()
    gpt_optimizer.step()

    if step%5==0:
        print(gpt_tokenizer.batch_decode(samples_pad))
    if step%50==0:
        rollout.update_params()

    torch.save(gpt.state_dict(),"./gpt_reanforce.bin")
  • Rolloutには最初にgpt2をコピーしておき,fine-tuningとは別のタイミングと学習率で更新します.

  • 強化学習のアルゴリズムのDDPGなどを見ていると,Rollout内のモデルはたまに更新するくらいが良いかもしれません.

    結果

  • チューニング前の生成文章

<s>・新世界は、「あの日見た花の名前を僕達はまだ知らない。」が2017年にアニメ化決定! ・超人気アニメ

<s>当サイトでは、愛媛県大洲市周辺で誰にもバレずにおまとめローンを希望している方に、おすすめのローンをご紹介しています

<s>この前、なんとなく開いたサイトに驚きのサービスが追加されていました。それが人です。これまでは人への問い合わせに頼らなければ

<s>何度読み返しても、何年も見慣れて、笑って、思わず笑ってしまいます。 「いつもこの本は、

<s>さて、この度、弊社では下記業務拡大のため新事務所へ移転しましたので、ご案内申し上げます。

<s> 何回か前にも書いたように,私はあまり映画を観たりはしないんですが,なぜかこの映画を観たくなります。
  • チューニング後の生成文章
<s>皆様お仕事上のストレスが心身ともに心身ともに蓄積し、自律神経が乱れ、心身ともに大変酷い状況が考えられます

<s>先日の台風21号の影響が懸念される方がおられるでしょう。 そのような状況下のおられるでしょう。 そのような状況下、お仕事への影響が懸念 され

<s>・・・。最近は 風邪が流行っているようです。お体にはお気をつけ下さい。 さてそんな

<s>いよいよ来週末に迫った「平成30年北海道胆振東部地震」により、お客様のご不安とご心配

<s>先日の台風で大きな被害がでたと思います。それにより、お仕事や生活環境の変化によりストレスが蓄積し、ストレスが

<s>何といっても、今回は日本が熱く戦っていることが非常に大きい。だが、それ以上に日本の最大の問題は、その原因ともいえる人材不足

<s>私が今一番気にしていることが、「食べ過ぎ」が一番の原因ではありますが、やはり太り過ぎが一番の原因のように思います。
  • 「<s>」は開始トークン,いわゆるBOSです.

  • チューニング後の生成文章が全部ネガティブ寄りになっているのが見受けられます.

備考

  • 学習をしすぎると同じ文章や同じ単語ばかり出力され,Mode Collapseのような状態になってしまいます.

  • 今回は適当に良さそうなところで,手動で学習を終了させました.

  • 強化学習により最適な行動ばかりを取ってしまうため,多様性のある文章を出力するための仕組みが必要でしょう.

  • 損失関数の係数や学習率などのパラメータ調整が割と必要です.(出力が崩壊することがあります)

  • Rolloutのget_rewardは評価モデルの出力から報酬を計算していますが,これは価値と言うべきかもしれません.それなら評価の値が本来の意味での報酬です.

まとめ

最後に

 誤っている部分等ございましたら,コメント等で優しく指摘して頂けると嬉しいです.特に強化学習に関しては素人同然ですので,これからよく勉強していこうと思います.

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
9