1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

N番煎じで東京工業大学のSwallow-MSをEXL2量子化してDatabricksで動かす

Posted at

出てくるときは一気に出てくるなあ。。。

導入

東京工業大学と国立研究開発法人産業技術総合研究所の共同チームがMistral 7B/Mixtral 8x7Bを継続事前学習したLLMであるSwallow-MS/MXを公開しました。
どちらも同規模の日本語LLMの中で最高レベルの性能となっているようです。
また、Mistralのライセンスを継承しており、Apache 2.0ライセンスで公開されています。
※ 現在公開されているSwallow-MS/MXはどちらもベースモデルであり、指示チューニングされたモデルではないことに注意が必要です。

開発にあたっての詳細は下記に記載されております。
非常に詳細かつ丁寧な内容でわかりやすいです。

また、以下のように既に試されている方も多数いらっしゃいますが、N番煎じでこちらも試してみます。

検証はDatabricks on AWSで行いました。
DBRは14.3ML、推論はクラスタタイプg4dn.xlargeです。

モデルはSwallow-MSを試します。
また、今回の推論はExLlamaV2を使ってEXL2形式で量子化して推論します。

Step1. モデルをダウンロード

huggingfaceから以下のモデルをダウンロード。

コードを実行することでUnity CatalogのVolumesに保管します。

%pip install -U transformers accelerate
dbutils.library.restartPython()

from typing import Optional

def download_model(model_id:str, revision:Optional[str]=None):
    import os
    from huggingface_hub import snapshot_download

    UC_VOLUME = "/Volumes/training/llm/model_snapshots"

    rev_dir = ("--" + revision) if revision else ""
    local_dir = f"/tmp/{model_id}{rev_dir}"
    uc_dir = f"/models--{model_id.replace('/', '--')}"
    
    snapshot_location = snapshot_download(
        repo_id=model_id,
        revision=revision,
        local_dir=local_dir,
        local_dir_use_symlinks=False,
        force_download=True,
    )

    dbutils.fs.cp(f"file:{local_dir}", f"{UC_VOLUME}{uc_dir}{rev_dir}", recurse=True)

model_id = "tokyotech-llm/Swallow-MS-7b-v0.1"
download_model(model_id)

Step2. モデルをEXL2で量子化

ダウンロードしたモデルをEXL2形式に量子化して保存します。
細かい手順は以下を参照ください。Swallow-MSはsafetensor形式で提供されていますので、そのままコンバートが可能です。

量子化のビット数は6.5bpwで変換しました。また、キャリブレーションデータはデフォルトのものを利用しています。
変換後のファイルサイズは5.74GBでした。

Step3. 推論準備

変換したモデルをロードし、推論準備します。

まずはExLlamaV2に必要なパッケージをインストール。

%pip install -U -qq transformers accelerate "exllamav2>=0.0.15" langchain

dbutils.library.restartPython()

こちらのカスタムクラスを再利用し、変換したモデルファイルをロード。

from exllamav2_chat import ChatExllamaV2Model

model_path = "/Volumes/training/llm/model_snapshots/models--tokyotech-llm--Swallow-MS-7b-v0.1-6.5bpw-h6-exl2/"

model = ChatExllamaV2Model.from_model_dir(
    model_path,
    cache_max_seq_len=4096,
    system_message_template="{}",    
    human_message_template="{}",
    ai_message_template="{}",
    temperature=0,
    top_p=0.0001,
    max_new_tokens=512,
    repetition_penalty = 1.15,
    low_memory=True,
    cache_4bit=True,
)

Step4. 推論

では実際に推論してみます。

量子化しているため、オリジナルモデルより性能が劣化していると思われます。
その上での推論結果だと認識ください。

まずはhuggingface上のサンプル通りの文字列を与えて生成。

model.max_new_tokens=100
result = model.invoke("東京工業大学の主なキャンパスは、")

print(result.content)
出力
大岡山とすずかけ台にあります。
- 大岡山キャンパス:理工学部・工学院・生命理工学院などがあるメインキャンパスです。
- すずかけ台キャンパス:情報理工学院や環境社会理工学院などがあります。
- 田町キャンパス:大学院生向けの研究施設で、主に修士課程以上の学生が利用しています。
- すずかけ台キャンパス:情報理工学院や

