1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

CandleでローカルLLMを実行する

Posted at

Rust用の機械学習フレームワーク Candle を使ってローカル LLM を実行してみました。

Candle は Hugging Face がオープンソース化しているライブラリの 1つで、TransformersDiffuers と同等の処理を Rust で実装できます。

ここでは Microsoft の Phi-1.5 を利用しましたが、他の LLM も概ね同じように処理できそうでした。1

モデルの取得

hf-hub を使う事で Hugging Face からのモデルダウンロードも Rust で実装できますが2、ここでは対象ファイルを手動ダウンロードしました。

具体的には、https://huggingface.co/microsoft/phi-1_5/tree/main から以下のファイルをダウンロードしてローカルへ保存しておきます。

  • config.json
  • model.safetensors
  • tokenizer.json

ちなみに、LLM によっては safetensors ファイルが複数に分かれているケースもあります。

実装

Candle で LLM を処理する流れは概ねこのようになります。

  1. Tokenizer でプロンプトを encode 処理してトークン化
  2. トークンを Model で forward 処理
    • 初回だけプロンプトの全内容を入力にする
    • 2回目以降は前回選んだトークンを入力にする
  3. forward 結果を LogitsProcessor で sample 処理してトークンを選出
  4. 終端トークンになるか既定回数になるまで 2 以降を繰り返す

candle_transformers::models に各モデルの ConfigModel が定義されているので該当するものを使って処理を行います。

例えば、Phi-1.5 であれば candle_transformers::models::phi::{Config, Model} を使用します。

2 の forward 処理では、初回だけプロンプトの全内容を入力として使い、2回目以降は前回選んだトークンを入力に使います。

forward の結果として各トークンの確率値が得られるので LogitsProcessor でトークンを選出します。3

また、終端トークンの文字列は LLM(厳密には Tokenizer)によって変化します。4

ダウンロードしたファイルを model ディレクトリに配置し、CPU で最小限の処理を行うコードはこのようになりました。5

Phi-1.5による文章生成コード例
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::phi::{Config, Model};
use std::env;
use tokenizers::{Result, Tokenizer};

fn main() -> Result<()> {
    let config_file = "model/config.json";
    let tokenizer_file = "model/tokenizer.json";
    let model_file = "model/model.safetensors";

    let max_len: usize = 1000;
    let temperature = Some(0.8);
    let top_p = Some(0.5);

    let prompt = env::args().nth(1).ok_or("prompt")?;

    let seed = env::args()
        .nth(2)
        .and_then(|x| x.parse().ok())
        .unwrap_or(1);

    let device = Device::Cpu;

    let config_str = std::fs::read_to_string(config_file)?;
    let config: Config = serde_json::from_str(&config_str)?;

    let vb = unsafe {
        VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)?
    };
    let mut model = Model::new(&config, vb)?;

    let mut logits_proc = LogitsProcessor::new(seed, temperature, top_p);

    let tokenizer = Tokenizer::from_file(tokenizer_file)?;
    // 1. プロンプトのトークン化
    let tokens = tokenizer.encode(prompt, true)?;

    if tokens.is_empty() {
        return Err("empty token".into());
    }
    // 終端トークンの取得
    let eos_token = tokenizer
        .token_to_id("<|endoftext|>")
        .ok_or("not found eos token")?;

    let mut tokens = tokens.get_ids().to_vec();
    let mut output_tokens = vec![];

    for _ in 0..max_len {
        let input = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
        // 2. forward 処理
        let logits = model.forward(&input)?;

        let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
        // 3. トークンの選定
        let next_token = logits_proc.sample(&logits)?;

        if next_token == eos_token {
            break;
        }

        output_tokens.push(next_token);
        tokens = vec![next_token];
    }
    // 結果出力
    print_tokens(&tokenizer, &output_tokens);

    Ok(())
}

fn print_tokens(tokenizer: &Tokenizer, tokens: &Vec<u32>) {
    if !tokens.is_empty() {
        if let Ok(t) = tokenizer.decode(tokens, true) {
            print!("{t}")
        }
    }
}
Cargo.tomlの依存定義
[dependencies]
candle-core = "0.8.2"
candle-nn = "0.8.2"
candle-transformers = "0.8.2"
serde_json = "1.0.137"
tokenizers = "0.21.0"

なお、同じトークンの繰り返しを防止するには 3 の処理の前に candle_transformers::utils::apply_repeat_penalty 関数を使って繰り返すトークンに対してペナルティを課す方法があります。

動作確認

実行結果はこのようになりました。
文章の内容はともかく一応は処理できているようです。

実行例
$ cargo run --release "Write a shopping cart domain model in Domain-Driven Design"
出力結果例
 (DDD) that uses a class-based view to display a list of products.

```python
from django.views.generic import ListView
from .models import Product

class ProductListView(ListView):
    model = Product
    template_name = 'product_list.html'
```

Exercise 3:
Write a shopping cart domain model in DDD that uses a class-based view to display a list of products with their prices.

```python
from django.views.generic import ListView
from .models import Product, Price

class ProductListView(ListView):
    model = Product
    template_name = 'product_list.html'

    def get_context_data(self, **kwargs):
        context = super().get_context_data(**kwargs)
        context['products'] = Product.objects.all()
        context['prices'] = Price.objects.all()
        return context
```

Exercise 4:
Write a shopping cart domain model in DDD that uses a class-based view to display a list of products with their prices and descriptions.

```python
from django.views.generic import ListView
from .models import Product, Price, Description

class ProductListView(ListView):
    model = Product
    template_name = 'product_list.html'

    def get_context_data(self, **kwargs):
        context = super().get_context_data(**kwargs)
        context['products'] = Product.objects.all()
        context['prices'] = Price.objects.all()
        context['descriptions'] = Description.objects.all()
        return context
```

Exercise 5:
Write a shopping cart domain model in DDD that uses a class-based view to display a list of products with their prices and descriptions, sorted by price in ascending order.

```python
from django.views.generic import ListView
from .models import Product, Price, Description

class ProductListView(ListView):
    model = Product
    template_name = 'product_list.html'

    def get_context_data(self, **kwargs):
        context = super().get_context_data(**kwargs)
        context['products'] = Product.objects.all()
        context['prices'] = Price.objects.all()
        context['descriptions'] = Description.objects.all()
        context['sorted_products'] = sorted(context['products'], key=lambda x: x.price)
        return context
```
  1. 例えば、gemma2 では forward の引数が 1つ増えてこれまでに処理したトークン数を渡す必要があったものの、基本的な処理内容は同じだった

  2. hf-hub によって Transformers や Diffusers と同じディレクトリ構成でモデルをダウンロードできます

  3. LogitsProcessor の temperature を None にした場合は常に最大のものを選ぶ Sampling::ArgMax が適用されるようです

  4. 例えば、gemma2 の場合は <eos> でした

  5. ここでは main 関数の戻り値に tokenizers の Result を使用しました

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?