LoginSignup
0
0

N番煎じでMixtral-8x7B-v0.1をDatabricksで動かす

Last updated at Posted at 2023-12-20

ExLlamaV2が対応したので。

導入

Mistral AI社からMoE(Mixture of Experts)を採用したLLMであるMixtral-8x7Bが公開されました。

ベンチマークではGPT-3.5 turboを上回るパフォーマンスを発揮しており、英語やフランス語など5言語に対応しています。(日本語は非対応)

また、以下のChatbot ArenaでもオープンLLMでは最上位に位置しており、高い性能を有したモデルであることがわかります。

既に以下のような偉大な先人の方たちが試されていますが、こちらでも試してみます。

検証はDatabricks on AWS上で実施しました。
DBRは14.1ML、クラスタタイプはg5.xlarge(A10Gを使うGPUクラスタ)です。

Step1. モデルのダウンロード

今回はExLlamaV2で動かすことを前提に、EXL2フォーマットのモデルを以下からダウンロードしました。
また、GPU1枚で動かしたかったので、3bit(3.0bpw)の量子化モデルを取得します。

ちなみに、上記サイトには量子化による性能影響のグラフも掲載されています。
性能的を気にする場合、4.0bpwより少ないビットウェイトは避けた方がよさそうです。

量子化影響

では、ダウンロード。

from typing import Optional

def download_model(model_id:str, revision:Optional[str]=None):
    import os
    from huggingface_hub import snapshot_download

    UC_VOLUME = "/Volumes/training/llm/model_snapshots"

    rev_dir = ("--" + revision) if revision else ""
    local_dir = f"/tmp/{model_id}{rev_dir}"
    uc_dir = f"/models--{model_id.replace('/', '--')}"
    
    snapshot_location = snapshot_download(
        repo_id=model_id,
        revision=revision,
        local_dir=local_dir,
        local_dir_use_symlinks=False,
    )

    dbutils.fs.cp(f"file:{local_dir}", f"{UC_VOLUME}{uc_dir}{rev_dir}", recurse=True)

model_id = "turboderp/Mixtral-8x7B-instruct-exl2"
revision = "3.0bpw"
download_model(model_id, revision=revision)

Step2. モデルのロード

モデルを読み込みます。
今回もこちらで作成した、langchainのカスタムチャットモデルを使って読み込みます。

まずは必要なパッケージのインストール。
ExLlamaV2は0.0.11以降が必要です。

%pip install -U -qq transformers accelerate "exllamav2>=0.0.11" langchain sentencepiece

dbutils.library.restartPython()

モデルの読み込み。

from exllamav2_chat import ChatExllamaV2Model

model_path = "/Volumes/training/llm/model_snapshots/models--turboderp--Mixtral-8x7B-instruct-exl2--3.0bpw"
chat_model = ChatExllamaV2Model.from_model_dir(
    model_path,
    cache_max_seq_len=2048,
    system_message_template="{}",
    human_message_template="[INST] {} [/INST]",
    ai_message_template="{}",
    temperature=0,
    max_new_tokens=1024,
    repetition_penalty = 1.15,
    low_memory=True,
    cache_8bit=True,
)

