4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Rustオンリーで簡素なニューラルネットワークを構成

Last updated at Posted at 2020-02-19

表題の通り、Rustオンリーで簡素なニューラルネットワークを構築してみました
下記のソースコードではXORを学習するようになっています。

連鎖律を用いた誤差逆伝播などは行っておらず、損失関数より算出された誤差を元に
weightをちまちま更新していく作りになっています(なので遅いです…)

#ソースコード
外部ライブラリ等は必要なく、ベタ貼り後rustcなりcargoなりで
動きます


use std::ops::Fn;

// general
fn relu(x : &Vec<f64>) -> Vec<f64>{
    let mut ret = x.to_vec();
    for i in &mut ret{
        *i = if 0.0 < *i {*i} else {0.0};
    }
    ret
}

fn num_diff<F>(f:&F,w:&mut Vec<Vec<Vec<f64>>>,tgt_p1:usize,tgt_p2:usize) -> Vec<f64> where F : (Fn(&Vec<Vec<Vec<f64>>>) -> f64){
    let h = 1e-04;
    let mut grad = vec![0.0_f64;w[tgt_p1][tgt_p2].len()];
    
    for i in 0..w[tgt_p1][tgt_p2].len(){
        let tmp = w[tgt_p1][tgt_p2][i];

        w[tgt_p1][tgt_p2][i] = tmp + h;
        let fw = f(&w);

        w[tgt_p1][tgt_p2][i] = tmp - h;
        let bk = f(&w);

        grad[i] = (fw - bk) / (2.0 * h);
        w[tgt_p1][tgt_p2][i] = tmp;
    }

    grad
}

fn mean_squared_error(y:&Vec<f64>,t:&Vec<f64>) -> f64{
    if y.len() != t.len(){
        panic!("unmatch dim. y:{},t:{}",y.len(),t.len());
    }
    
    let mut ret= 0.0_f64;
    for i in 0..y.len(){
        ret += (y[i] - t[i]).powf(2.0);
    }

    let ret_val = ret / 2.0;

    return ret_val;
}

fn dot(x:&Vec<f64>,w:&Vec<Vec<f64>>) -> Vec<f64>{
    if w.len() <= 0 || x.len() != w.len(){
        panic!("unmatch dim. x:{},y:{}",x.len(),w.len());
    }

    let mut ret = vec![0.0_f64;w[0].len()];

    for i in 0..w[0].len(){
        for j in 0..x.len(){
            ret[i] += x[j] * w[j][i];
        }
    }

    ret
}

fn predict(x:&Vec<f64>,w:&Vec<Vec<Vec<f64>>>) -> Vec<f64>{
    let x1 = dot(x,&w[0]);
    let z1 = relu(&x1);
    let x2 = dot(&z1,&w[1]);
    let z2 = relu(&x2);
    let x3 = dot(&z2,&w[2]);

    x3
}

fn backward(x:&Vec<f64>,t:&Vec<f64>,w:&mut Vec<Vec<Vec<f64>>>) -> Vec<Vec<Vec<f64>>>{
    let lt = 0.05;
    let f = |wh:&Vec<Vec<Vec<f64>>>| -> f64{
        let pr = predict(x,wh);
        let ret = mean_squared_error(&pr,t);
        return ret;
    };

    let mut diff_ret : Vec<Vec<Vec<f64>>> = Vec::new();

    for i in 0..w.len(){
        diff_ret.push(vec![Vec::new();w[i].len()]);
        for j in 0..w[i].len(){
            diff_ret[i][j] = num_diff(&f,w,i,j);
            for k in 0..diff_ret[i][j].len(){
                w[i][j][k] -= lt * diff_ret[i][j][k];
            }
        }
    }
    

    diff_ret
}

fn forward(x:&Vec<f64>,w:&Vec<Vec<Vec<f64>>>) -> Vec<f64>{
    predict(x,w)
}

fn training(input : &Vec<Vec<f64>>,answer : &Vec<Vec<f64>>,weight : &mut Vec<Vec<Vec<f64>>>,epoch :usize,view_status :usize){
    let epoch_val = epoch + 1;
    let view_status_val = if view_status <= 0 {1} else {view_status};
    for i in 1..epoch_val{
        for dt in 0..input.len(){
            backward(&input[dt],&answer[dt],weight);
        }
        if i % view_status_val == 0{
            println!("epoch:{},end",i);
        }
    }
}

fn test(input : &Vec<Vec<f64>>,weight : &Vec<Vec<Vec<f64>>>){
    for inp in input{
        let ret = forward(&inp,&weight);
        for i in ret{
            println!("ret[{},{}]:{}",inp[0],inp[1],i);
        }
    }
}

fn main(){
    let input = vec![
        vec![0.0_f64,0.0_f64],
        vec![1.0_f64,0.0_f64],
        vec![0.0_f64,1.0_f64],
        vec![1.0_f64,1.0_f64]
    ];
    let mut weight = vec![
        vec![vec![0.1_f64,0.10_f64],vec![0.1_f64,0.1_f64]],
        vec![vec![0.1_f64,0.14_f64],vec![0.1_f64,0.1_f64]],
        vec![vec![0.1_f64;1],vec![0.1_f64;1]]];
    let res = vec![
        vec![0.0_f64],
        vec![1.0_f64],
        vec![1.0_f64],
        vec![0.0_f64]
    ];

    training(&input,&res,&mut weight,30000,100);
    test(&input,&weight);
}

reluで伝播させた後、恒等関数で結果を出力し
平均2乗誤差にて誤差を算出しています。

結果は以下のようになります

epoch:100,end
epoch:200,end
epoch:300,end
…(省略)
epoch:29800,end
epoch:29900,end
epoch:30000,end
ret[0,0]:0
ret[1,0]:0.9999999999911916
ret[0,1]:1.0000000000002078
ret[1,1]:0

各値に対するXORが算出されてますね!
正しく学習できたようです!

おわりに

上記ソースのライセンスは
よく分からないのですが一先ずMITとします

また、下記githubにもソースコードをアップしました
https://github.com/kakisaouns/Rust_Easy_NeuralNet

誤り、指摘などありましたらコメント頂けると嬉しいです。

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?