LoginSignup
0
0

N番煎じでLLaMA-Pro-8BをDatabricksで動かす

Posted at

面白い試みだなと興味を持ったので。

導入

Tencent ARC社からLlama2 7Bを拡張(?)したLLM、LLaMA-Pro-8Bが公開されました。

LLaMA-Pro-8BはLlama2 7BにTransformerブロックを追加することで拡張を果たしており、特にプログラムコード生成と計算を強化したモデルのようです。
オリジナルに対して、1.3Bのパラメータが追加されています。

こちらが本モデルに関する論文のようです。(まだ未読)

こちらの論文からの引用ですが、表結果として、もとのLlama2 7B等に比べてGSM8Kなど一部のベンチマークにおいて大きく性能が向上しています。

image.png

面白いやり方だなと思ったので、どんなものなのか、とりあえずDatabricks on AWS上で動かしてみました。
Databricksで試すにあたって、DBRは14.1ML、クラスタタイプ g5.xlarge(AWS)で確認しました。

Step1. パッケージのインストール

必要なパッケージをインストール。
今回もExLlama V2を使って推論します。

%pip install -U -qq transformers accelerate "exllamav2>=0.0.11" langchain sentencepiece

dbutils.library.restartPython()

Step2. モデルのダウンロード

モデルのスナップショットをダウンロードし、Unity Catalog Volumes内に保管。
InstructチューニングされたモデルをGPTQ形式で4bit量子化されたものを利用しました。


from typing import Optional

def download_model(model_id:str, revision:Optional[str]=None):
    import os
    from huggingface_hub import snapshot_download

    UC_VOLUME = "/Volumes/training/llm/model_snapshots"

    rev_dir = ("--" + revision) if revision else ""
    local_dir = f"/tmp/{model_id}{rev_dir}"
    uc_dir = f"/models--{model_id.replace('/', '--')}"
    
    snapshot_location = snapshot_download(
        repo_id=model_id,
        revision=revision,
        local_dir=local_dir,
        local_dir_use_symlinks=False,
    )

    dbutils.fs.cp(f"file:{local_dir}", f"{UC_VOLUME}{uc_dir}{rev_dir}", recurse=True)

model_id = "TheBloke/LLaMA-Pro-8B-Instruct-GPTQ"
download_model(model_id)

Step3. モデルのロード

モデルを読み込みます。
今回もこちらで作成した、langchainのカスタムチャットモデルを使って読み込みます。

from exllamav2_chat import ChatExllamaV2Model

model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--LLaMA-Pro-8B-Instruct-GPTQ"
chat_model = ChatExllamaV2Model.from_model_dir(
    model_path,
    cache_max_seq_len=4096,
    system_message_template="<|system|>\n{}",
    human_message_template="<|user|>\n{}",
    ai_message_template="<|assistant|>\n{}",
    temperature=0,
    top_p=0.0001,
    max_new_tokens=1024,
    repetition_penalty = 1.15,
    low_memory=True,
    cache_8bit=True,
)

