124
90

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

GPT-2をファインチューニングしてニュース記事のタイトルを条件付きで生成してみた。

Last updated at Posted at 2021-08-21

はじめに

GPT-2のファインチューニングの仕方がなんとなくわかってきたので、メモとして残しておきます。

事前学習モデルはrinna社が提供してくれている以下の日本語モデルを使用します。huggingface/transformersから簡単にロードして使うことができます。

こちらのモデルはmediumとあるので、TransformerのDecoderブロックが24層重なったやつですね。

今回紹介する内容はGPT-2条件付き文章生成です。
GPT-2のチュートリアル的な記事でよく見るのが、与えた文章の続きにくる文章を生成するようなものが多いかと思いますが、出力の形式等を入力の段階から制御するようなことをしてみようと思います。

GPT-2自体の理解や、使い方、ファインチューニングの仕方、生成文章を入力の段階で制御する方法などは以下の記事がとても参考になりました、というか以下の記事をただパクってるだけです。

  1. The Illustrated GPT-2 (Visualizing Transformer Language Models)
  2. Huggingface Transformers 入門 (28) - rinnaの日本語GPT-2モデルのファインチューニング
  3. Conditional Text Generation by Fine Tuning GPT-2

扱うデータはlivedoorニュースコーパスで、ニュース記事のカテゴリーと本文を与えて、ニュースのタイトルを生成するようにファインチューニングします。入力するカテゴリーらしいタイトルを生成してくれることを狙います。

実装

Google Colab Proを使っています。GPT-2はかなり巨大なモデルのため、colabの無料版ではおそらくGPUのメモリが足りない可能性があります。GPT-2を動かすのを機にcolab proに課金するのはいかがでしょうか。

colabにgoogle driveをマウントしておきます。

from google.colab import drive
drive.mount('/content/drive')

ライブラリのインストール

GPT-2のファインチューニングにはhuggingfaceが提供しているスクリプトファイルを使うととても便利なので、今回もそれを使いますが、そのスクリプトファイルを使うにはtransformersをソースコードからインストールする必要があるので、必要なライブラリを以下のようにしてcolabにインストールします。

# ソースコードから直接transformersをインストール
!pip install git+https://github.com/huggingface/transformers
# rinna/japanese-gpt2-mediumのtokenizerはsentencepieceなのでsentencepieceもインストールする必要があります。
!pip install sentencepiece
!pip install datasets

ファインチューニングで使うスクリプトファイルを得るために、transformersのリポジトリもcloneしておきます。

!git clone https://github.com/huggingface/transformers

# run_clm.pyを使います
!ls ./transformers/examples/pytorch/language-modeling/
# README.md	  run_clm_no_trainer.py  run_mlm_no_trainer.py	run_plm.py
# requirements.txt  run_clm.py		 run_mlm.py

rinna社の学習済GPT-2は以下のように簡単にロードできます。

from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-medium")
tokenizer.do_lower_case = True # due to some bug of tokenizer config loading
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")

学習データの準備

手元に以下のようにカテゴリー、タイトル、本文で分けられたlivedoorニュースコーパスのデータを用意します。
私はすでにgoogle drive上にpickle形式で保存しているので、それを以下でロードしています。

import pandas as pd
import pickle
from tqdm import tqdm

livedoor_dir = "drive/MyDrive/ColabNotebooks/livedoor_data/"
with open(livedoor_dir + 'livedoor_data.pickle', 'rb') as r:
    livedoor_df = pickle.load(r)

display(livedoor_df.sample(5))

スクリーンショット 2021-08-20 22.46.11.png

次に上で用意したlivedoorデータを、カテゴリーと本文を与えてタイトルを生成するタスクを解けるようにファインチューニングさせたいわけですが、学習データの準備がとても大事です。
参考記事3.で挙げたように、以下のような形式でデータを整形する必要があります。

SPECIAL_TOKENS['bos_token'] 
+ ニュース記事のカテゴリー + SPECIAL_TOKEN['sep_token'] + ニュース記事の本文 
+ SPECIAL_TOKEN['sep_token'] + ニュース記事のタイトル + SPECIAL_TOKEN['eos_token']