あってそうですね。

では、ここからオリジナル。

model.max_new_tokens=100
result = model.invoke("大阪の特徴的な観光名所を3か所、Markdownの表形式で名前と特徴を出力すると、")

print(result.content)
出力
以下のようになります。
- 通天閣:高さ108mの展望台があるタワー。
- 道頓堀:川沿いに飲食店や劇場などが立ち並ぶ繁華街。
- あべのハルカス:日本一高いビル。
このように、Markdownの表は、データを整理してわかりやすく表示できるため、情報をまとめて伝える際に便利です。
また、Markdownの

Markdownの表ではないのが残念ですが、内容は妥当そうです。
(あべのハルカスはもう日本一ではないはずですが)

model.max_new_tokens=200
result = model.invoke("Databricksの特長をJSON形式で出力すると、")

print(result.content)
出力
その中には「Spark」や「Delta Lake」などが含まれている。
これらのキーワードから、データレイクプラットフォームとしてのDatabricksの特徴を理解できるだろう。
また、このような機能を提供しながらも、クラウド上で実行されるため、インフラストラクチャ管理の手間がかからないことも大きなメリットだ。
さらに、Databricksはオープンソースソフトウェア(OSS)であるApache Sparkをベースに開発されたサービスであり、Apache Sparkエコシステムのコミュニティとも連携している。
そのため、Apache Sparkの最新情報をいち早くキャッチアップできたり、Apache Sparkの開発者たちと交流したりする機会も多い。
こうした点から、Databricksはデータエンジニアリングの世界で注目度の高いツールとなっており、今後ますます活用される可能性が高い。

こちらもJSON形式では出力されませんが、内容は正しいものが出力されています。

model.max_new_tokens=100
result = model.invoke("葬送のフリーレンとは、")

print(result.content)
出力
2015年から週刊少年サンデーで連載されている漫画です。
主人公のエルフ・フリーレンが人間たちと旅をしながら、仲間や自分自身について学んでいく物語です。
この作品は、ファンタジー要素だけでなく、深いテーマ性も持ち合わせています。
特に、死生観や人間関係など、現代社会でも問題視されるようなテーマが描かれており、読者に深く

連載開始時期は間違っていますが(正しくは2020年)ほかは妥当そう。日本語Wikipediaなどを利用して学習しているため、このあたりの知識は保有していますね。

model.max_new_tokens=100
result = model.invoke("魔法少女まどか☆マギカで最もかわいい人の名前をあげると、それは")

print(result.content)
出力
鹿目まどかです。
彼女が登場するたびに、視聴者たちはそのかわいさに魅了されてしまいます。
そんな鹿目まどかの魅力を徹底的に解説していきたいと思います!
目次
- 1 鹿目まどか(かなめまどか)ってどんなキャラ?
- 2 鹿目まどかの声優さんは誰?
- 3 鹿目まどか

ノーコメント。

その他

Swallow-MXもEXL2での量子化を試みたのですが、こちらはA10G一枚で推論させたくて3.0bpwで量子化した結果、まともに生成してくれなくなりました。4bit以上のサイズで量子化するか、キャリブレーションデータセットを工夫するなど必要がありそうです。
(もしくは素直にGGUFで量子化したモデルを利用するか)

指示チューニングしたSwallow MXが出たらリベンジしたいと思います。

まとめ

Swallow MSを試してみました。
個人的に7Bサイズの日本語LLMとしてはMistralをベース使うのがいいんじゃないかなと勘で思っていて、使いやすい日本語モデルが出てきたなと思います。ライセンスも商用利用面考えても使いやすいし。
今回はうまくいきませんでしたが、Mixtralを使った日本語LLMも期待をしていて、指示チューニングしたモデルの公開も楽しみです。(自分でやれよって話ですが)

そして翌日にElyza社から70bのLLMが出てくるんだから、この分野の競争はまだまだ続きそうですね。

なお、デモを触ってみている感じ、確かに非常によい回答生成や指示への追随が見られるので面白いなと思います。
(ただし、現時点でこのモデルがオープンに公開されるかは不明)

いろんな動きがまだまだ起きてきそうなので、今後も楽しみですね。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?