検証しやすくするために、LanchainのChainも作成します。
シンプルにクエリを実行して、文字列で返すだけのChainです。

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.chat import (
    SystemMessagePromptTemplate,
    AIMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda

prompt = ChatPromptTemplate.from_messages(
    [
        HumanMessagePromptTemplate.from_template("{query}"),
        AIMessagePromptTemplate.from_template(" "),
    ]
)

chain = ({"query": RunnablePassthrough()} 
         | prompt 
         | chat_model
         | StrOutputParser())

Step4. 試す

では、いくつか出力させてみます。

query = "Databricksとは何ですか?日本語で答えてください。"
print(chain.invoke(query))
出力
 Databricks は、Apache Sparkを使用したデータ処理のためのソフトウェアベースのプラットフォームです。Spark StreamingやHadoopなどのオープンソース技術を活用し、大規模なデータセットを効率的に処理することが可能です。また、Python、R、Java、およびJavaScriptなどの多くの言語をサポートしており、ユーザーが選択すれば、そのほぼ任意の言語でデータ処理を実行できます。

このあたりはLlama2 7Bでも同様に得られた認識。

query = "東京、名古屋、大阪、京都、福岡の特徴を表にまとめてください。列名は「都道府県」「おすすめスポット」「おすすめグルメ」にしてください。"
print(chain.invoke(query))
出力
これらの都市が持つ特徵や観光地なども挙げると以下のようになります。

|都道府県 | おすすめスポット | おすすめグルメ |
| --- | --- | --- |
|東京 | 東京タワー、新宿、渋谷、赤坂 | 焼肉店「吉野家」、中華料理「王将」、カフェ・ラ・マドリアンヌ |
|名古屋 | 名古屋城、金山公園、栄、桶売町 | 名古屋味廚房、豆腐屋「喜多見」、米朝食べ放飲 |
|大阪 | 梅田、天保山、北海道館、難波 | 和風亭「花月楼」、居酒屋「加賀屋」、串本 |
|京都 | 上毛寺、四条河原、銀閣堂 | 茶封子「三条」、燕麗香、甜荷麟 |
|福岡 | 西鉄ビルディング、天神、博多、小味 | 焼肉店「吉野家」、中国料理「一期一会」、カレーハウス「カレーハウスキングダム」 |

テンセント社のデータセットを使っているためか、少し中国ナイズされたような回答ですね。
ただ、表としては破綻せずに出力できています。

query = "ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。"
print(chain.invoke(query))
出力
これは、Pythonに関連付けられたWebサイト「Real Python」で簡単な例が提供されています。

```python
import random

# Create a list of 10 random elements.
random_list = [random.randint(1, 10) for _ in range(10)]
print("Original List:", random_list)

# Sort the list in ascending order.
sorted_list = sorted(random_list)
print("Sorted List: ", sorted_list)
```

日本語はやはり少し怪しいかな。コードは正しいものができています。
ちなみに、「Real Python」はここのことですかね。

query = "ヨーロッパ各国の首都はどこ?国名と首都名のペアを英語で10個程度をJSON形式で結果のみ出力して。"
print(chain.invoke(query))
出力
{
"Austria": "Vienna",
"Belgium": "Brussels",
"Bulgaria": "Sofia",
"Croatia": "Zagreb",
"Cyprus": "Nicosia",
"Czech Republic": "Prague",
"Denmark": "Copenhagen",
"Estonia": "Tallinn",
"Finland": "Helsinki",
"France": "Paris",
"Germany": "Berlin",
"Greece": "Athens",
"Ireland": "Dublin",
"Italy": "Rome",
"Latvia": "Riga",
"Lithuania": "Vilnius",
"Luxembourg": "Luxembourg City",
"Malta": "Valletta",
"Moldova": "Chisinau",
"Monaco": "Monaco",
"Montenegro": "Podgorica",
"Netherlands": "Amsterdam",
"North Macedonia": "Skopje",
"Poland": "Warsaw",
"Portugal": "Lisbon",
"Romania": "Bucharest",
"Russia": "Moscow",
"San Marino": "San Marino",
"Serbia": "Belgrade",
"Slovakia": "Bratislava",
"Slovenia": "Ljubljana",
"Spain": "Madrid",
"Sweden": "Stockholm",
"Switzerland": "Bern",
"Ukraine": "Kyiv",
"United Kingdom": "London",
"Vatican City": "Vatican City"
}

10個程度という指示は無視されてるなあ。
JSONというフォーマット指定はきちんと効いていました。

最後は、いつものおまけ。

query = "まどか☆マギカでは誰が一番かわいい?"
print(chain.invoke(query))
出力
それは個人的な意見です。ただ、アニメや漫画のキャラクターに関しては、美少女ゲームと同じように、「可愛くてもっと」ということを重要視される場合があります。

うーん、これはイマイチ。

まとめ

LLaMA-Pro-8Bを試してみました。
確かにLlama2 7B chatよりはよい感じの結果を得られているように思いました。
コード系もLlama2 7Bに比べてよい結果です。

ただ、Block Expansionがファインチューニングと比べて明確によい効果があるのかとか、Mistralとの比較ではどうなのか、などいろいろ気になるところです。
あと、SOLAR-10.7Bがdepth up-scaling (DUS)という手法でパラメータ数を増やしている(?)のですが、こことの違いとかも。

もし、ブロック拡張がLlama2以外でも同様にできる&効果が大きいのであれば、基盤モデルを性能向上させる手段として非常に面白そうだなと思っています。

2024年も様々な技術発展が楽しみですね。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0