SPECIAL_TOKENはtokenizerに依存する話ですが、tokenizer.special_tokens_mapで確認できます。
<s></s><unk>になってるのはtokenizerがsentencepieceで学習されているからですね。)

tokenizer.special_tokens_map
# {'bos_token': '<s>',
#  'cls_token': '[CLS]',
#  'eos_token': '</s>',
#  'mask_token': '[MASK]',
#  'pad_token': '[PAD]',
#  'sep_token': '[SEP]',
#  'unk_token': '<unk>'}

学習データを上のようにSPECIAL_TOKENで分割した状態でモデルに与えて、GPT-2の言語モデルをファインチューニングします。
GPT-2の言語モデルは次単語予測なわけですから、上のような学習データでファインチューニングすることで、推論時にSPECIAL_TOKENS['bos_token'] + ニュース記事のカテゴリー + SPECIAL_TOKEN['sep_token'] + ニュース記事の本文 + SPECIAL_TOKEN['sep_token'] を与えてやれば、その次に来るべき単語は(入力として与えたカテゴリーと本文を元にした)ニュース記事のタイトルを表すような単語が生成されるだろう、という考え方です。

上のような形式の学習データを1つのテキストファイルとして取りまとめます。

# 学習データの保存先
save_dir = "drive/MyDrive/ColabNotebooks/GPT-2/"

with open(save_dir + 'gpt2_train_data.txt', 'w') as output_file:
    for row in tqdm(livedoor_df.itertuples(), total=livedoor_df.shape[0]):
        category = row.category
        title = row.title
        body = row.body

        # 本文のトークン数が256までで制限しています。これは単純にメモリの節約で、これ以上だとcolab proでも
        # ファインチューニング時にGPUメモリが足りなくなる可能性があります。
        tokens = tokenizer.tokenize(body)[:256]
        body = "".join(tokens).replace('', '')
        text = '<s>' + category + '[SEP]' + body + '[SEP]' + title + '</s>'
        output_file.write(text + '\n')

こんな感じの学習データが出来上がりました。あとはこれをファインチューニング用のスクリプトファイルrun_clm.pyに渡してやればOKです。
スクリーンショット 2021-08-20 23.14.09.png

ファインチューニング

参考記事2.と同じように以下のようにスクリプトを実行すればOKです。--model_name_or_pathにrinna社の事前学習済モデルを指定してやりましょう。
BERT(base)と違って、Transformerブロックは24層と倍の数だし、隠れ層のサイズも1024なので、バッチサイズは要注意です。簡単にメモリーオーバーしちゃいます。colabでは1が限界です。

!python ./transformers/examples/pytorch/language-modeling/run_clm.py \
    --model_name_or_path=rinna/japanese-gpt2-medium \
    --train_file=drive/MyDrive/ColabNotebooks/GPT-2/gpt2_train_data.txt \
    --validation_file=drive/MyDrive/ColabNotebooks/GPT-2/gpt2_train_data.txt \
    --do_train \
    --do_eval \
    --num_train_epochs=10 \
    --save_steps=10000 \ # driveの容量が足りなくなるかもなので、保存するstepの間隔は要注意!
    --save_total_limit=3 \
    --per_device_train_batch_size=1 \ # バッチサイズ1で代替14,5GBほどGPUメモリ使います。colabではバッチサイズは1が限界です。
    --per_device_eval_batch_size=1 \
    --output_dir=drive/MyDrive/ColabNotebooks/GPT-2/output/ \
    --use_fast_tokenizer=False

今回のファインチューニングはだいたい3時間ほどかかりました(学習データのテキストサイズ約10MB、エポック数10、バッチサイズ1)

もっとたくさんのデータでファインチューニングしたい、となると相当のメモリと学習時間が必要そうだなぁと痛感しました。

ちなみに学習曲線はこんな感じになりました。きれいに収束するもんだなぁと関心しました。

image.png

学習されたモデルは--output_dirで指定した先に保存されていると思います。ファインチューニング後のモデルをロードするには以下のように--output_dirで指定したディレクトリをそのまま指定すればOKです。

model = AutoModelForCausalLM.from_pretrained("drive/MyDrive/ColabNotebooks/GPT-2/output/")

