1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

rustでpytorchしてみた

Posted at

pytorchなるものがこの世の中にはあるらしいのだが、rustからでも使えるらしいのでちょっと使ってみた。
サンプルがMNISTとかばっかりだったので、もっとシンプルな実装の例の参考になると良いな。

環境

Lubuntu 24.04 on VMWare

導入

rustをインストール。
cargo init PROJECT_NAMEする。
tch-rscrateの追加。v0.16.0。お気楽にやるためにfeatures=["download-libtorch"]。libtorchはv2.3.0。2024年07月末時点。

サンプル

中間層1層、判別関数はsigmoid、最適化はAdam、ミニバッチあり。

main.rs
use tch::{nn, nn::{Module, OptimizerConfig, VarStore}};
use tch::{Device, data::Iter2, Kind, Tensor};

const INPUTSIZE : i64 = 10;  // 入力
const HIDDENSIZE : i64 = 64;  // 中間層
const OUTPUTSIZE : i64 = 4;  // 出力
const ETA : f64 = 0.01;  // 学習率
const MINIBATCH : i64 = 16;  // ミニバッチ数

// ネットワーク構成
// 重みは乱数で初期化される。
fn net(vs : &nn::Path) -> impl Module {
    nn::seq()
        .add(nn::linear(vs / "layer1",
            INPUTSIZE,  // 入力数
            HIDDENSIZE,  // 中間層
            Default::default()))
        .add_fn(|xs| xs.sigmoid())  // シグモイド
        .add(nn::linear(vs / "layer2",
            HIDDENSIZE,  // 中間層
            OUTPUTSIZE,  // 出力数
            Default::default()))
}

// 教師データセット数x入力数の教師データを
// 1次元として読み込み
fn loadinputdata() -> Vec<f32> {
    let ret = Vec::new();
    // 入力(特徴)データを読み込む。
    ret
}

// 教師データセット数x出力数の正解の推論結果の教師データを
// 1次元として読み込み
fn loadtargetdata() -> Vec<f32> {
    let ret = Vec::new();
    // 出力(正解)データを読み込む。
    ret
}

fn main() -> Result<(), tch::TchError> {
    let inputdata : Vec<f32> = loadinputdata();
    let targetdata : Vec<f32> = loadtargetdata();

    // 1次元のデータをデータセット数x入力数としてTensorに変換
    let input = tch::Tensor::from_slice(&inputdata)
            .view(inputdata.len() / INPUTSIZE, INPUTSIZE));
    println!("input : {} {:?}", input.dim(), input.size());

    // 1次元のデータをデータセット数x出力数としてTensorに変換
    let target = tch::Tensor::from_slice(&targetdata)
            .view(targetdata.len() / OUTPUTSIZE, OUTPUTSIZE));
    println!("target: {} {:?}", target.dim(), target.size());

    let mut vs = VarStore::new(Device::Cpu);  // CPUで計算
    let nnet = net(&vs.root());

    if ファイルの値を読み込みたい {
        vs.load(ファイル名).unwrap();
    }

    let mut optm = nn::Adam::default().build(&vs, ETA)?;  // Adam
    for (key, t) in vs.variables().iter_mut() {  // 例:各WEIGHTのサイズの確認
        println!("{key}:{:?}", t.size());
    }
    for ep in 0..100 {  // 100世代
        let mut dataset = Iter2::new(&input, &target, MINIBATCH);
        let dataset = dataset.shuffle();  // シャッフル
        for (xs, ys) in dataset {
            let loss = nnet.forward(&xs).mse_loss(&ys, tch::Reduction::Mean);
            optm.backward_step(&loss);
        }
        print!("epoch:{ep} ");
        std::io::stdout().flush().unwrap();
    }
    println!("");

    // ファイルに保存
    vs.save("weight.safetensors");

    Ok(())
}

感想

100行程度のコードで簡単なNNの学習が出来てしまってびっくりというかあっけないというか、Adamとか関数呼ぶだけで出来ちゃうもんセコすぎ、そりゃみんなpytorch使いますわ。

参考

tch-rsでググる

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?