はじめに
初投稿記事です。よろしくお願いします。
ディープラーニングで何か面白いものを作りたい!とずっと考えていたのですが、ようやくいいネタが思いついて実装したのでQiitaで記事書くことにしました。
タイトルにもある通り、言語モデルのGPT-2を使ってポケモン図鑑の説明文生成器を作成するのが目標です。といってもゼロからではなく、rinna社が公開している学習済みの日本語モデルをファインチューニングするだけでできちゃいます。めちゃ便利ですね。
流れ
大まかな流れは以下のようになっています。
- 学習データの準備
- 学習データの加工
- 学習済みモデルのファインチューニング
- ポケモン名から図鑑説明文の生成
1. 学習データの準備
まずは、ポケモンWikiから図鑑説明文の掲載がある905匹分のデータを取得します。複数作品に登場するポケモンは、作品ごとに図鑑説明が異なっていることもありますが、それらも全部学習データに入れます。手作業で一つ一つコピペしていくのは気が遠くなるので、各ページをスクレイピングして取得しました。コードは省略。
ちなみに、ポケモンWikiでは以下のような形で説明文が記載されています。
赤・緑、ファイアレッド、Y
ほっぺたの りょうがわに ちいさい でんきぶくろを もつ。ピンチのときに ほうでんする。
(漢字) ほっぺたの 両側に 小さい 電気袋を 持つ。ピンチのときに 放電する。
青、リーフグリーン
なんびきかが あつまっていると そこに もうれつな でんきが たまり いなずまが おちることがあるという。
-ピカチュウ - ポケモンWiki より
漢字verの説明文もありますが、気にせず以下のような形式のjsonファイルとして保存しておきました。
{
// :
// [中略]
// :
"ピカチュウ": [
"ほっぺたの りょうがわに ちいさい でんきぶくろを もつ。ピンチのときに ほうでんする。",
"ほっぺたの 両側に 小さい 電気袋を 持つ。ピンチのときに 放電する。",
"なんびきかが あつまっていると そこに もうれつな でんきが たまり いなずまが おちることがあるという。"
],
// :
// [中略]
// :
}
2. 学習データの加工
今回は、ポケモン名から図鑑説明文を生成するというタスクを想定しています。学習済みの日本語モデルをこのタスクに合わせてファインチューニングしていくのですが、そのために先ほど取得したデータを次のように加工します。
import json
# jsonファイルの読み込み
with open('pokedata.json', 'r') as f:
pokedata = json.load(f)
# 読み込んだjsonからデータを加工する
with open('train.txt', 'w') as f:
for name, descs in pokedata.items():
for desc in descs:
f.write(f'<s>{name}[SEP]{desc}</s>\n')
出力結果↓
:
:
<s>ピカチュウ[SEP]ほっぺたの りょうがわに ちいさい でんきぶくろを もつ。ピンチのときに ほうでんする。</s>
<s>ピカチュウ[SEP]ほっぺたの 両側に 小さい 電気袋を 持つ。ピンチのときに 放電する。</s>
<s>ピカチュウ[SEP]なんびきかが あつまっていると そこに もうれつな でんきが たまり いなずまが おちることがあるという。</s>
:
:
<s>
,</s>
,[SEP]
はスペシャルトークンで、それぞれ文頭、文末、区切りを表す文字列です。使用するTokenizerによってスペシャルトークンの文字列は変わってきますが、rinna社の学習済みモデルのものを使用する場合はこれでいいはず(多分)。
結果として、<s>ポケモン名[SEP]図鑑説明文</s>
という形式のデータが13505件のテキストファイルが用意できました。
3. 学習済みモデルのファインチューニング
続いて学習済みモデルをタスクが解けるようにファインチューニングしていきます。一見、とても大変そうですが、実はデータさえ用意できていればコードを書く必要はほとんどないです。
GoogleColabを使って学習していきます。ProだとハイメモリのGPUが使えるらしいですが、貧乏学生なので無料版を使っています。
Google Driveのマウント
モデルを保存したり、学習データを利用したりするため、GoogleDriveをマウントします。
from google.colab import drive
drive.mount('/content/drive')
%cd ./drive/MyDrive/pokedex
今回はマイドライブ直下にpokedex
フォルダを作りその中で作業していきます。つまりファイル構成は次のようなものを想定しています。
./drive
└── MyDrive
└── pokedex
├── output
├── pokedata.json
├── train.txt
└── transformers
ライブラリのダウンロードなど
必要なパッケージはtransformers
, sentencepiece
, datasets
です。pip
でインストールしておきましょう。バージョンは最新のでいいと思います。
# パッケージのインストール
!pip install transformers sentencepiece datasets
執筆時点では
transformers==4.19.0
sentencepiece==0.1.96
datasets==2.2.1
が入りました。
また、ファインチューニング用のスクリプトが提供されているので、これを使うためにGitHubのtransformersのリポジトリをクローンします。pip
でインストールしたバージョンと同じものを指定しておくといいです。ちなみに、rinna社の学習済みモデルを使う際には、TokenizerとしてAutoTokenizer
ではなく、T5Tokenizer
を明示的に指定する必要があるらしいので、sed
コマンドで置換しています。
# ファインチューニング用のスクリプトを使えるようにする
%cd ./drive/MyDrive/pokedex
!git clone https://github.com/huggingface/transformers -b v4.19.0
!sed -i 's/AutoTokenizer/T5Tokenizer/' ./transformers/examples/pytorch/language-modeling/run_clm.py
ファインチューニング
先ほどクローンしたリポジトリにあるスクリプトを実行します。rinna社の学習済みモデルにはrinna/japanese-gpt2-small
やrinna/japanese-gpt2-medium
などがありますが、Colabの無料枠だとメモリ的に前者、Proであれば後者を選択するといいと思います。適宜読み替えてください。
以下で実行するとファインチューニングが始まります。GPUガチャやモデルの大きさによってかかる時間は違うと思いますが、学習は20~60分ほどで終了します。最終的なモデルは./drive/MyDrive/pokedex/output
に保存されています。
!python ./transformers/examples/pytorch/language-modeling/run_clm.py \
--model_name_or_path=rinna/japanese-gpt2-small \
--train_file=train.txt \
--validation_file=train.txt \
--do_train \
--do_eval \
--num_train_epochs=10 \
--save_steps=10000 \
--save_total_limit=3 \
--per_device_train_batch_size=1 \
--per_device_eval_batch_size=1 \
--output_dir=output/ \
--overwrite_output_dir=true \
--use_fast_tokenizer=False
オプションについては
!python ./transformers/examples/pytorch/language-modeling/run_clm.py -h
で確認できるはずです。
4. ポケモン名から図鑑説明文の生成
実際に図鑑説明文の生成をしてみましょう!
モデルの準備
まずは先ほど学習したモデルを読み込みます。
import torch
from transformers import T5Tokenizer, AutoModelForCausalLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-small")
tokenizer.do_lower_case = True
model = AutoModelForCausalLM.from_pretrained("output/")
model.to(device)
model.eval()
学習に用いたデータは以下のような形式でした。
<s>ポケモン名[SEP]図鑑説明文</s>
GPT-2は入力したテキストに続く文章を生成するので、ポケモン名から図鑑説明文を生成するには
<s>ポケモン名[SEP]
の形式を入力とすればよさそうです。入力を整形して結果をprintする関数を書いて完成!
def generate_description(pokemon: str, num=10) -> None:
input_text = '<s>'+pokemon+'[SEP]'
input_ids = tokenizer.encode(input_text, return_tensors='pt', add_special_tokens=False).to(device)
out = model.generate(
input_ids,
do_sample=True,
top_k=40,
max_length=100,
num_return_sequences=num,
bad_words_ids=[[tokenizer.bos_token_id], [tokenizer.sep_token_id], [tokenizer.unk_token_id]],
repetition_penalty=1.2,
)
print(f'ポケモン名: {pokemon}')
print('='*40)
for s in tokenizer.batch_decode(out):
s = s.replace('<s>', '')
s = s.split('[SEP]')[1]
s = s.replace('</s>', '').strip()
print(s)
いざ生成!
それでは生成していきましょう。先ほど作った関数にポケモン名を与えるだけです。
生成例1
generate_description(pokemon='ピカチュウ')
生成結果↓
ポケモン名: ピカチュウ
========================================
おとなしい イッシュちほうのポケモン。おそろしく ふかふかさで とても にくきゅうちゅうしつである。
なんでも みせようとするが おそろしい おなかの おどろく あかい ハニーを もっている。
10キロさきで おそわれると 3つの ツメをつかって エサをさがす。
ほかのポケモンが エサとして あるくことを まちがえて 食べ物を よこまわす。
うみを およぎまわり 100匹の アンパンマンから うまれた。1ぴきだけ おおけがをすることも あるらしい。
2つの ツメで 3ぼんの マシュマロを くしして なかみを つらぬいている。
ひろい 森や 空に 姿を 見せては その鳴き声で 仲間内の 気持ちを察知するよ。
うちゅうに くらしはじめたころは ちいさなポケモン。ほけんしょくの イワークから おどってきたときは なかまおもくなって ハネになるらしい。
あいての たましいを うばうことが 主な しゅぎょうだが いちげきでは キズもつかない。
サナギが生える エサの さがし方や むれの すそいなどを ひとに よって エサを あたえている。
図鑑説明文の分かち書きまで再現されていますね。すごい。データセットにひらがなver/漢字verが混ざっていたので、一つの説明文で表記が統一されていないものもありますが、それっぽい文章が生成されています。100匹のアンパンマンから生まれた...?
生成例2
別の例でも試してみましょう。
generate_description(pokemon='イーブイ')
生成結果↓
ポケモン名: イーブイ
========================================
4つのツメを ふさいで ねむっている。よるになると おなかから どくほうしをだす。
2つの体でできた ながいハリに エサを とりこみ たべのこしを つくる。
サンゴの発光で おしりに つきさして およぐ。 しんかするごとにちからが たかまっていくため イーブイはこうぼうな ポケモンだ。
かせきから コピーされた ポケモンのそっくりな ポケモン。
おおぞらを ただよう。 もくげきしては おそってきたりてきに ぶちかえるためだ。
5本の足で地面を駆ける。体よりも高いところの岩場などへジャンプするため おそろしい。
3~4匹で木の幹に 頭を突っ込んで 寝そべったまま 過ごすらしい。
2つの 両目が光源となって さまざまなものを みているようだ。
2匹で暮らすよ。リーダーからは 3つ子であると 教えられているんだ。
おおきな おはかで えものを ゆうかんに みとおす。むれを つくらないと しないので つよいのだ。
イーブイは足が5本あって体が2つあるポケモンで複数匹で暮らすらしい...
所々おかしな文はありますが、概ね成功と言えるでしょう。rinna/japanese-gpt2-medium
を使ったり、generate_description
関数の中のmodel.generate
の引数を色々弄ることで、より質の高い文章が生成できるかもしれません。それぞれの引数の説明については以下を参照してください。
生成例3
最後に、まだ図鑑説明のない新ポケモンでも試してみます。
generate_description(pokemon='ホゲータ')
生成結果↓
ポケモン名: ホゲータ
========================================
ひとびとの 願いを かなえるため おどるような なきごえをだす。おぼつかないので だれとも おたがいとは あらそわないぞ。
うごくもの すべてに はんのうする。ひとが いやがるときも おそろしいほど そのすがたらしい。
頭突きで おどりまくる。スピードが 100ノットをゆうに超えることもあるぞ。
3つのツメで 相手を おなかに いれて まるのみ。 じめんを このむのは ストレスたまると いうよ。
うちゅうの ポケモンを モデルにした マスコットのような かわいらしい めだまが いる。
空の 上からでも 炎の吐くような 鳴き声を 出せる。 本体とは ハネを 使い分けるぞ。
ほのおを あやつる。あいての うごきを よそくして ただかれないうえに ほのおは おれても からだを そりかえすぞ。
サニーゴやゴダイゴを うむ。じめじめした日差しに 強いが 寒いので 水泡には弱いのだ。
4つのツノで おたがいの きもちを みつめあう。うれしくてたまらないのだ。
4つの ユニットは 2つの 2ばいの 弾みで あることができる。
偶然かはわからないですが、ほのおを操るなんて説明文が出てきました。面白いですね。
おわりに
今回は、ポケモン名を入力するとそれっぽい図鑑説明文が生成される、というものを作ってみました。ポケモン名が生成に影響しているのかどうかは若干微妙ですが、満足いくものができたと思っています。
コードはGoogleColabにあるので、パラメータなど好きなように弄って遊んでください。
参考