Rust用の機械学習フレームワーク Candle を使ってローカル LLM を実行してみました。
Candle は Hugging Face がオープンソース化しているライブラリの 1つで、Transformers や Diffuers と同等の処理を Rust で実装できます。
ここでは Microsoft の Phi-1.5 を利用しましたが、他の LLM も概ね同じように処理できそうでした。1
モデルの取得
hf-hub を使う事で Hugging Face からのモデルダウンロードも Rust で実装できますが2、ここでは対象ファイルを手動ダウンロードしました。
具体的には、https://huggingface.co/microsoft/phi-1_5/tree/main から以下のファイルをダウンロードしてローカルへ保存しておきます。
- config.json
- model.safetensors
- tokenizer.json
ちなみに、LLM によっては safetensors ファイルが複数に分かれているケースもあります。
実装
Candle で LLM を処理する流れは概ねこのようになります。
- Tokenizer でプロンプトを
encode
処理してトークン化 - トークンを Model で
forward
処理- 初回だけプロンプトの全内容を入力にする
- 2回目以降は前回選んだトークンを入力にする
- forward 結果を LogitsProcessor で
sample
処理してトークンを選出 - 終端トークンになるか既定回数になるまで
2
以降を繰り返す
candle_transformers::models
に各モデルの Config
や Model
が定義されているので該当するものを使って処理を行います。
例えば、Phi-1.5 であれば candle_transformers::models::phi::{Config, Model}
を使用します。
2
の forward 処理では、初回だけプロンプトの全内容を入力として使い、2回目以降は前回選んだトークンを入力に使います。
forward の結果として各トークンの確率値が得られるので LogitsProcessor
でトークンを選出します。3
また、終端トークンの文字列は LLM(厳密には Tokenizer)によって変化します。4
ダウンロードしたファイルを model ディレクトリに配置し、CPU で最小限の処理を行うコードはこのようになりました。5
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::phi::{Config, Model};
use std::env;
use tokenizers::{Result, Tokenizer};
fn main() -> Result<()> {
let config_file = "model/config.json";
let tokenizer_file = "model/tokenizer.json";
let model_file = "model/model.safetensors";
let max_len: usize = 1000;
let temperature = Some(0.8);
let top_p = Some(0.5);
let prompt = env::args().nth(1).ok_or("prompt")?;
let seed = env::args()
.nth(2)
.and_then(|x| x.parse().ok())
.unwrap_or(1);
let device = Device::Cpu;
let config_str = std::fs::read_to_string(config_file)?;
let config: Config = serde_json::from_str(&config_str)?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)?
};
let mut model = Model::new(&config, vb)?;
let mut logits_proc = LogitsProcessor::new(seed, temperature, top_p);
let tokenizer = Tokenizer::from_file(tokenizer_file)?;
// 1. プロンプトのトークン化
let tokens = tokenizer.encode(prompt, true)?;
if tokens.is_empty() {
return Err("empty token".into());
}
// 終端トークンの取得
let eos_token = tokenizer
.token_to_id("<|endoftext|>")
.ok_or("not found eos token")?;
let mut tokens = tokens.get_ids().to_vec();
let mut output_tokens = vec![];
for _ in 0..max_len {
let input = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
// 2. forward 処理
let logits = model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
// 3. トークンの選定
let next_token = logits_proc.sample(&logits)?;
if next_token == eos_token {
break;
}
output_tokens.push(next_token);
tokens = vec![next_token];
}
// 結果出力
print_tokens(&tokenizer, &output_tokens);
Ok(())
}
fn print_tokens(tokenizer: &Tokenizer, tokens: &Vec<u32>) {
if !tokens.is_empty() {
if let Ok(t) = tokenizer.decode(tokens, true) {
print!("{t}")
}
}
}
[dependencies]
candle-core = "0.8.2"
candle-nn = "0.8.2"
candle-transformers = "0.8.2"
serde_json = "1.0.137"
tokenizers = "0.21.0"
なお、同じトークンの繰り返しを防止するには 3
の処理の前に candle_transformers::utils::apply_repeat_penalty
関数を使って繰り返すトークンに対してペナルティを課す方法があります。
動作確認
実行結果はこのようになりました。
文章の内容はともかく一応は処理できているようです。
$ cargo run --release "Write a shopping cart domain model in Domain-Driven Design"
(DDD) that uses a class-based view to display a list of products.
```python
from django.views.generic import ListView
from .models import Product
class ProductListView(ListView):
model = Product
template_name = 'product_list.html'
```
Exercise 3:
Write a shopping cart domain model in DDD that uses a class-based view to display a list of products with their prices.
```python
from django.views.generic import ListView
from .models import Product, Price
class ProductListView(ListView):
model = Product
template_name = 'product_list.html'
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
context['products'] = Product.objects.all()
context['prices'] = Price.objects.all()
return context
```
Exercise 4:
Write a shopping cart domain model in DDD that uses a class-based view to display a list of products with their prices and descriptions.
```python
from django.views.generic import ListView
from .models import Product, Price, Description
class ProductListView(ListView):
model = Product
template_name = 'product_list.html'
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
context['products'] = Product.objects.all()
context['prices'] = Price.objects.all()
context['descriptions'] = Description.objects.all()
return context
```
Exercise 5:
Write a shopping cart domain model in DDD that uses a class-based view to display a list of products with their prices and descriptions, sorted by price in ascending order.
```python
from django.views.generic import ListView
from .models import Product, Price, Description
class ProductListView(ListView):
model = Product
template_name = 'product_list.html'
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
context['products'] = Product.objects.all()
context['prices'] = Price.objects.all()
context['descriptions'] = Description.objects.all()
context['sorted_products'] = sorted(context['products'], key=lambda x: x.price)
return context
```
-
例えば、gemma2 では
forward
の引数が 1つ増えてこれまでに処理したトークン数を渡す必要があったものの、基本的な処理内容は同じだった ↩ -
hf-hub によって Transformers や Diffusers と同じディレクトリ構成でモデルをダウンロードできます ↩
-
LogitsProcessor の temperature を None にした場合は常に最大のものを選ぶ
Sampling::ArgMax
が適用されるようです ↩ -
例えば、gemma2 の場合は
<eos>
でした ↩ -
ここでは main 関数の戻り値に tokenizers の Result を使用しました ↩