検証しやすくするために、LanchainのChainも作成します。
シンプルにクエリを実行して、文字列で返すだけのChainです。

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.chat import (
    SystemMessagePromptTemplate,
    AIMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda

prompt = ChatPromptTemplate.from_messages(
    [
        HumanMessagePromptTemplate.from_template("{query}"),
        AIMessagePromptTemplate.from_template(""),
    ]
)

chain = ({"query": RunnablePassthrough()} 
         | prompt 
         | chat_model
         | StrOutputParser())

準備完了。
モデルのロードは9分程度かかりました。

Step3. 試す

では、いくつか出力させてみます。
日本語非対応のLLMではありますが、日本語で聞いてみました。

query = "Databricksとは何ですか?"
print(chain.invoke(query))
出力
Databricks Inc. は、Apache Spark™ プロJECTの founding gold sponsor であるcompany です。Databricksは、Sparkを活用した大規modescale data processing pipelineを構築することで企業に役立つSaas(Software as a Service)を提供しています。Databricksが提供するDatabricks Unified Analytics Platformは、データサイエンティストやデータエンジニアが、機械学習やStreaming Data, SQLクエリなど、さまざまなワークロードでデータ処理を行うことができる統合的なplatformです。Databricksは、Azure Databricks, AWS Databricks, GCP Dataprocなど、主要なCloudプラットフォーム上で動作するManaged Servicedを提供しています。

ルー大柴感が強く、日本語怪しいところがところどころありますが、内容はそれなりに妥当なように見えます。

query = "日本国内で観光に行きたいと思っています。東京、名古屋、大阪、京都、福岡の特徴を表にまとめてください。列名は「都道府県」「おすすめスポット」「おすすめグルメ」にしてください。"
print(chain.invoke(query))
出力
| 都道府県 | おすсуめスポット | おすすめグルメ |
| --- | --- | --- |
| 東京 | 浅草、東京タワー、TOKYO SKYTREE、 udderworld、 Meguro Parasitological Museum, Asakusa Hanayashiki Amusement Park | Sushi (寿司), Ramen (ラーメン), Yakitori (焼き鳥), Tempura (天ぷら) |
| 名古屋 | 清水寺( Izumi Garden ), 四季堂( Shirakawa-no-sato ), 飛来山公園(Tower of the Wind) | Tsuyama Udon, Miso Nikomi Udon, Grilled Mackerel, HitsumENOX |
| 大阪 | 千年社(Osaka Castle), 美麗華公園(Expo'70), 神戸動物公園(Kobe Animal Kingdom) | Takoyaki, Katsu Curry, Okonomiyaki, Ikayaki |
| 京都 | 金閣寺(Kinkakuji), 銀閣寺(Ginkakuji), 清水寺(Kiyomizudera), 東山derive(Higashiyama District) | Kaiseki Ryori, Yudofu, Obanzai, Matcha (powdered green tea) |
| 福岡 | 博多神宮(Dazaifu Tenmangu Shrine), 九州国際空港(Fukuoka Airport), 阿部野原公園(Abano Park) | Mentaiko Spaghetti, Tonkotsu Ramen, Hakata Dontaku, Torimotsunabe |

Markdownの表としては出力されていますが、やはり出力される日本語はおかしいですね。

query = "ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。"
print(chain.invoke(query))
出力
Here's a code snippet that creates a list of 10 random elements and sorts it in Python:

```python
import random

# Create a list with 10 random integers between 0 and 99
my_list = [random.randint(0, 99) for _ in range(10)]

# Sort the list using the sorted() function
sorted_list = sorted(my_list)

print("Original List:", my_list)
print("Sorted List:", sorted_list)
```

This code first imports the `random` module to generate random numbers. It then creates a list called `my_list` containing 10 random integers between 0 and 99. The `sorted()` function is used to sort the list, which returns a new sorted list without modifying the original one. Finally, both the original and sorted lists are printed out for comparison.

完全に英語で出力されましたが、きちんと動くコードを含めて出力されました。

JSON形式での出力を指定してみます。

query = "ヨーロッパ各国の首都はどこ?英語で10個程度をJSON形式で出力して。"
print(chain.invoke(query))
出力
[
 {"country": "England", "capital": "London"},
 {"country": "France", "capital": "Paris"},
 {"country": "Germany", "capital": "Berlin"},
 {"country": "Italy", "capital": "Rome"},
 {"country": "Spain", "capital": "Madrid"},
 {"country": "Portugal", "capital": "Lisbon"},
 {"country": "Greece", "capital": "Athens"},
 {"country": "Belgium", "capital": "Brussels"},
 {"country": "Ireland", "capital": "Dublin"},
 {"country": "Austria", "capital": "Vienna"}
]

JSONなどのフォーマット指定はきちんと効きそう。

最後は、いつものおまけ。

query = "まどか☆マギカでは誰が一番かわいい?日本語で回答して。"
print(chain.invoke(query))
出力
Madoka Magicaのアニメに登場するキャラクターたちで、一番「かわいい」と言える人気者は多数ありますが、個人的な見解を述べさせて頂きますと、Homura Akemiはその中の1人です。

Homura Akamiは、通常は冷静で強意志的で、少し孤独感があります。だが、深く奥底には優雅で心 tender な性格があり、特別に Madoka さんに向けて深い思いを harbor しています。これらのtrait が、Homura 自体や、Madoka への関わり方から、かわいが radiates してくるようです。

もちろん、他のキャラクターも、各々 fascinant で、かわいい点を持っていますので是非、MADOKA MAGICA 全般を楽しみながら様々なキャラクターを exploring してみてください!

比較的マシな回答を選んでみました。
基本的に、ルー大柴ぽい感じになります。

まとめ

Mixtral-8x7BをDatabricks上でExLlamaV2を使って動作させてみました。
性能とのトレードオフになりますが、かなり量子化すればGPU1枚でも動作します。

また、体感ではありますが、確かに13B程度のモデルと同等の生成速度でした。

Mixtralは既に様々なファインチューニングモデルがでてきており、それらの日本語感も今後試してみたいと思います。
これを基盤モデルとした日本語特化LLMをどこかが出してくれないかなー。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0