LoginSignup
0
1

More than 3 years have passed since last update.

CNNに足し算をさせてみたいだけだった

Posted at

はじめに

精度的にはいまいちな結果になってしまいました。
もう少し工夫する必要がありそうです。

誰かやってみたい人がいれば
https://github.com/daikiclimate/mnist_add

1 やりたいこと 

目的としては、
・2枚の画像を入力とするネットワークを構築する(練習)
・本来の予定では、

1+3を学習させずに、1と3の画像を入れたら何を出力するのかを確認してみたかった。

スクリーンショット 2020-02-16 11.59.50.png

この図のように画像2枚を入力として、その足し算結果を出力とするクラス分類器を作成してみました。
当然ですが「数字を予測させてから足し合わせる」ことで簡単に実現可能となります。
この点から特徴抽出を行い、concatなど特徴を合成してから足し算結果を予測させています。

2 データセット

使ったのはMNISTデータセットです。
数字0=9が写った画像が学習用60000枚テスト用10000枚があります。
ここからペアを作成していきましたが、「同じ画像だが異なった組み合わせ」を可能にすると、とんでもない量の学習用データのサイズになります。(60000!組?)

クラスは0+0 = 0 から 9+9 = 18までの計19クラスあります。
大体20クラスになります。
概算ですが、ペア画像にするため、学習用60000 / 2 
数字のペアを00 ~ 99まで作るとすると100通りあるので、各ペアは 60000 / 2 / 100 = 300ペアしかないことになります。
(00は0+0のこと)

そこで今回はランダムに画像を抽出してきてペアにする方法にしました。
各数字は5000枚以上あるので、同じ画像同士のペアができる確率は限りなく低い(はずです)。

テストデータでは、重複した画像は使わない方法でペアを作成しています。
この場合は各ペア00 ~ 99が44組ずつあります。

3 結果

訓練データの精度 テストデータの精度
27.6% 24.4%
27547 / 99968 1074 / 4400

コーディングする前に思ってたよりもかなり悪いです。
中身を確認してみます。
足し算ごとにどの結果が良く無いかを確認してみました。
長いので隠します。

各数字ごとの結果
num1 num2 acc acc[%]
0 0 40/44 90.9%
0 1 28/44 63.6%
0 2 12/44 27.3%
0 3 28/44 63.6%
0 4 12/44 27.3%
0 5 11/44 25.0%
0 6 12/44 27.3%
0 7 8/44 18.2%
0 8 11/44 25.0%
0 9 27/44 61.4%
1 0 29/44 65.9%
1 1 7/44 15.9%
1 2 20/44 45.5%
1 3 1/44 2.3%
1 4 0/44 0.0%
1 5 4/44 9.1%
1 6 22/44 50.0%
1 7 0/44 0.0%
1 8 18/44 40.9%
1 9 24/44 54.5%
2 0 22/44 50.0%
2 1 23/44 52.3%
2 2 0/44 0.0%
2 3 4/44 9.1%
2 4 1/44 2.3%
2 5 5/44 11.4%
2 6 6/44 13.6%
2 7 2/44 4.5%
2 8 0/44 0.0%
2 9 27/44 61.4%
3 0 26/44 59.1%
3 1 3/44 6.8%
3 2 1/44 2.3%
3 3 1/44 2.3%
3 4 3/44 6.8%
3 5 8/44 18.2%
3 6 2/44 4.5%
3 7 12/44 27.3%
3 8 16/44 36.4%
3 9 15/44 34.1%
4 0 9/44 20.5%
4 1 2/44 4.5%
4 2 4/44 9.1%
4 3 6/44 13.6%
4 4 0/44 0.0%
4 5 4/44 9.1%
4 6 0/44 0.0%
4 7 4/44 9.1%
4 8 8/44 18.2%
4 9 16/44 36.4%
5 0 10/44 22.7%
5 1 6/44 13.6%
5 2 4/44 9.1%
5 3 12/44 27.3%
5 4 10/44 22.7%
5 5 1/44 2.3%
5 6 15/44 34.1%
5 7 4/44 9.1%
5 8 5/44 11.4%
5 9 8/44 18.2%
6 0 5/44 11.4%
6 1 15/44 34.1%
6 2 1/44 2.3%
6 3 4/44 9.1%
6 4 0/44 0.0%
6 5 12/44 27.3%
6 6 0/44 0.0%
6 7 2/44 4.5%
6 8 7/44 15.9%
6 9 26/44 59.1%
7 0 6/44 13.6%
7 1 5/44 11.4%
7 2 8/44 18.2%
7 3 5/44 11.4%
7 4 8/44 18.2%
7 5 7/44 15.9%
7 6 3/44 6.8%
7 7 1/44 2.3%
7 8 9/44 20.5%
7 9 26/44 59.1%
8 0 5/44 11.4%
8 1 18/44 40.9%
8 2 2/44 4.5%
8 3 16/44 36.4%
8 4 3/44 6.8%
8 5 3/44 6.8%
8 6 7/44 15.9%
8 7 10/44 22.7%
8 8 18/44 40.9%
8 9 13/44 29.5%
9 0 28/44 63.6%
9 1 19/44 43.2%
9 2 27/44 61.4%
9 3 16/44 36.4%
9 4 18/44 40.9%
9 5 12/44 27.3%
9 6 29/44 65.9%
9 7 20/44 45.5%
9 8 12/44 27.3%
9 9 29/44 65.9%

一言で言えば凄惨です。
Mnistの分類は通常90%くらいは余裕で行くため、
このタスクでも90%くらいは行く前提で
その後に何かをしようと思っていたのですが、
圧倒的にタスクそのものをこなせていません。

期待値的には、おそらく「9」(0+9から1+8まで組み合わせが多い)ので、
9の精度が良いと思えば、6+3はそうでもありません。

WANDBを用いた結果の推移
スクリーンショット 2020-02-16 18.24.19.png

もっとepochを増やせば少しは精度があがるかもしれませんが、30%は超えなさそうな勢いです。

4 反省と展望

感想
不思議やなぁ...って感じです。掛け算99のように数字の組み合わせとその結果を覚えてくれればうまくいくと思ってましたがそううまく行かないようです。

聞いた話では、Cassificationとも言えないタスクでCrossEntoropyだと限界があるとかなんとか。
しかし、回帰なら良いかというとそういう話でも無いと思うので、マルチタスクで複数のロスをつけることを検討中

もしネットワーク構造などで、間違っていることがあればコメント蘭で指摘いただきたいです。
閲覧ありがとうございました。

0
1
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1