Why not login to Qiita and try out its useful features?

We'll deliver articles that match you.

You can read useful information later.

17
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Ateam LifeDesignAdvent Calendar 2022

Day 4

TransformersでGPTの文章生成する方法いろいろ

Last updated at Posted at 2022-12-03

Transformersを使うと、GPTの事前学習モデルを使って簡単に文章生成ができます。モデル自体は同じでも色々なメソッドが用意されていて、用途に応じて適切なインターフェースを選ぶことでより便利に使えます。

環境

  • Google Colaboratory
  • transformers: 4.25.1

!pip install transformers sentencepiece等でtransformersをインストールしておきます。

事前準備

MODEL_NAME定数に好きなモデルを指定してください。ここではrinna/japanese-gpt2-xsmallを使います。

MODEL_NAME = 'rinna/japanese-gpt2-xsmall'

pipeline

pipeline()を使うのが、最も簡単な方法だと思います。
pipeline()の第一引数にタスクを指定することで、タスクを実行する簡単なパイプラインを生成できます。。ここではテキスト生成なのでtext-generationを用いますが、他に使えるタスクは公式ドキュメントに説明があります。

from transformers import pipeline

text_pipe = pipeline('text-generation', model=MODEL_NAME)
output = text_pipe("昔々あるところに")

output[0]['generated_text']

昔々あるところに、たくさん人が来て、たくさんの人がいて、みんなで楽しめる、それがなんだろう...と思ったことがあったんですw 僕には「この人じゃなかったらやらなかったんだねー」と思いました。 ・さん:

※この出力は毎回変わります

generate()メソッド

おそらく最も一般的なのはgenerate()メソッドを使った方法ではないでしょうか。rinnaモデルのREADMEで紹介されている方法だったりします。

import torch
from transformers import T5Tokenizer, AutoModelForCausalLM

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

with torch.no_grad():
    input_ids = tokenizer.encode("昔々あるところに", add_special_tokens=False, return_tensors="pt")
    output_ids = model.generate(input_ids.to(model.device), max_length=20, pad_token_id=tokenizer.pad_token_id)

tokenizer.decode(output_ids.tolist()[0])

昔々あるところに、あるお店があります。 店内は、カウンター席とテーブル

generate()に引数としてtop_knum_beamsなどを付与すると、beam searchやsamplingなどのdecodingアルゴリズムを制御できます。decodingアルゴリズムについて詳細は下記公式ブログを参照してください。

greedy searchを自前で実装する

decodingの中でも最も単純なgreedy search(貪欲法)つまり、もっとも確率の高いトークンを選び続ける手法を、実際に実装してみます。
なお、これは機械学習エンジニアのためのTransformersにて紹介されていたソースコードを利用しています。

n_steps = 10

input_ids = tokenizer.encode("昔々あるところに", return_tensors="pt")

with torch.no_grad():
    for _ in range(n_steps):
        output = model(input_ids)

        next_token_logits = output.logits[0, -1, :]
        next_token_probs = torch.softmax(next_token_logits, dim=-1)
        sorted_ids = torch.argsort(next_token_probs, descending=True, dim=-1)

        input_ids = torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1)

tokenizer.decode(input_ids[0])
出力
昔々あるところに</s><unk> <unk> <unk> <unk> <unk>

unknown token <unk> ばかりになってしまいました。実装を間違えてしまったのでしょうか。実際に、モデルの出力がどうなっているのかを確認してみましょう。トークンを確率が高い順に並び替えた結果であるsorted_idsの中身を出力してみます。

n_steps = 10

input_ids = tokenizer.encode("昔々あるところに", return_tensors="pt")

with torch.no_grad():
    for _ in range(n_steps):
        output = model(input_ids)

        next_token_logits = output.logits[0, -1, :]
        next_token_probs = torch.softmax(next_token_logits, dim=-1)
        sorted_ids = torch.argsort(next_token_probs, descending=True, dim=-1)

        input_ids = torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1)

+       print(",".join(tokenizer.decode(sorted_ids[i]) for i in range(5)))

