LoginSignup
13
2

More than 1 year has passed since last update.

Livebook で Bumblebee から GPT2 を使って文章の続きを生成する

Last updated at Posted at 2022-12-13

はじめに

Bumblebee の公式サンプルに従ってモデルを実行してみるシリーズです

今回は GPT2 という自然言語処理 AI モデルを使って、途中まで書かれた文章の続きを生成します

このシリーズの記事

Bumblebee の公式サンプル

実装の全文はこちら

実行環境

  • MacBook Pro 13 inchi
    • 2.4 GHz クアッドコアIntel Core i5
    • 16 GB 2133 MHz LPDDR3
  • macOS Ventura 13.0.1
  • Rancher Desktop 1.6.2
    • メモリ割り当て 12 GB
    • CPU 割り当て 6 コア

Livebook 0.8.0 の Docker イメージを元にしたコンテナで動かしました

コンテナ定義はこちらを参照

セットアップ

必要なモジュールをインストールして EXLA.Backend で Nx が動くようにします

Mix.install(
  [
    {:bumblebee, "~> 0.1"},
    {:nx, "~> 0.4"},
    {:exla, "~> 0.4"},
    {:kino, "~> 0.8"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

コンテナで動かしている場合、キャッシュディレクトリーを指定した方が都合がいいです

※詳細はこの記事を見てください

cache_dir = "/tmp/bumblebee_cache"

モデルのダウンロード

モデルファイルを Haggin Face からダウンロードしてきて読み込みます

必要な場合は cache_dir を指定します

{:ok, gpt2} =
  Bumblebee.load_model({
    :hf,
    "gpt2",
    cache_dir: cache_dir
  })
{:ok, tokenizer} =
  Bumblebee.load_tokenizer({
    :hf,
    "gpt2",
    cache_dir: cache_dir
  })

サービスの提供

Bumblebee.Text.generation で文章生成サービスを提供します

max_new_tokens は追加される単語の上限数です

serving = Bumblebee.Text.generation(gpt2, tokenizer, max_new_tokens: 10)

マスクされた文章の準備

テキスト入力を作り、途切れた文章を入力します

text_input = Kino.Input.text("TEXT", default: "Robots have gained human rights and")

スクリーンショット 2022-12-13 16.39.04.png

入力された文章を取得します

text = Kino.Input.read(text_input)

スクリーンショット 2022-12-13 16.39.33.png

推論の実行

推論して結果を表示します

serving
|> Nx.Serving.run(text)
|> then(&(&1.results))

スクリーンショット 2022-12-13 16.40.05.png

他のモデル

GPT2 のより重いモデルでも推論できます

serve_model = fn repository_id ->
  {:ok, model} =
    Bumblebee.load_model({
      :hf,
      repository_id,
      cache_dir: cache_dir
    })

  {:ok, tokenizer} =
    Bumblebee.load_tokenizer({
      :hf,
      repository_id,
      cache_dir: cache_dir
    })

  Bumblebee.Text.generation(model, tokenizer, max_new_tokens: 10)
end
"gpt2-medium"
|> serve_model.()
|> Nx.Serving.run(text)
|> then(&(&1.results))

スクリーンショット 2022-12-13 16.41.07.png

"gpt2-large"
|> serve_model.()
|> Nx.Serving.run(text)
|> then(&(&1.results))

スクリーンショット 2022-12-13 16.59.51.png

まとめ

公式 GitHub の Pull Request を見ていると、他にも色々なモデルが追加されようとしています

まだまだこの勢いは止まらないですね

13
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
2