LoginSignup
2
0

More than 1 year has passed since last update.

PyTorch 2.0 Nightly版でGPT日本語文章生成してみた

Last updated at Posted at 2023-02-16

2023年3月上旬にPyTorch 2.0がリリースされる予定です。

このPyTorch 2.0の目玉機能にtorch.compile()というものがあります。これを使ってTransformersの日本語文章生成を試してみました。

環境

  • Google Colaboratory
    • GPU有効(Nvidia Tesla T4)
  • PyTorch 2.0.0.dev20230213+cu117

環境構築

Google Colaboratoryを利用します。

2023年2月16日現在、PyTorch 2.0 Nightlyビルドは日々開発が行われているせいかインストールするたびに挙動が不安定になる事があります。なので、バージョンを指定してインストールします。

!pip install numpy --pre https://download.pytorch.org/whl/nightly/cu117/torch-2.0.0.dev20230212%2Bcu117-cp38-cp38-linux_x86_64.whl --force-reinstall

またTransformersもインストールしておきます

!pip install transformers sentencepiece

文章生成

rinnaの公開しているGPTモデルを利用します。好みのモデル名を設定します。

MODEL_NAME = 'rinna/japanese-gpt2-xsmall'
# 好みで選択
# MODEL_NAME = 'rinna/japanese-gpt2-medium'
# MODEL_NAME = 'rinna/japanese-gpt-1b'

指定したモデルを読み込みます。

import torch
from transformers import T5Tokenizer, AutoModelForCausalLM

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

if torch.cuda.is_available():
    model = model.to("cuda")

次がPyTorch 2.0で異なるところです。torch.compile()の1行を追加します。

model_optimized = torch.compile(model)

文章を生成する実装をします。

input_ids = tokenizer.encode(
    "昔々あるところに",
    add_special_tokens=False,
    return_tensors="pt"
)

def generate_by(model):
    return model.generate(
        input_ids.to(model.device),
        pad_token_id=tokenizer.pad_token_id,
        do_sample=True,
    )[0]

print(tokenizer.decode(generate_by(model_optimized)))
出力※毎回異なります。
昔々あるところに、お世継ぎを残すために産まれたばかりの王子様と魔女と4人の魔女がいました。
(以下略)

以上です。torch.compile()以外は今までと同じであり、この1行を通っていないモデルでも普通に動きます。PyTorch 2.0と1.x系は互換性が高いようです。

PyTorch 2.0での変化

公式ブログによると、PyTorch 2.0でtorch.compile()を噛ますことでTransformersの多くのモデルにおいて処理が高速化したようです。しかし今回利用したGPUとモデルでは、色々試してみたものの速度の変化が全く見られませんでした。

# 生成速度は変わらず
generate_by(model)
generate_by(model_optimized)

torch.compile()では様々なモードを設定することができるようです。しかし、これらによる変化も今回のコードでは分かりませんでした。

model_optimized = torch.compile(model)
# フレームワークのオーバーヘッドを減らすため最適化するが、メモリが多く消費される。小さなモデルで効果的
model_optimized = torch.compile(model, mode="reduce-overhead")
# 最速のモデルを生成するようになるが、コンパイルに時間がかかる
model_optimized = torch.compile(model, mode="max-autotune")

参考

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0