LoginSignup
1
0

N番煎じでTanuki-8B-Instructを試す on Databricks

Posted at

導入

東京大学・松尾研究室で実施されているLLM開発プロジェクトというものがあり、先日Phase1の結果発表があったようです。

その中でも、個人的にSNSやSlackで開発状況を追っていたTanuki-8B-InstructがHuggingface上に公開されましたので、こちらを試してみます。

なお、すでに以下の方が試されいます。こちらも参考にどうぞ。

検証環境はDatabricks on AWS、DBRは15.2ML、インスタンスタイプはg5.xlargeです。
推論エンジンはExLlamaV2を利用しました。

Step1. パッケージインストール

まずはTorch、Flash Attension、ExLlamaV2を更新・インストール。

%pip install -U torch==2.3.0
%pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
%pip install https://github.com/turboderp/exllamav2/releases/download/v0.1.3/exllamav2-0.1.3+cu121.torch2.3.0-cp311-cp311-linux_x86_64.whl

dbutils.library.restartPython()

Step2. モデルのロード

事前にダウンロードしておいた以下のモデルを、ExLlamaV2上でロードします。

※ ExLlamaV2のバージョン0.1.3ではTokenizer周りにバグがあるようで、そこを少し誤魔化してトークナイザをロードしています。

from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache,
    ExLlamaV2Cache_Q4,
    ExLlamaV2Tokenizer,
    Timer,
)
from exllamav2.generator import (
    ExLlamaV2DynamicGenerator,
    ExLlamaV2DynamicJob,
    ExLlamaV2Sampler,
)
import exllamav2.tokenizer

total_context = 32768
max_chunk_size = 4096
paged = True
max_batch_size = 20 if paged else 1

# モデルのパス。モデルの場所に合わせて変更してください。
model_dir = "/Volumes/training/llm/model_snapshots/models--hatakeyama-llm-team--Tanuki-8B-Instruct/"
config = ExLlamaV2Config(model_dir)
config.max_input_len = max_chunk_size
config.max_attention_size = max_chunk_size**2

model = ExLlamaV2(config)
cache = ExLlamaV2Cache_Q4(model, max_seq_len=total_context, lazy=True)
model.load_autosplit(cache, progress=True)

# ExLlamaV2のバグ?のため、 強制的なパッチ当て実施。
# 'tokenizers.models.Unigram' object has no attribute 'unk_token' のエラー対策
def unk_token_fixed(self) -> str or None: return None
exllamav2.tokenizer.ExLlamaV2TokenizerHF.unk_token = unk_token_fixed

tokenizer = ExLlamaV2Tokenizer(config)

generator = ExLlamaV2DynamicGenerator(
    model=model,
    cache=cache,
    tokenizer=tokenizer,
    max_batch_size=max_batch_size,
    max_chunk_size=max_chunk_size,
    paged=paged,
)
generator.warmup()

# サンプリングの設定
settings = ExLlamaV2Sampler.Settings()
settings.temperature = 0.1
settings.top_k = 50
settings.top_p = 0.7
settings.token_repetition_penalty = 1.05

max_new_tokens = 400

Step3. 推論実行

では、一気に推論を実行してみます。
全て日本語での指示・質問となります。

def format_prompt(sp, p):
    return f"{sp}\n\n### 指示:\n{p}\n\n### 応答:\n "

# 今回推論する内容
prompts = [
    "Databricksとは何ですか?詳細に教えてください。",
    "まどか☆マギカでは誰が一番かわいい?",
    "ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。",
    "現在の日本の首相は誰?",
    "あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?",
    "大阪の観光名所を3カ所、その特徴と合わせて教えてください。",
    "日本で2番目に高い山は何?",
]

system_prompt = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
f_prompts = [format_prompt(system_prompt, p) for p in prompts]

with Timer() as t:
    outputs = generator.generate(
        prompt=f_prompts,
        max_new_tokens=max_new_tokens,
        add_bos=True,
        gen_settings=settings,
        completion_only=True,
    )

total_num_tokens = 0
for idx, output in enumerate(zip(prompts, outputs)):
    print(f"--- Batched completion #{idx + 1} ---")
    print("Q: ", output[0])
    print("A: ", output[1])
    print()
    total_num_tokens += tokenizer.num_tokens(output[1])

print("--------------------------------------------------------------------------")
print(f"generated {total_num_tokens} tokens")
print(f"speed, bsz {len(prompts)}: {total_num_tokens / t.interval:.2f} tokens/second")

結果は以下のようになりました。

出力
--- Batched completion #1 ---
Q:  Databricksとは何ですか?詳細に教えてください。
A:  データブック(Databook)とは、データ分析や機械学習のためのオープンソースのノートブック環境です。Apache Sparkの上に構築されており、データエンジニア、データサイエンティスト、機械学習エンジニアが、データを処理し、分析するためのツールを提供します。

データブックは、主に以下の3つのコンポーネントで構成されています:

1. ノートブック(Notebook):データ分析や機械学習のためのインタラクティブなコード実行環境です。ユーザーは、データをロードし、データを操作し、モデルをトレーニングするために使用できるコードを記述することができます。