出力
<unk>,あ,昔,こんな,は
,C,の,o,oo
<unk>,が,の,という,で
,C,が,の,という
<unk>,が,の,という,C
,C,が,の,という
<unk>,が,という,の,C
,C,が,の,(
<unk>,が,という,の,C
,C,が,(,の

第1位に<unk>が来ているのですが、第2位以下は普通の単語が並んでいる事がわかります。

普通こんな事はしませんが、第1位ではなく第2位のトークンを取り続けるように改修すると、文章を出力するようになりました。

n_steps = 10

input_ids = tokenizer.encode("昔々あるところに", return_tensors="pt")

with torch.no_grad():
    for _ in range(n_steps):
        output = model(input_ids)

        next_token_logits = output.logits[0, -1, :]
        next_token_probs = torch.softmax(next_token_logits, dim=-1)
        sorted_ids = torch.argsort(next_token_probs, descending=True, dim=-1)

-       input_ids = torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1)
+       input_ids = torch.cat([input_ids, sorted_ids[None, 1, None]], dim=-1)

tokenizer.decode(input_ids[0])
出力
昔々あるところに</s> あそこは、この辺で ”<unk>

このように、自前で実装することで中身の挙動を調べたりカスタムが自由に行なえます。

greedy_search()メソッド

greedy searchを自前で実装しましたが、Transformers内部にもメソッドが用意されています。
このメソッドを使うことで処理自体はとても単純になるのですが、LogitsProcessorというものを用意して渡す必要があります。説明は置いておいて、とりあえず動かしてみます。

from transformers.generation import LogitsProcessorList

input_ids = tokenizer.encode("昔々あるところに", return_tensors="pt")

logits_processor = model._get_logits_processor(
    repetition_penalty=None,
    no_repeat_ngram_size=None,
    encoder_no_repeat_ngram_size=None,
    input_ids_seq_length=None,
    encoder_input_ids=None,
    bad_words_ids=None,
    min_length=None,
    max_length=20,
    eos_token_id=tokenizer.eos_token_id,
    forced_bos_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    prefix_allowed_tokens_fn=None,
    num_beams=None,
    num_beam_groups=None,
    diversity_penalty=None,
    remove_invalid_values=None,
    exponential_decay_length_penalty=None,
    logits_processor=LogitsProcessorList(),
    renormalize_logits=None,
)

outputs = model.greedy_search(
    input_ids=input_ids,
    logits_processor=logits_processor,
    pad_token_id=tokenizer.pad_token_id,
)

tokenizer.decode(outputs[0])
出力
昔々あるところに</s><unk> <unk> <unk> <unk> <unk> <unk> <unk></s>

greedy searchを実装した時と同様、<unk>が並んでしまいました。

LogitsProcessorにはbad_words_idsという設定があり、ここに<unk>を出力しないよう指定することができます。

logits_processor = model._get_logits_processor(
    repetition_penalty=None,
    no_repeat_ngram_size=None,
    encoder_no_repeat_ngram_size=None,
    input_ids_seq_length=None,
    encoder_input_ids=None,
-   bad_words_ids=None,
+   bad_words_ids=[[tokenizer.unk_token_id]],
    min_length=None,
    max_length=20,
    eos_token_id=tokenizer.eos_token_id,
    forced_bos_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    prefix_allowed_tokens_fn=None,
    num_beams=None,
    num_beam_groups=None,
    diversity_penalty=None,
    remove_invalid_values=None,
    exponential_decay_length_penalty=None,
    logits_processor=LogitsProcessorList(),
    renormalize_logits=None,
)
出力
昔々あるところに</s> ああ、あの頃は ああ、あの頃は</s>

このように、LogitsProcessorを使うことで必要な出力が得やすいようにdecodingできます。

自前の実装をカスタムする

LogitsProcessorを利用する

先ほど自前で実装したgreedy searchにLogitsProcessorを組み込むには、下記のようにモデルの出力に挟みます。

input_ids = tokenizer.encode("昔々あるところに", return_tensors="pt")
n_steps = 10

with torch.no_grad():
    for _ in range(n_steps):
        output = model(input_ids)
        next_token_logits = output.logits[:, -1, :]
        next_token_scores = logits_processor(input_ids, next_token_logits)

        probs = torch.softmax(next_token_scores, dim=-1)

        next_tokens = torch.argmax(probs, dim=-1)[:, None]

        input_ids = torch.cat([input_ids, next_tokens], dim=-1)

tokenizer.decode(input_ids[0])
出力
昔々あるところに</s>ああ、あの頃は ああ、

ちゃんと、先ほどと同じ出力が得られました。
このようにして、自前の実装とTransformersの便利なメソッドを組み合わせることで、必要な生成処理を比較的簡単に実装できるようになります。

samplingを実装する

これまではgreedy searchということで、torch.argmax()を用いて確率が最大のトークンを持ってきていました。ここを変更しtorch.multinomialで確率分布に基づいてトークンを選ぶようにすると、samplingになります。

input_ids = tokenizer.encode("昔々あるところに", return_tensors="pt")
n_steps = 10

with torch.no_grad():
    for _ in range(n_steps):
        output = model(input_ids)
        next_token_logits = output.logits[:, -1, :]
        next_token_scores = logits_processor(input_ids, next_token_logits)

        probs = torch.softmax(next_token_scores, dim=-1)

-       next_tokens = torch.argmax(probs, dim=-1)[:, None]
+       next_tokens = torch.multinomial(probs, num_samples=1)[:, None, 0]

        input_ids = torch.cat([input_ids, next_tokens], dim=-1)

tokenizer.decode(input_ids[0])
出力 ※この出力は毎回変わります
昔々あるところに</s>いろんな種類校正された校正物

同様に、この処理を変更していけばbeam searchやtop-k samplingなども実装できます。

まとめ

TransfomrmersでGPTの文章生成する方法を複数紹介しました。手軽な方法を使えばすぐに文章生成を試すことができますし、自前で実装することで仕組みの理解を深めたり、調査やカスタマイズもできるようになります。Transfomrmersには色々な機能が用意されていて、目的に応じて適切な機能を使いこなす事が大切だと思います。

参考文献

17
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?