pytorchなるものがこの世の中にはあるらしいのだが、rustからでも使えるらしいのでちょっと使ってみた。
サンプルがMNISTとかばっかりだったので、もっとシンプルな実装の例の参考になると良いな。
環境
Lubuntu 24.04 on VMWare
導入
rustをインストール。
cargo init PROJECT_NAME
する。
tch-rs
crateの追加。v0.16.0。お気楽にやるためにfeatures=["download-libtorch"]
。libtorchはv2.3.0。2024年07月末時点。
サンプル
中間層1層、判別関数はsigmoid、最適化はAdam、ミニバッチあり。
main.rs
use tch::{nn, nn::{Module, OptimizerConfig, VarStore}};
use tch::{Device, data::Iter2, Kind, Tensor};
const INPUTSIZE : i64 = 10; // 入力
const HIDDENSIZE : i64 = 64; // 中間層
const OUTPUTSIZE : i64 = 4; // 出力
const ETA : f64 = 0.01; // 学習率
const MINIBATCH : i64 = 16; // ミニバッチ数
// ネットワーク構成
// 重みは乱数で初期化される。
fn net(vs : &nn::Path) -> impl Module {
nn::seq()
.add(nn::linear(vs / "layer1",
INPUTSIZE, // 入力数
HIDDENSIZE, // 中間層
Default::default()))
.add_fn(|xs| xs.sigmoid()) // シグモイド
.add(nn::linear(vs / "layer2",
HIDDENSIZE, // 中間層
OUTPUTSIZE, // 出力数
Default::default()))
}
// 教師データセット数x入力数の教師データを
// 1次元として読み込み
fn loadinputdata() -> Vec<f32> {
let ret = Vec::new();
// 入力(特徴)データを読み込む。
ret
}
// 教師データセット数x出力数の正解の推論結果の教師データを
// 1次元として読み込み
fn loadtargetdata() -> Vec<f32> {
let ret = Vec::new();
// 出力(正解)データを読み込む。
ret
}
fn main() -> Result<(), tch::TchError> {
let inputdata : Vec<f32> = loadinputdata();
let targetdata : Vec<f32> = loadtargetdata();
// 1次元のデータをデータセット数x入力数としてTensorに変換
let input = tch::Tensor::from_slice(&inputdata)
.view(inputdata.len() / INPUTSIZE, INPUTSIZE));
println!("input : {} {:?}", input.dim(), input.size());
// 1次元のデータをデータセット数x出力数としてTensorに変換
let target = tch::Tensor::from_slice(&targetdata)
.view(targetdata.len() / OUTPUTSIZE, OUTPUTSIZE));
println!("target: {} {:?}", target.dim(), target.size());
let mut vs = VarStore::new(Device::Cpu); // CPUで計算
let nnet = net(&vs.root());
if ファイルの値を読み込みたい {
vs.load(ファイル名).unwrap();
}
let mut optm = nn::Adam::default().build(&vs, ETA)?; // Adam
for (key, t) in vs.variables().iter_mut() { // 例:各WEIGHTのサイズの確認
println!("{key}:{:?}", t.size());
}
for ep in 0..100 { // 100世代
let mut dataset = Iter2::new(&input, &target, MINIBATCH);
let dataset = dataset.shuffle(); // シャッフル
for (xs, ys) in dataset {
let loss = nnet.forward(&xs).mse_loss(&ys, tch::Reduction::Mean);
optm.backward_step(&loss);
}
print!("epoch:{ep} ");
std::io::stdout().flush().unwrap();
}
println!("");
// ファイルに保存
vs.save("weight.safetensors");
Ok(())
}
感想
100行程度のコードで簡単なNNの学習が出来てしまってびっくりというかあっけないというか、Adamとか関数呼ぶだけで出来ちゃうもんセコすぎ、そりゃみんなpytorch使いますわ。
参考