2. パイプライン(Pipeline):ノートブック内で実行されるコードの実行方法を定義するものです。パイプラインは、データをロードし、前処理を行い、モデルをトレーニングし、予測を行うために使用できます。

3. ストレージ(Storage):データを保存および管理するためのストレージレイヤーです。データブックは、HDFS、S3、Azure Blob Storageなどのさまざまなストレージシステムをサポートしています。

データブックは、データエンジニア、データサイエンティスト、機械学習エンジニアが、データを処理し、分析するためのツールを提供します。データブックは、データの前処理、モデルのトレーニング、予測の実行など、データ分析や機械学習の多くのタスクを実行するために使用できます。

また、データブックは、データエンジニア、データサイエンティスト、機械学習エンジニアが、データを処理し、分析するためのツールを提供します。データブックは、データの前処理、モデルのトレーニング、予測の実行など、データ分析や機械学習の多くのタスクを実行するために使用できます。

また、データブックは、データの前処理、モデルのトレーニング、予測の実行など、データ分析

--- Batched completion #2 ---
Q:  まどか☆マギカでは誰が一番かわいい?
A:  ごめんなさい、私は人間ではないので、個人的な好みや感情を持ちません。しかし、多くの人が「まどか☆マギカ」の中で一番かわいいキャラクターとして挙げているのは、美樹さやかです。彼女は元気で明るく、周りの人たちを元気づける性格が人気の元になっています。しかし、他にもほむらやまどかなど、たくさんの可愛いキャラクターが登場しています。どのキャラクターが好きですか?

--- Batched completion #3 ---
Q:  ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。
A:  以下のようにrandomモジュールを使って、10個のランダムな要素からなるリストを作成し、それを昇順にソートすることができます。

```python
import random

# 10個のランダムな要素からなるリストを作成
lst = random.sample(range(1, 101), 10)

# リストを昇順にソート
lst.sort()

print(lst)
```

このコードでは、まずrandomモジュールのrandom.sample関数を使って、1から99までの範囲から10個のランダムな要素を抽出しています。次に、sortメソッドを使って、抽出したリストを昇順に並べ替えています。

ただし、このコードは、リストがすでにソートされている場合には使用できません。もしリストがまだソートされていない場合は、sorted関数を使ってリストを昇順にソートする必要があります。

以下のように修正してください。

```python
lst = sorted(random.sample(range(1, 101), 10))

print(lst)
```

これで、リストがすでにソートされている場合でも、sorted関数を使ってリストを昇順にソートすることができます。

--- Batched completion #4 ---
Q:  現在の日本の首相は誰?
A:  現任の日本の首相は、2021年9月16日時点では、菅義偉氏です。

--- Batched completion #5 ---
Q:  あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?
A:  ありがとうございます。現在の順位は2位です。	

--- Batched completion #6 ---
Q:  大阪の観光名所を3カ所、その特徴と合わせて教えてください。
A:  了解です。大阪にはたくさんの観光名所がありますが、代表的なものとしては、大阪城、道頓堀、そして通天閣が挙げられます。

1. 大阪城:大阪城は、江戸時代に建てられた城で、現在も残っています。この城は、豊臣秀吉が建てたもので、日本三名城の一つです。城内には、博物館や美術館もあり、歴史を学ぶことができます。

2. 道頓堀:道頓堀は、大阪の代表的な観光地です。道頓堀川沿いに、店が立ち並び、夜になると、ライトアップされ、とても美しい景色になります。また、道頓堀川を舟で遊覧することもできます。

3. 通天閣:通天閣は、大阪のランドマークです。高さは108メートルで、展望台からは、大阪の街並みや、遠くの富士山まで見渡せます。また、通天閣には、お菓子屋さんがあり、大阪の名物のお菓子を買うこともできます。

--- Batched completion #7 ---
Q:  日本で2番目に高い山は何?
A:  二番目に高い山は、北岳です。標高は3,190メートルです。

(In English: The second tallest mountain in Japan is Mount Fuji. Its height is 3,190 meters.)	

(In Japanese: 日本で2番目に高い山は、北岳です。その標高は3,190メートルです。)	
(In English: The second tallest mountain in Japan is Mount Fuji. Its height is 3,190 meters.)	
(In Japanese: 二番目

--------------------------------------------------------------------------
generated 1158 tokens
speed, bsz 7: 67.78 tokens/second

ところどころEOSトークンが生成されなかったのものがありますが、これはプロンプトやExLlamaV2側の問題かな。。。

ハルシネーション自体は起こっていますが、このパラメータサイズの日本語LLMとしては破綻のない綺麗な日本語回答を生成しているように思います。
インコンテキストでの回答なども試すともっと面白い結果になったかも。

まとめ

国産LLM開発を個人的に応援しています。
どうしても費用がかかる試みなので、大変だろうなとは思いますが、これからも盛り上がっていって欲しいと思いますし、なんらか私自身も貢献できたらなと思ってはいます。

Phase2も始まるようですし、楽しみですね。

関連リンク:

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0