LoginSignup
1
0

More than 1 year has passed since last update.

RustでBERT!?rust-bertで遊んでみた (後半)

Last updated at Posted at 2022-03-07

前回の記事はこちらです。
Rust用に変換したGPT-2の重みを使って、実際にTextGenerationタスクをしてみました。

※使用環境などは前回と同じです。

プロジェクトの作成

Rustではcargoコマンドでプロジェクトの作成や実行、公開などができるらしく、すごく便利だなぁと感動しました。
プロジェクトを作成します!

bash
cargo new gpt2_test

すると、以下の構成になっていると思います

workspace
    ├── gpt2/
    ├── rust-bert/
    └── gpt2_test/
            ├──Cargo.toml
            └──src/
                └──main.rs

Cargo.tomlの編集

Cargo.tomlにはプロジェクトの情報や、依存関係を記述します。
今回は2つのクレートを使っているので、[dependencies]の下に追記します。

./Cargo.toml
rust-bert = "0.17.0"
anyhow = "1.0.55"

※ 各クレートのバージョンやドキュメントはPythonでいうPyPIの、crates.ioというところから確認できます。

main.rsの編集

15行目あたりの、<モデルのパス>を、前回Rust用に変換したモデルの場所に置き換えてくだい。
(Rust始めて1ヶ月もたたない初心者なので、文法おかしいところがあると思いますが、動くので許してください...)

./src/main.rs
use std::path::PathBuf;
use rust_bert::resources::LocalResource;
use rust_bert::resources::Resource::Local;
use rust_bert::pipelines::text_generation::{ TextGenerationConfig, TextGenerationModel };
use rust_bert::resources::Resource;

fn input() -> String {
    let mut text = String::new();
    println!("Input: ");
    std::io::stdin().read_line(&mut text).unwrap();
    return text.trim().to_string();
}

fn get_resource(item: String) -> Resource {
    let mut model_dir = PathBuf::from("<モデルのパス>/gpt2/");
    model_dir.push(&item);
    println!("{:?}", model_dir);
    let resource = Local(LocalResource{
        local_path: model_dir,
    });
    return resource;
}

fn main() -> anyhow::Result<()> {

    let model_resource = get_resource(String::from("rust_model.ot"));
    let vocab_resource = get_resource(String::from("vocab.json")); 
    let config_resource = get_resource(String::from("config.json"));
    let merges_resource = get_resource(String::from("merges.txt"));

    // configの作成
    let generate_config = TextGenerationConfig {     
        model_type: rust_bert::pipelines::common::ModelType::GPT2,
        model_resource,
        config_resource,
        vocab_resource,
        merges_resource,

        // パラメーター調整
        repetition_penalty: 1.6,
        max_length: 30,
        do_sample: false,
        num_beams: 1,
        temperature: 1.0,
        ..Default::default()
    };
    
    // 上のconfigからモデル作成
    let model = TextGenerationModel::new(generate_config)?;
    //model.set_device(Device::cuda_if_available());
    loop {
        let input_text = input();
        // QUITで終了できるように
        if input_text == "QUIT" { break; }
        // 時間測定スタート
        let start = std::time::Instant::now();
        println!("Generating...");
        // 推論
        let output = model.generate(&[input_text], None);

        for sentence in output {
            println!("「{:?}」", sentence);
        }
        // 時間測定。差分を取る
        let stop = std::time::Instant::now();
        println!("<Time: {:.3}s>", (stop.duration_since(start).as_millis() as f64) / 1000.0);
        println!("\n");
    }
    Ok(())
}

実行!!

gpt2_test下で実行します。

bash
cargo run

cargo runは、『コードをコンパイルして実行する』というお得なコマンドらしいです。すごい。
※コンパイルだけしたい場合は、cargo buildでいけます。

初回はコンパイルと、外部クレートのダウンロードがあるため、少々時間がかかります。
次回以降は、target/debug/に実行ファイルが生成されているので、

bash
./target/debug/gpt2_test

でサクッと実行できます!

実行結果

「I am a」が私の入力です。
GPT-2は文の最初を入力すると、次に続く単語を順に推論していきます。
Screen Shot 2022-03-07 at 19.08.20.png

推論時間:0.495秒

感想

rust-bertについて、GPT-2によるTextGenerationタスクを試してみました。
他にもQuestion AnsweringTranslationも用意されているようです!
rust-bert 公式ドキュメント

日本語モデルでの生成はまだ上手く行っていないので、できたらまた記事にしようと思います。
環境構築も(libtorchを除けば)簡単だと思うので、ぜひ興味もって試してくれたら嬉しいです。

URLなど

crates.io
rust_bert
Rust 日本語ドキュメント

pythonで学習したDNNモデルをC++から利用する(PyTorch & libtorch版)
※libtorchのインストールで参考にさせていただきました。ありがとうございましたmm

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0