参考になれば幸い。
GPT-3が流行ってるのにGPT-2やるの時代遅れとか言わない
ファインチューニング
友達がたくさん入ってるグループがあって、そのグループっぽい会話を生成したかったのでファインチューニングが必要だった。
以下のコードでやった。
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-xsmall",use_fast=False)
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-xsmall")
モデルがxsmallなのは、これより大きいと後でpythonanywhereにアップロードする際、ストレージ容量が足りないため。
import torch
from transformers import LineByLineTextDataset,DataCollatorForLanguageModeling
dataset = LineByLineTextDataset(
tokenizer=tokenizer,
block_size=256,
file_path="traindata.txt"
)
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir='./results', # 出力先のディレクトリ
num_train_epochs=15, # エポック数
per_device_train_batch_size=4, # バッチサイズ
save_steps=10000, # モデルを保存するステップ数
save_total_limit=1, # 保存するチェックポイント数
learning_rate=2e-5, # 学習率
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
pad_to_multiple_of=8
)
from transformers import Trainer
device = torch.device("cpu")
model = model.to("cpu")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator
)
trainer.train()
メモリがGPUでは足りなかったので、CPUでの計算。いいグラボが欲しい。
torch.save(model.state_dict(), "gpt-2-xsmall.pt")
これで生成したモデルはGoogleDriveにアップしておきます。
Pythonanywhere でFlaskのWebアプリを建てる
以下はLineSDKAPIforPythonのデモを少し改変したもの。これで事足りる。
from flask import Flask, request, abort
from linebot import (
LineBotApi, WebhookHandler
)
from linebot.exceptions import (
InvalidSignatureError
)
from linebot.models import (
MessageEvent, TextMessage, TextSendMessage,
)
from transformers import AutoTokenizer, AutoModelForCausalLM
import urllib.request
import io
import os
os.environ['NUMEXPR_MAX_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
import torch
torch.set_num_threads(1)
map_location=torch.device('cpu')
tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-xsmall",use_fast=False)
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-xsmall")
url = 'https://drive.google.com/uc?id=file_id'
with urllib.request.urlopen(url) as url_response:
buffer = io.BytesIO(url_response.read())
model.load_state_dict(torch.load(buffer,map_location=torch.device('cpu')))
app = Flask(__name__)
line_bot_api = LineBotApi('token')
handler = WebhookHandler('token')
@app.route("/callback", methods=['POST'])
def callback():
# get X-Line-Signature header value
signature = request.headers['X-Line-Signature']
# get request body as text
body = request.get_data(as_text=True)
app.logger.info("Request body: " + body)
# handle webhook body
try:
handler.handle(body, signature)
except InvalidSignatureError:
print("Invalid signature. Please check your channel access token/channel secret.")
abort(400)
return 'OK'
@handler.add(MessageEvent, message=TextMessage)
def handle_message(event):
global tokenizer,model
input_ids = tokenizer.encode(event.message.text, return_tensors='pt')
output = model.generate(input_ids,max_length=len(input_ids[0])+32, no_repeat_ngram_size=1,num_beams=10, early_stopping=True)
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
line_bot_api.reply_message(
event.reply_token,
TextSendMessage(output_text[len(event.message.text):]))
if __name__ == "__main__":
app.run()
PythonAnywhereではスレッド数が1までしか許可されていないので、その設定が必要。
以下を参照
注意点
Driveに置いたファイルのURLは以下のように取得してください。
リンクをコピーをクリックし、以下のURLにアクセスする
https://drive.google.com/uc?id=file_id
」
すると、以下の画面が出るので、ダウンロードボタンを右クリックして「検証」
そうしたら、Inputを囲っているform要素のaction属性がダウンロードできるURLとなる。
この場合は、https://drive.google.com/uc?id=1xLAIhxQDcZLcOoScCppNSC28un8lhZyn&confirm=t&uuid=7dad8840-0b09-487f-89d6-fa19329e63cd&at=ANzk5s4LokXQsm1kQxgSsC12SfFA:1681790061150
。
これを、pythonanywhereのコードのurlに指定する。
LineBotの設定
WebhookURLに、以下を指定する
https://<your_project_id>.pythonanywhere.com/callback
アイコンの設定などは割愛、他の記事に任せる
完成
以下に作成したBotを貼っておく。
まとめ
以下の点が大変だった。
- 学習(ファインチューニング)
12時間ほどかかった。いいパソコンほしい。 - ストレージ制限
pythonanywhereの無料版には512MBの制限があるため、それ以下のモデルを探すのは思ったより大変でした。名前がわからないと検索引っかからないもん。 - メモリー制限
最初はVercelでやろうと思ったが、メモリが1GBしかないのではやりようがなかった。
pythonanywhereは4GBあるようなので、よかった。