前回の記事はこちらです。
Rust用に変換したGPT-2の重みを使って、実際にTextGenerationタスクをしてみました。
※使用環境などは前回と同じです。
プロジェクトの作成
Rustではcargo
コマンドでプロジェクトの作成や実行、公開などができるらしく、すごく便利だなぁと感動しました。
プロジェクトを作成します!
cargo new gpt2_test
すると、以下の構成になっていると思います
workspace
├── gpt2/
├── rust-bert/
└── gpt2_test/
├──Cargo.toml
└──src/
└──main.rs
Cargo.tomlの編集
Cargo.toml
にはプロジェクトの情報や、依存関係を記述します。
今回は2つのクレートを使っているので、[dependencies]の下に追記します。
rust-bert = "0.17.0"
anyhow = "1.0.55"
※ 各クレートのバージョンやドキュメントはPythonでいうPyPIの、crates.ioというところから確認できます。
main.rsの編集
15行目あたりの、<モデルのパス>
を、前回Rust用に変換したモデルの場所に置き換えてくだい。
(Rust始めて1ヶ月もたたない初心者なので、文法おかしいところがあると思いますが、動くので許してください...)
use std::path::PathBuf;
use rust_bert::resources::LocalResource;
use rust_bert::resources::Resource::Local;
use rust_bert::pipelines::text_generation::{ TextGenerationConfig, TextGenerationModel };
use rust_bert::resources::Resource;
fn input() -> String {
let mut text = String::new();
println!("Input: ");
std::io::stdin().read_line(&mut text).unwrap();
return text.trim().to_string();
}
fn get_resource(item: String) -> Resource {
let mut model_dir = PathBuf::from("<モデルのパス>/gpt2/");
model_dir.push(&item);
println!("{:?}", model_dir);
let resource = Local(LocalResource{
local_path: model_dir,
});
return resource;
}
fn main() -> anyhow::Result<()> {
let model_resource = get_resource(String::from("rust_model.ot"));
let vocab_resource = get_resource(String::from("vocab.json"));
let config_resource = get_resource(String::from("config.json"));
let merges_resource = get_resource(String::from("merges.txt"));
// configの作成
let generate_config = TextGenerationConfig {
model_type: rust_bert::pipelines::common::ModelType::GPT2,
model_resource,
config_resource,
vocab_resource,
merges_resource,
// パラメーター調整
repetition_penalty: 1.6,
max_length: 30,
do_sample: false,
num_beams: 1,
temperature: 1.0,
..Default::default()
};
// 上のconfigからモデル作成
let model = TextGenerationModel::new(generate_config)?;
//model.set_device(Device::cuda_if_available());
loop {
let input_text = input();
// QUITで終了できるように
if input_text == "QUIT" { break; }
// 時間測定スタート
let start = std::time::Instant::now();
println!("Generating...");
// 推論
let output = model.generate(&[input_text], None);
for sentence in output {
println!("「{:?}」", sentence);
}
// 時間測定。差分を取る
let stop = std::time::Instant::now();
println!("<Time: {:.3}s>", (stop.duration_since(start).as_millis() as f64) / 1000.0);
println!("\n");
}
Ok(())
}
実行!!
gpt2_test
下で実行します。
cargo run
cargo run
は、『コードをコンパイルして実行する』というお得なコマンドらしいです。すごい。
※コンパイルだけしたい場合は、cargo build
でいけます。
初回はコンパイルと、外部クレートのダウンロードがあるため、少々時間がかかります。
次回以降は、target/debug/
に実行ファイルが生成されているので、
./target/debug/gpt2_test
でサクッと実行できます!
実行結果
「I am a」が私の入力です。
GPT-2は文の最初を入力すると、次に続く単語を順に推論していきます。
推論時間:0.495秒
感想
rust-bertについて、GPT-2によるTextGenerationタスクを試してみました。
他にもQuestion Answering
やTranslation
も用意されているようです!
rust-bert 公式ドキュメント
日本語モデルでの生成はまだ上手く行っていないので、できたらまた記事にしようと思います。
環境構築も(libtorchを除けば)簡単だと思うので、ぜひ興味もって試してくれたら嬉しいです。
URLなど
crates.io
rust_bert
Rust 日本語ドキュメント
pythonで学習したDNNモデルをC++から利用する(PyTorch & libtorch版)
※libtorchのインストールで参考にさせていただきました。ありがとうございましたmm