LoginSignup
2

posted at

updated at

Livebook で Bumblebee から GPT2 を使って文章の続きを生成する

はじめに

Bumblebee の公式サンプルに従ってモデルを実行してみるシリーズです

今回は GPT2 という自然言語処理 AI モデルを使って、途中まで書かれた文章の続きを生成します

このシリーズの記事

Bumblebee の公式サンプル

実装の全文はこちら

実行環境

  • MacBook Pro 13 inchi
    • 2.4 GHz クアッドコアIntel Core i5
    • 16 GB 2133 MHz LPDDR3
  • macOS Ventura 13.0.1
  • Rancher Desktop 1.6.2
    • メモリ割り当て 12 GB
    • CPU 割り当て 6 コア

Livebook 0.8.0 の Docker イメージを元にしたコンテナで動かしました

コンテナ定義はこちらを参照

セットアップ

必要なモジュールをインストールして EXLA.Backend で Nx が動くようにします

Mix.install(
  [
    {:bumblebee, "~> 0.1"},
    {:nx, "~> 0.4"},
    {:exla, "~> 0.4"},
    {:kino, "~> 0.8"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

コンテナで動かしている場合、キャッシュディレクトリーを指定した方が都合がいいです

※詳細はこの記事を見てください

cache_dir = "/tmp/bumblebee_cache"

モデルのダウンロード

モデルファイルを Haggin Face からダウンロードしてきて読み込みます

必要な場合は cache_dir を指定します

{:ok, gpt2} =
  Bumblebee.load_model({
    :hf,
    "gpt2",
    cache_dir: cache_dir
  })
{:ok, tokenizer} =
  Bumblebee.load_tokenizer({
    :hf,
    "gpt2",
    cache_dir: cache_dir
  })

サービスの提供

Bumblebee.Text.generation で文章生成サービスを提供します

max_new_tokens は追加される単語の上限数です

serving = Bumblebee.Text.generation(gpt2, tokenizer, max_new_tokens: 10)

マスクされた文章の準備

テキスト入力を作り、途切れた文章を入力します

text_input = Kino.Input.text("TEXT", default: "Robots have gained human rights and")

スクリーンショット 2022-12-13 16.39.04.png

入力された文章を取得します

text = Kino.Input.read(text_input)

スクリーンショット 2022-12-13 16.39.33.png

推論の実行

推論して結果を表示します

serving
|> Nx.Serving.run(text)
|> then(&(&1.results))

スクリーンショット 2022-12-13 16.40.05.png

他のモデル

GPT2 のより重いモデルでも推論できます

serve_model = fn repository_id ->
  {:ok, model} =
    Bumblebee.load_model({
      :hf,
      repository_id,
      cache_dir: cache_dir
    })

  {:ok, tokenizer} =
    Bumblebee.load_tokenizer({
      :hf,
      repository_id,
      cache_dir: cache_dir
    })

  Bumblebee.Text.generation(model, tokenizer, max_new_tokens: 10)
end
"gpt2-medium"
|> serve_model.()
|> Nx.Serving.run(text)
|> then(&(&1.results))

スクリーンショット 2022-12-13 16.41.07.png

"gpt2-large"
|> serve_model.()
|> Nx.Serving.run(text)
|> then(&(&1.results))

スクリーンショット 2022-12-13 16.59.51.png

まとめ

公式 GitHub の Pull Request を見ていると、他にも色々なモデルが追加されようとしています

まだまだこの勢いは止まらないですね

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
2