candle とは?
rust で使えるミニマリスト機械学習フレームワークです。(Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support) and ease of use.) 1
低レベルな逆伝播を含むテンソルの計算ができる一方、whisper や LLaMAのような高レベルな機能もある程度は備えているようです。
今回は elyza を試してみます。
手順
- 環境構築
- モデルをダウンロードする
- プロジェクトを作る
- クレートを追加する
- コードを書く
環境構築
今回は docker を使用しました。
ベースイメージはホストマシンにインストールされているドライバに合わせて調整が必要です。
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y \
curl pkg-config build-essential libssl-dev && \
apt-get clean
WORKDIR /home
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
モデルをダウンロードする
git lfs install
git clone https://huggingface.co/elyza/Llama-3-ELYZA-JP-8B
で取得できます。
今回はdockerホストにダウンロードし、(カレントディレクトリ)/elyza3
にバインドマウントしました。
プロジェクトを作る
cargo new (プロジェクト名)
以上
クレートを追加する
Cargo.toml
の [dependencies]
に以下を追加します。
[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.8.0", features = ["cuda"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.8.0", features = ["cuda"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.8.0" }
serde_json = "1.0.133"
tokenizers = "0.21.0"
理由は不明ですが、ガイドでは crates.io ではなく github から取得する手順となっています。
コードを書く
サンプル を参考に削ぎ落していくと、以下のようになりました。
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use std::error::Error;
use tokenizers::Tokenizer;
use candle_transformers::models::llama as model;
use model::{Llama, LlamaConfig, LlamaEosToks::Single};
fn main() -> Result<(), Box<dyn Error>> {
let device = Device::new_cuda(0)?;
let dtype = DType::BF16;
let prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>
あなたは日本の交通事情・住宅事情に詳しいアシスタントです。<|eot_id|><|start_header_id|>user<|end_header_id|>
三軒茶屋駅に通勤するのにおすすめの駅を3つ挙げてください。<|eot_id|><|start_header_id|>assistant<|end_header_id|>";
println!("configure start");
let tokenizer =
Tokenizer::from_file("elyza3/tokenizer.json").expect("failed to create tokenizer");
let config = std::fs::read("elyza3/config.json").expect("failed to read config");
let config: LlamaConfig = serde_json::from_slice(&config).expect("failed to read config");
let config = config.into_config(false);
let eos_token_id = config.eos_token_id.clone();
let mut cache = model::Cache::new(true, dtype, &config, &device)?;
println!("configure end");
println!("load weight start");
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(
&[
"elyza3/model-00001-of-00004.safetensors",
"elyza3/model-00002-of-00004.safetensors",
"elyza3/model-00003-of-00004.safetensors",
"elyza3/model-00004-of-00004.safetensors",
],
dtype,
&device,
)?
};
let llama = Llama::load(vb, &config)?;
println!("load weight end");
println!("inference start");
let mut tokens = tokenizer
.encode(prompt, true)
.expect("failed to tokenize")
.get_ids()
.to_vec();
let mut logits_processor =
LogitsProcessor::from_sampling(42, Sampling::All { temperature: 0.8 });
let mut index_pos = 0;
for index in 0..1000 {
let context_size = if index == 0 {
tokens.len()
} else {
1
};
let context = &tokens[tokens.len() - context_size..];
let input = Tensor::new(context, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos, &mut cache)?;
let logits = logits.squeeze(0)?;
index_pos += context.len();
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
match eos_token_id {
Some(Single(eos_tok_id)) => {
if next_token == eos_tok_id {
break;
}
}
_ => unimplemented!(),
}
}
println!("inference end");
let text = tokenizer
.decode(&tokens, false)
.unwrap_or_else(|x| format!("{}", x));
println!("{text}");
Ok(())
}
残念ながら、数行でサクッと、とは行かないようです。これがミニマリストということでしょうか。
実行
cargo run
<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>
あなたは日本の交通事情・住宅事情に詳しいアシスタントです。<|eot_id|><|start_header_id|>user<|end_header_id|>
三軒茶屋駅に通勤するのにおすすめの駅を3つ挙げてください。<|eot_id|><|start_header_id|>assistant<|end_header_id|>三軒茶屋駅周辺で通勤通学するのにおすすめの駅を3つ挙げます。
1. 渋谷駅: 渋谷駅は三軒茶屋駅から電車で10分強の距離にあり、JR山手線、埼京線、湘南新宿ライン、東急東横線、京王井の頭線、田園都市線が通っています。渋谷駅周辺には多くの企業や商業施設が集中し、通勤や就活などでも非常に便利です。
2. 二子玉川駅: 二子玉川駅は三軒茶屋駅から電車で5分弱の距離にあり、東急田園都市線が通っています。二子玉川駅周辺には再開発が進み、多くの高層マンションや商業施設が建設されています。通勤や買い物などに便利な立地です。
3. 尾山台駅: 尾山台駅は三軒茶屋駅から電車で5分弱の距離にあり、東急大井町線が通っています。尾山台駅周辺は閑静な住宅街が多く、自然も多く残っています。通勤や通学はもちろん、都心部へのアクセスも悪くないです。<|eot_id|>
二子玉川から三軒茶屋までは5分では行けない、など内容は不正確ですが、動いてはいます。
感想
モデルの構築や推論の制御をある程度自分で記述する必要があり、ある程度機械学習に精通している必要があるな、と思いました。
python環境などの不要なものは入れたくない、でも動作をカスタムしたい、という場合には選択肢になりそうです。