はじめに
Bumblebee の公式サンプルに従ってモデルを実行してみるシリーズです
今回は RoBERTa という自然言語処理 AI モデルを利用して、 AI に質疑応答してもらいます
このシリーズの記事
- 画像分類: ResNet50
- 画像生成: Stable Diffusion
- 文章の穴埋め: BERT
- 文章の判別: BERTweet
- 文章の生成: GPT2
- 質疑応答: RoBERTa (ここ)
- 固有名詞の抽出: bert-base-NER
Bumblebee の公式サンプル
実装の全文はこちら
実行環境
- MacBook Pro 13 inchi
- 2.4 GHz クアッドコアIntel Core i5
- 16 GB 2133 MHz LPDDR3
- macOS Ventura 13.0.1
- Rancher Desktop 1.6.2
- メモリ割り当て 12 GB
- CPU 割り当て 6 コア
Livebook 0.8.0 の Docker イメージを元にしたコンテナで動かしました
コンテナ定義はこちらを参照
セットアップ
必要なモジュールをインストールして EXLA.Backend で Nx が動くようにします
Mix.install(
[
{:bumblebee, "~> 0.1"},
{:nx, "~> 0.4"},
{:exla, "~> 0.4"},
{:kino, "~> 0.8"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
コンテナで動かしている場合、キャッシュディレクトリーを指定した方が都合がいいです
※詳細はこの記事を見てください
cache_dir = "/tmp/bumblebee_cache"
モデルのダウンロード
モデルファイルを Haggin Face からダウンロードしてきて読み込みます
必要な場合は cache_dir
を指定します
{:ok, roberta} =
Bumblebee.load_model({
:hf,
"deepset/roberta-base-squad2",
cache_dir: cache_dir
})
{:ok, tokenizer} =
Bumblebee.load_tokenizer({
:hf,
"roberta-base",
cache_dir: cache_dir
})
文章の準備
質問文と、質問の対象になる文章を入力します
question_input =
Kino.Input.text("QUESTION",
default: "What industries does Elixir help?"
)
context_input =
Kino.Input.textarea("CONTEXT",
default:
~s/Elixir is a dynamic, functional language for building scalable and maintainable applications. Elixir runs on the Erlang VM, known for creating low-latency, distributed, and fault-tolerant systems. These capabilities and Elixir tooling allow developers to be productive in several domains, such as web development, embedded software, data pipelines, and multimedia processing, across a wide range of industries./
)
入力された文章を取得します
question = Kino.Input.read(question_input)
context = Kino.Input.read(context_input)
推論の実行
推論して結果を表示します
この機能は Nx.Serving 用の関数がまだ用意されていないので、前処理、後処理を自分で書きます
inputs = Bumblebee.apply_tokenizer(tokenizer, {question, context})
outputs = Axon.predict(roberta.model, roberta.params, inputs)
answer_start_index =
outputs.start_logits
|> Nx.argmax()
|> Nx.to_number()
answer_end_index =
outputs.end_logits
|> Nx.argmax()
|> Nx.to_number()
answer_tokens =
inputs["input_ids"][[0, answer_start_index..answer_end_index]]
|> Nx.to_flat_list()
Bumblebee.Tokenizer.decode(tokenizer, answer_tokens)
ちゃんと質問に答えました
まとめ
これからの Bumblebee に期待が持てますね