4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

candle とは?

rust で使えるミニマリスト機械学習フレームワークです。(Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support) and ease of use.) 1
低レベルな逆伝播を含むテンソルの計算ができる一方、whisper や LLaMAのような高レベルな機能もある程度は備えているようです。

今回は elyza を試してみます。

手順

  • 環境構築
  • モデルをダウンロードする
  • プロジェクトを作る
    • クレートを追加する
    • コードを書く

環境構築

今回は docker を使用しました。
ベースイメージはホストマシンにインストールされているドライバに合わせて調整が必要です。

FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04

RUN apt-get update                                    && \
    DEBIAN_FRONTEND=noninteractive apt-get install -y    \
    curl pkg-config build-essential libssl-dev        && \
    apt-get clean

WORKDIR /home
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y

モデルをダウンロードする

git lfs install
git clone https://huggingface.co/elyza/Llama-3-ELYZA-JP-8B

で取得できます。
今回はdockerホストにダウンロードし、(カレントディレクトリ)/elyza3 にバインドマウントしました。

プロジェクトを作る

cargo new (プロジェクト名)

以上

クレートを追加する

Cargo.toml[dependencies] に以下を追加します。

[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.8.0", features = ["cuda"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.8.0", features = ["cuda"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.8.0" }
serde_json = "1.0.133"
tokenizers = "0.21.0"

理由は不明ですが、ガイドでは crates.io ではなく github から取得する手順となっています。

コードを書く

サンプル を参考に削ぎ落していくと、以下のようになりました。

use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use std::error::Error;
use tokenizers::Tokenizer;

use candle_transformers::models::llama as model;
use model::{Llama, LlamaConfig, LlamaEosToks::Single};

fn main() -> Result<(), Box<dyn Error>> {
    let device = Device::new_cuda(0)?;
    let dtype = DType::BF16;

    let prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>
あなたは日本の交通事情・住宅事情に詳しいアシスタントです。<|eot_id|><|start_header_id|>user<|end_header_id|>
三軒茶屋駅に通勤するのにおすすめの駅を3つ挙げてください。<|eot_id|><|start_header_id|>assistant<|end_header_id|>";

    println!("configure start");

    let tokenizer =
        Tokenizer::from_file("elyza3/tokenizer.json").expect("failed to create tokenizer");

    let config = std::fs::read("elyza3/config.json").expect("failed to read config");
    let config: LlamaConfig = serde_json::from_slice(&config).expect("failed to read config");
    let config = config.into_config(false);
    let eos_token_id = config.eos_token_id.clone();

    let mut cache = model::Cache::new(true, dtype, &config, &device)?;

    println!("configure end");

    println!("load weight start");

    let vb = unsafe {
        VarBuilder::from_mmaped_safetensors(
            &[
                "elyza3/model-00001-of-00004.safetensors",
                "elyza3/model-00002-of-00004.safetensors",
                "elyza3/model-00003-of-00004.safetensors",
                "elyza3/model-00004-of-00004.safetensors",
            ],
            dtype,
            &device,
        )?
    };
    let llama = Llama::load(vb, &config)?;

    println!("load weight end");

    println!("inference start");

    let mut tokens = tokenizer
        .encode(prompt, true)
        .expect("failed to tokenize")
        .get_ids()
        .to_vec();

    let mut logits_processor =
        LogitsProcessor::from_sampling(42, Sampling::All { temperature: 0.8 });

    let mut index_pos = 0;
    for index in 0..1000 {
        let context_size = if index == 0 {
            tokens.len()
        } else {
            1
        };
        let context = &tokens[tokens.len() - context_size..];
        
        let input = Tensor::new(context, &device)?.unsqueeze(0)?;
        let logits = llama.forward(&input, index_pos, &mut cache)?;
        let logits = logits.squeeze(0)?;

        index_pos += context.len();

        let next_token = logits_processor.sample(&logits)?;

        tokens.push(next_token);

        match eos_token_id {
            Some(Single(eos_tok_id)) => {
                if next_token == eos_tok_id {
                    break;
                }
            }
            _ => unimplemented!(),
        }
    }

    println!("inference end");

    let text = tokenizer
        .decode(&tokens, false)
        .unwrap_or_else(|x| format!("{}", x));
    println!("{text}");

    Ok(())
}

残念ながら、数行でサクッと、とは行かないようです。これがミニマリストということでしょうか。

実行

cargo run

<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>
あなたは日本の交通事情・住宅事情に詳しいアシスタントです。<|eot_id|><|start_header_id|>user<|end_header_id|>
三軒茶屋駅に通勤するのにおすすめの駅を3つ挙げてください。<|eot_id|><|start_header_id|>assistant<|end_header_id|>三軒茶屋駅周辺で通勤通学するのにおすすめの駅を3つ挙げます。

1. 渋谷駅: 渋谷駅は三軒茶屋駅から電車で10分強の距離にあり、JR山手線、埼京線、湘南新宿ライン、東急東横線、京王井の頭線、田園都市線が通っています。渋谷駅周辺には多くの企業や商業施設が集中し、通勤や就活などでも非常に便利です。

2. 二子玉川駅: 二子玉川駅は三軒茶屋駅から電車で5分弱の距離にあり、東急田園都市線が通っています。二子玉川駅周辺には再開発が進み、多くの高層マンションや商業施設が建設されています。通勤や買い物などに便利な立地です。

3. 尾山台駅: 尾山台駅は三軒茶屋駅から電車で5分弱の距離にあり、東急大井町線が通っています。尾山台駅周辺は閑静な住宅街が多く、自然も多く残っています。通勤や通学はもちろん、都心部へのアクセスも悪くないです。<|eot_id|>

二子玉川から三軒茶屋までは5分では行けない、など内容は不正確ですが、動いてはいます。

感想

モデルの構築や推論の制御をある程度自分で記述する必要があり、ある程度機械学習に精通している必要があるな、と思いました。
python環境などの不要なものは入れたくない、でも動作をカスタムしたい、という場合には選択肢になりそうです。

  1. https://github.com/huggingface/candle/blob/145aa7193c4e658b184f52706574cc9f115e4674/README.md

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?