はじめに
最近、rinna社のGPT-2の日本語学習済みモデルの使い方を勉強したので、アウトプットとしてこの記事を書いてみました。
実行環境
実行環境はGoogle Colaboratoryを使用します。ランタイムの設定でGPUを選択しておいて下さい。
コード
青空文庫から訓練用テキストを取得
import urllib.request
import zipfile
import re
# 青空文庫のURLからテキストを取得する関数
def get_text_from_aozora_url(url):
# zipファイルのダウンロード
zip = re.split(r'/', url)[-1]
urllib.request.urlretrieve(url, zip)
# zipファイルの解凍
with zipfile.ZipFile(zip, 'r') as myzip:
myzip.extractall()
for myfile in myzip.infolist():
filename = myfile.filename
with open(filename, encoding='sjis') as file:
text_org = file.read()
# テキスト整形
text = re.split('\-{5,}',text_org)[2]
text = re.split('底本:',text)[0]
text = re.sub('《.+?》', '', text)
text = re.sub('[#.+?]', '',text)
text = re.sub('\n\n', '\n', text)
text = re.sub('\r', '', text)
text = re.sub("[| ]", "", text)
return text
今回は学習用データに「三四郎」「こころ」「坊っちゃん」を使います。
import time
# ダウンロードURLリスト
url_list = ['https://www.aozora.gr.jp/cards/000148/files/794_ruby_4237.zip',
'https://www.aozora.gr.jp/cards/000148/files/773_ruby_5968.zip',
'https://www.aozora.gr.jp/cards/000148/files/752_ruby_2438.zip'
]
merge_text = ""
for url in url_list:
# テキスト取得
text = get_text_from_aozora_url(url)
# サーバーに負荷をかけないよう1秒空ける
time.sleep(1)
# ダウンロードしたテキストのマージ
merge_text += text + ("\n")
# print(merge_text)
学習用データの保存
file_path = "train.txt"
with open(file_path, mode="w") as f:
f.write(merge_text)
ライブラリのインポート
!pip install transformers
!pip install sentencepiece
トークナイザー(テキストを深層学習モデルの入力データに変換する処理)と事前学習済みモデルの準備です。Colab Proを使っているのなら"rinna/japanese-gpt2-medium"の代わりにより大きいモデルの"rinna/japanese-gpt-1b"を使ってみてもいいと思います。
from transformers import T5Tokenizer, AutoModelForCausalLM
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
訓練用データ設定とモデルの訓練
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, AutoModelWithLMHead
# 訓練用データセットの設定
train_dataset = TextDataset(
tokenizer=tokenizer,
file_path=file_path,
block_size=128
)
# データ入力の設定
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# 訓練の設定
training_args = TrainingArguments(
output_dir="./gpt2_soseki",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=8
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset
)
# 訓練の実施
trainer.train()
# 訓練済みモデルの保存
trainer.save_model("./gpt2_soseki_model")
文章生成関数
def getarate_sentence(seed_sentence):
x = tokenizer.encode(seed_sentence, return_tensors="pt", add_special_tokens=False)
x = x.cuda()
y = model.generate(input_ids=x,
min_length=256,
max_length=512,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
generated_sentence = tokenizer.batch_decode(y, skip_special_tokens=True)
return generated_sentence
文章生成の実行
seed_sentence = "メロスは激怒した。"
generated_sentence = getarate_sentence(seed_sentence)
print(generated_sentence)
出力結果
['メロスは激怒した。そのうえ、おれが生徒を連れてくるとは予想だにせんし、こんな無法野郎を生け捕りにして、どこまでも放免させるつもりかと、怒り狂った。 「おれは放免する気はない、放免の方法さえ教えてやれば放免なのだが、おれの所為で、三年ばかり、四十五度も通って、生徒が困るかと思うがいいか」 ...']
おわりに
出力結果をみると、やや不自然さはありますが、夏目漱石っぽい文章でメロス先生(?)が怒り狂ってますね。内容的に「坊っちゃん」の影響が強く出ているように見受けられます。