文章生成をやってみる

では、上でファインチューニングしたGPT-2にニュース記事のカテゴリーと本文を与えて、タイトルを生成してもらいましょう。

改めてtokenizerとファインチューニング済モデルをロードします。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-medium")
tokenizer.do_lower_case = True
model = AutoModelForCausalLM.from_pretrained(save_dir + 'output/')
model.to(device)
model.eval()

こんな感じでカテゴリーと本文を与えたら、10個のタイトル文を生成するような関数を用意します。
モデルに入力する際、学習データと同じように先頭に<s>、カテゴリーと本文のあとに[SEP]を挿入しています。

generateには引数が盛りだくさんです。リファレンスを覗いてみるといいかもしれません。

ポイント
model.generate()の引数bad_words_idsが重要です。これは指定したIDのトークンを生成させないようにする役割を持ちます。以下では15、つまり<s>[SEP]を生成しないように制御しています。この指定をしないと、生成される文章が
<s>category[SEP]body[SEP]title</s>のように、カテゴリーから生成してしまうケースが多々発生してしまいます。学習データ量の問題なんですかねぇ。。。

def generate_news_title(category, body, num_gen=10):
    input_text = '<s>'+category+'[SEP]'+body+'[SEP]'
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
    out = model.generate(input_ids, do_sample=True, top_p=0.95, top_k=40, 
                         num_return_sequences=num_gen, max_length=256, bad_words_ids=[[1], [5]])
    print('='*5,'本文', '='*5)
    print(body)
    print('-'*5, '生成タイトル', '-'*5)
    for sent in tokenizer.batch_decode(out):
        sent = sent.split('[SEP]</s>')[1]
        sent = sent.replace('</s>', '')
        print(sent)

生成例を見てみる

サンプルのデータを日経から持ってこようと日経を開いたら、みずほフィナンシャルグループがまたやらかしてしまったというニュースが目に入ったので、こちらのニュースの冒頭を拝借してGPT-2にタイトルを生成してもらいましょう。カテゴリーはとりあえずtopic-newsにしておきます。

実際のタイトルは「みずほ、再発防止つまずき 5度目障害も顧客へ周知遅く」ですね。

category = 'topic-news'
body = '''
みずほフィナンシャルグループ(FG)が再びつまずいた。
システム障害により20日の午前に全国のおよそ450ある店舗で窓口での入金や振り込みなどの取引が一時できなくなった。
2021年に入ってからすでに4件の障害を起こし原因究明と再発防止に取り組むさなかに障害を起こした。
システムの安定稼働とみずほの再生に向けた道のりは険しい。
'''
generate_news_title(category, body)

以下の生成文章はGPT-2による自動生成なので、事実とは異なる可能性があります。また、明らかにヤバめな文章がたまに表示されるので、そういうのはふせさせていただきます。(削除しました、と表記)

カテゴリー topic-news
===== 本文 =====

みずほフィナンシャルグループ(FG)が再びつまずいた。
システム障害により20日の午前に全国のおよそ450ある店舗で窓口での入金や振り込みなどの取引が一時できなくなった。
2021年に入ってからすでに4件の障害を起こし原因究明と再発防止に取り組むさなかに障害を起こした。
システムの安定稼働とみずほの再生に向けた道のりは険しい。

----- 生成タイトル -----
<unk>がつまずいた 「みずほ」でトラブル続々
再びつまずいたみずほ銀行
復旧のめどが立たず、店舗で現金の引き出しなどの一部サービスが利用できない
(削除しました)
再び障害。みずほフィナンシャルグループのつまずき
障害で一部店舗で入金や振り込みができない【ニュース】
再びつまずいたみずほフィナンシャルグループの障害、原因は“みずほ銀行”
再びつまずいたみずほフィナンシャルグループ
<unk>「みずほフィナンシャルグループ」がつまずいた...障害、預金口座の現金化などの被害相次ぐ
再びつまずいたみずほ。今度は本店でトラブル発生

<unk>が出てしまうのは、tokenizerは学習済のものをそのまま使っているので致し方なしとして、本文の内容を表したタイトルっぽいものが生成されている気がします。

