1
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

無料サービスだけ使ってPythonを使ってGPT-2で回答を生成するLinebotを作る

Posted at

参考になれば幸い。
GPT-3が流行ってるのにGPT-2やるの時代遅れとか言わない

ファインチューニング

友達がたくさん入ってるグループがあって、そのグループっぽい会話を生成したかったのでファインチューニングが必要だった。
以下のコードでやった。

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-xsmall",use_fast=False)
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-xsmall")

モデルがxsmallなのは、これより大きいと後でpythonanywhereにアップロードする際、ストレージ容量が足りないため。

import torch
from transformers import LineByLineTextDataset,DataCollatorForLanguageModeling

dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    block_size=256,
    file_path="traindata.txt"
)

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',   # 出力先のディレクトリ
    num_train_epochs=15,      # エポック数
    per_device_train_batch_size=4, # バッチサイズ
    save_steps=10000,        # モデルを保存するステップ数
    save_total_limit=1,      # 保存するチェックポイント数
    learning_rate=2e-5,      # 学習率
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    pad_to_multiple_of=8
)
from transformers import Trainer
device = torch.device("cpu")
model = model.to("cpu")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator
)

trainer.train()

メモリがGPUでは足りなかったので、CPUでの計算。いいグラボが欲しい。

torch.save(model.state_dict(), "gpt-2-xsmall.pt")

これで生成したモデルはGoogleDriveにアップしておきます。

Pythonanywhere でFlaskのWebアプリを建てる

以下はLineSDKAPIforPythonのデモを少し改変したもの。これで事足りる。

from flask import Flask, request, abort
from linebot import (
    LineBotApi, WebhookHandler
)
from linebot.exceptions import (
    InvalidSignatureError
)
from linebot.models import (
    MessageEvent, TextMessage, TextSendMessage,
)

from transformers import AutoTokenizer, AutoModelForCausalLM
import urllib.request
import io
import os
os.environ['NUMEXPR_MAX_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
import torch
torch.set_num_threads(1)
map_location=torch.device('cpu')
tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-xsmall",use_fast=False)
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-xsmall")
url = 'https://drive.google.com/uc?id=file_id'

with urllib.request.urlopen(url) as url_response:
    buffer = io.BytesIO(url_response.read())

model.load_state_dict(torch.load(buffer,map_location=torch.device('cpu')))
app = Flask(__name__)

line_bot_api = LineBotApi('token')
handler = WebhookHandler('token')

@app.route("/callback", methods=['POST'])
def callback():
    # get X-Line-Signature header value
    signature = request.headers['X-Line-Signature']

    # get request body as text
    body = request.get_data(as_text=True)
    app.logger.info("Request body: " + body)

    # handle webhook body
    try:
        handler.handle(body, signature)
    except InvalidSignatureError:
        print("Invalid signature. Please check your channel access token/channel secret.")
        abort(400)

    return 'OK'


@handler.add(MessageEvent, message=TextMessage)
def handle_message(event):
    global tokenizer,model
    input_ids = tokenizer.encode(event.message.text, return_tensors='pt')
    output = model.generate(input_ids,max_length=len(input_ids[0])+32, no_repeat_ngram_size=1,num_beams=10, early_stopping=True)
    output_text = tokenizer.decode(output[0], skip_special_tokens=True)
    line_bot_api.reply_message(
        event.reply_token,
        TextSendMessage(output_text[len(event.message.text):]))


if __name__ == "__main__":
    app.run()

PythonAnywhereではスレッド数が1までしか許可されていないので、その設定が必要。
以下を参照

注意点

Driveに置いたファイルのURLは以下のように取得してください。
image.png

リンクをコピーをクリックし、以下のURLにアクセスする

https://drive.google.com/uc?id=file_id

すると、以下の画面が出るので、ダウンロードボタンを右クリックして「検証」
image.png
そうしたら、Inputを囲っているform要素のaction属性がダウンロードできるURLとなる。
image.png
この場合は、https://drive.google.com/uc?id=1xLAIhxQDcZLcOoScCppNSC28un8lhZyn&confirm=t&uuid=7dad8840-0b09-487f-89d6-fa19329e63cd&at=ANzk5s4LokXQsm1kQxgSsC12SfFA:1681790061150

これを、pythonanywhereのコードのurlに指定する。

LineBotの設定

WebhookURLに、以下を指定する

https://<your_project_id>.pythonanywhere.com/callback

アイコンの設定などは割愛、他の記事に任せる

完成

以下に作成したBotを貼っておく。

まとめ

以下の点が大変だった。

  • 学習(ファインチューニング)
    12時間ほどかかった。いいパソコンほしい。
  • ストレージ制限
    pythonanywhereの無料版には512MBの制限があるため、それ以下のモデルを探すのは思ったより大変でした。名前がわからないと検索引っかからないもん。
  • メモリー制限
    最初はVercelでやろうと思ったが、メモリが1GBしかないのではやりようがなかった。
    pythonanywhereは4GBあるようなので、よかった。
1
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?