13
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ED法でmnistを学習する(BCELoss)

Last updated at Posted at 2024-04-27

これの続き

BCELossがいい感じだったのでmnistをやってみる
2レイヤーで、合計1570個のパラメータを持っている。

0 or 1 の学習

plot.png

すごくいい感じに学習ができた。
テストデータでもいい感じ。

correct: 2114 / 2115 = 0.9995271867612293

4 or 9 の学習

plot.png

こっちもすぐ過学習に陥ってるけどいい感じ

テストデータ

correct: 1874 / 1991 = 0.9412355600200905

実装

pub struct Mnist {
    first_layer: Layer<Sigmoid>,
    layers: Vec<Layer<Sigmoid>>,
    last_layer: Layer<PassThrough>,
}

impl Mnist {
    pub fn new(layer_num: usize, neural_num: usize) -> Self {
        let mut rng = StdRng::seed_from_u64(42);
        Mnist {
            first_layer: Layer::new(&mut rng, 784 * 2, neural_num),
            layers: (0..layer_num)
                .map(|_| Layer::new(&mut rng, neural_num, neural_num))
                .collect(),
            last_layer: Layer::new(&mut rng, neural_num, 1),
        }
    }

    pub fn forward(&mut self, inputs: &[f64]) -> f64 {
        let x = duplicate_elements(inputs.into_iter()).collect();
        let x = self.first_layer.forward(x);
        let x = self.layers.iter_mut().fold(x, |x, layer| layer.forward(x));
        self.last_layer.forward(x)[0]
    }

    pub fn forward_without_train(&self, inputs: &[f64]) -> f64 {
        let x = duplicate_elements(inputs.into_iter()).collect();
        let x = self.first_layer.forward_without_train(x);
        let x = self
            .layers
            .iter()
            .fold(x, |x, layer| layer.forward_without_train(x));
        self.last_layer.forward_without_train(x)[0]
    }

    pub fn backward(&mut self, delta: f64) {
        self.first_layer.backward(delta);
        self.layers
            .iter_mut()
            .for_each(|layer| layer.backward(delta));
        self.last_layer.backward(delta);
    }
}

所感

バッチでの学習がなくても、思いの外いい感じで学習できている。
次はCrossEntropyLossの実装でmnist全体の学習にトライしたい。

こちらの記事ではバッチの実装ができているらしく、そっちの実装にも手をつけたいところ。

蛇足

ED法のモデルは一層目の入力をduplicateする必要があり、それを忘れるときちんと学習ができない。
その実装を忘れていても 0 or 1 は学習できたけど、 4 or 9 の学習はできなかった。
面白い。

13
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?