今度は、同じ内容の本文でカテゴリーをpeachyに変えてみます。peachyは若い女性向けのニュースを集めています。

カテゴリー peachy
===== 本文 =====

みずほフィナンシャルグループ(FG)が再びつまずいた。
システム障害により20日の午前に全国のおよそ450ある店舗で窓口での入金や振り込みなどの取引が一時できなくなった。
2021年に入ってからすでに4件の障害を起こし原因究明と再発防止に取り組むさなかに障害を起こした。
システムの安定稼働とみずほの再生に向けた道のりは険しい。

----- 生成タイトル -----
<unk>のシステム障害、顧客約400万人に影響
250万台の携帯電話がダウン、みずほ銀行本店は20時から営業開始
<unk>みずほ、店舗での入金や振込みの再開に影響
<unk>障害、みずほフィナンシャルグループ本店が一時休業
<unk>みずほ、システム障害で窓口での入金や振り込みなどが一時できない
<unk>にまたまた障害発生!みずほ銀行が全国の店舗で窓口での入金や振り込みなどの取引ができない状況
見ているだけで幸せになれる!世界にひとつしかない自分だけのケーキ<unk>
<unk>障害によりみずほ銀行が全国の店舗で入出金などの一部サービスを停止した
<unk>みずほフィナンシャルグループがまたもやつまずいた
<unk>障害、みずほで再び大ピンチ

最後の大ピンチとか、ちょっとだけpeachyっぽいタイトルが生成されていたり、ケーキとか全く関係ないタイトルが出てきてしまっていますが、カテゴリーを変えることでいくらか生成文を制御できているような気がします。
もう少し、ポップな文章が生成されることを期待していたのですが、そこまでうまくいきませんでした。

他にもいくつか生成例を見てみます。

カテゴリー smax
===== 本文 =====

スマートフォンを使う上で「どれだけバッテリーが長持ちするのか」は重要なポイントだ。バッテリーの持ちというと、「充電せずに連続でどれだけ使えるのか」に目が行きがちだが、今回着目するのは「バッテリーの寿命」。
つまり、バッテリーを交換することなく、1台のスマホをどれだけ長く利用できるか、ということ。スマートフォンの機能が成熟し、買い換えサイクルが伸びつつある中、1台のスマートフォンはより長く使えることが望ましい。
バッテリーが3年長持ちするというXperia
 言うまでもなく、スマートフォンは繰り返し充電をしながら使うものだが、充電のタイミングや方法によってはバッテリーを劣化させる恐れがある。
また充電をしないときでも、スマートフォンを使う環境によってはバッテリーに悪影響を及ぼす可能性もある。

----- 生成タイトル -----
<unk>ac必須アプリがサクサク動く「<unk>egra3」 <unk>pp<unk>toreからダウンロード
<unk>i-<unk>i通信をより便利に!国内大手のキャリア3社が協業でスマートフォンの新サービスを開始
知っ得!虎の巻 vol.16スマートフォンのバッテリーを長持ちさせる充電の方法
知っ得!虎の巻スマホ活用術スマホを長持ちさせる秘訣!
知っ得!虎の巻アプリでスマートフォンの長寿命化を目指そう!
<unk>ote <unk>C-06<unk>でさらに長持ちさせる方法! <unk>ndroid 4.0でさらに便利になった「<unk>ote」の活用法
<unk>ppleがバッテリー関連の問題を修正するためにソフトウェアアップデート「<unk>alaxy <unk>ote 4.3」を提供開始!<unk>ドコモも対応
知っておきたいスマホの「お手入れ」と「長持ち」の知識【スマホ快適術】
知らなきゃソン!スマートフォンのバッテリーは「充電しない」方が長持ち
知っ得!虎の巻!スマートフォンの活用テクニック【知っ得!虎の巻】

「知っ得!虎の巻」ってのがおそらくlivedoorニュースで連載的に掲載されていたのかな、と思われます。ちょっと学習データに引きずられている感じが色濃く出てしまいました。

カテゴリー sports-watch
===== 本文 =====

