3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Microsoftから独立した元女子高生AI rinnaとcolabで対話する方法

Last updated at Posted at 2021-06-07

概要

今回はMicrosoftが開発し、Microsoftから独立して会社になってしまった元女子高生AIチャットボット「りんな」の事前学習モデルがhuggingfaceに公開されていたので、さっそく試す。

対象読者

#1. 36本目 ライブラリのセットアップ

  • 質問: Colabでrinna(rinna/japanese-gpt2-medium)と会話するために必要なモジュールをインストールせよ。

  • 回答: huggingfaceのtransformersとsentencepieceが必要。ver.は結構大事なのでつけておくことを推奨する。

!pip install transformers==4.6.1
!pip install sentencepiece

#2. 37本目 事前学習モデルのロード

  • 質問: rinnaの事前学習モデル"rinna/japanese-gpt2-medium"をtokenizer, modelにロードせよ。

  • 回答: AutoModelForCasualLMの方はAutoでもいいのだが、T5Tokenizerの方をAutoTokenizerにするとエラーがでる。
    この微妙な差が大事。

from transformers import T5Tokenizer, AutoModelForCausalLM

tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
model = AutoModelWithLMHead.from_pretrained("rinna/japanese-gpt2-medium")

#3. 38本目 トークナイズ

  • 質問: '新型コロナ収束。'を37本目のtokenizerでトークナイズせよ。

  • 回答: tokenizer()に引数で渡すだけ。input_ids, attention_maskが返却される。

inputs = tokenizer('新型コロナ収束。')
inputs

結果

{'input_ids': [9, 7127, 5049, 145, 14365, 8, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

#4. 39本目 推論

  • 質問: 38本目でトークナイズしたテキストを入力に、37本目でロードしたrinna事前学習モデルで推論を行え。

  • 回答: modelにはinput_idsのみをtorch.tensorでキャストして渡す。


import torch
outputs = model(torch.tensor(inputs['input_ids']))
print(outputs.keys())
print(type(outputs.logits))
print(type(outputs.past_key_values))
print(outputs.logits.shape)
print(outputs.past_key_values[0][0].shape)

結果


odict_keys(['logits', 'past_key_values'])
<class 'torch.Tensor'>
<class 'tuple'>
torch.Size([7, 32000])
torch.Size([1, 16, 7, 64])

outputsはlogitsとpast_key_valuesをキーとして持っていることがわかる。
logitsには生成されたテキストのID(語彙32000個のうちのどれかを示すもの)の列が格納されている(7トークン)。
past_key_valuesはもう一度同じ計算をする際に、隠れ層のキャッシュを再利用し高速化を図る目的で保持されている。

#5. 40本目 対話

  • 質問: 39本目の推論結果を解釈し、りんなが発した言葉をテキスト化せよ。

  • 回答: logits.argmax(0)で0軸の最大となる索引を返却できる。logits.argmax(-1)としているのは、軸が増えた場合には軸1を、一次元の場合は軸0のargmaxを取得するという意味である。


print(''.join([tokenizer.decode(logit.argmax(0).item()) for logit in outputs.logits]))

結果

、インフルエンザナの装置pic

天才りんなに聞いてみると、新型コロナ収束の鍵は、インフルエンザナの装置picにありそうだ…
インフルエンザと新型コロナは両方かかってしまうこともあるそうだが、両方に効く一挙両得な治療法が見つかることを切に願う。

たくさん話しかけてみると、拙い日本語だが関連するテキストが返却されるのでぜひお試しください。


#11. 参考文献

#12. 著者

ツイッターでPython/numpy/pandas/pytorch関連の有益なツイートを配信してます。

@keiji_dl

#13. 参考動画

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?