米大リーグ・エンゼルスの大谷翔平投手は19日(日本時間20日)、タイガース戦で「1番・指名打者」で先発し、3打数2安打1打点、2四球を選び計4度出塁。
チームの13-10での逆転勝利に貢献した。18日(同19日)の同戦では「1番・投手」で出場し、40号弾&8勝目を挙げた。
40号を放ち、三塁を回って生還する直前にはなんとバットボーイとハイタッチ。
これに米ファンは笑撃を受けていたが、日本のファンも「この試合のハイライト」「日本でもあれば面白いのに」などとテレビ越しに注目していた。

----- 生成タイトル -----
 ′′日本のファンも注目′′大谷がメジャー初本塁打&初の打点、米メディアも絶賛
 米大リーグ・レイズの大谷翔平に「日本のファンも笑ってるんだよ」
 「日本のファンも楽しみにしている」大谷、出塁とハイタッチで日本を挑発
 日本のファンも注目していた大逆転劇、米メジャーの凄い場面
 “日本じゃなくて、韓国版「日本の10倍は辛い」” の声も
 米大リーグ・ツインズの西岡剛内野手が「日本でがあれば面白いのに」
 イチローの“神スイング”に「日本のファンも嬉しそうだったな」
 日本人に人気の海外球種=メジャーの投手とは?
 米大リーグ・大谷翔平が「日本のファンも楽しみにしている」と話題
 楽天・大谷翔平に「日本のファンは笑わないでくれ」

本文に思いっきりバッドボーイとハイタッチとの記載はあるが、バッドボーイとハイタッチした、というタイトルは生成されなかったですね。全然関係ないタイトルもあるしこちらもうまくいかず。

カテゴリー movie-enter
===== 本文 =====

俳優の仲村トオルが主演する、万田邦敏監督の新作映画『愛のまなざしを』の公開日が、11月12日に決定。「救いを求め、堕ちてゆく」というコピーが添えられたポスタービジュアルが公開された。
【画像】映画『愛のまなざしを』場面写真
 本作は、第54回カンヌ国際映画祭の批評家週間に出品され、レイル・ドール賞とエキュメニック賞をダブル受賞した『UN LOVED』(2002年)。
小池栄子と豊川悦司が主演を務め、第23回高崎映画祭で最優秀作品賞に輝いた『接吻(せっぷん)』(08年)に続く“愛の三部作最終章”。「愛」の本質を見つめ、人間の性とエゴをあぶりだした愛憎サスペンス。

----- 生成タイトル -----
仲村トオルのニューヒロインが「救いを求め、堕ちてゆく」
女優の仲村トオルが主演、夫の小栗旬が製作・監督を務める映画『愛のまなざしを』のポスター公開
仲村トオル主演、万田邦敏監督の最新作『愛のまなざしを』ポスター公開
仲村トオルの映画は万田邦敏監督の新作が多い
仲村トオルが主演、万田邦敏監督の新作『愛のまなざしを』のポスター公開
仲村トオルが主演、萬田邦敏監督の最新作が公開
万田邦敏監督最新作、公開日が11月12日に決定
仲村トオルが主演、夫の浮気に苦しむ妻を優しく慰める
万田邦敏監督最新作『愛のまなざしを』のポスター公開
仲村トオルが主演する、万田邦敏監督の最新作『愛のまなざしを』のポスタービジュアル公開

仲村トオルが主演であること、万田邦敏氏が監督の最新作であること、作品名が愛のまなざしであることをちゃんと捉えられている気がします。まぁ入力文章に全部キレイに書いてあるからタイトル生成としてや容易だったのかもしれません。

おわりに

GPT-2で遊んでみましたが、こんな簡単につかえて、生成される文章は少なくとも日本語としてはとても流暢に出力されます。入力で出力を色々制御できそうな雰囲気も感じることができ、学習データをもっと増やせばもっと意図した出力になるのではと思います。
ニュース記事のタイトルってある意味その記事の要約文章と捉えることもできると思うんですよね。なのでこんな感じでGPT-2を使えば文章要約のタスクに挑戦することもできそうですね。その他の応用例として入力で質問文を与えて出力で回答文を生成する(質疑応答)なんてのもやってみたいと思いました。
色々使い方を妄想しよう。

おわり

124
90
4

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
124
90

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?