14
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

darknetでMNISTを学習する

Last updated at Posted at 2016-11-27

Cで書かれたニューラルネット環境darknetですが、訓練済のネットを使う例はあちこちで見かけますが、訓練をしている例が見当たりません。darknet本家にもこの手のドキュメントはないので、ソースを読みながら見よう見まねでやってみることにしました。ソースはすっきりしていて読みやすいですね。

githubにもろもろ素材を上げておきました(こちら)ので参考にしてください。

準備

元データのダウンロード

data/mnist/以下に THE MNIST DATABASEから下記をダウンロードします。

  • t10k-images-idx3-ubyte.gz
  • t10k-labels-idx1-ubyte.gz
  • train-images-idx3-ubyte.gz
  • train-labels-idx1-ubyte.gz

gunzipで展開しておきます。

データフォーマット調整

データ・フォーマットはこちらの下のほうに書かれています。ヘッダがあってあとはバイトの羅列のようなのでnumpyで簡単に抜き出せます。

ラベルファイルという、クラス番号とクラスIDを紐付けるものが必要なようです。今回は下記のようにします。行がクラス番号0-9に相当します。mnist.label.listというファイルに書いておきます。

c0
c1
..
c9

画像についてはファイル名のリストをtxtで羅列しますが、classifierの場合はパスの文字列からクラスIDの文字列を検索することでクラスと対応付けているようなので、trainとvalidをそれぞれ下記のように、"[trainかvalidか]_[画像ID]_[クラスID].png"のように命名します。

  • t_00000_c0.png
  • v_00000_c0.png

ファイル名一覧をmnist.train.list, mnist.valid.listに書き出します。

ダウンロードから変換までを行うスクリプトがこちらです。

datasetファイル(データのありかなど)

cfg/mnist.dataset ファイルに記述します。

classes=10
train  = data/mnist/mnist.train.list
valid  = data/mnist/mnist.valid.list
backup = /tmp/backup/
labels = data/mnist.labels.list
names  = data/mnist.names.list
top = 5

train, valid, labels はさきほど作りましたね。namesはどのclass_idに名前をつける場合に使うテキストです。MNISTなので0から9と書いておけばよいですね。

0
1
...
9

cfgファイル(ネットワーク構成)

いよいよニューラルネットワークの構成を書きます。cfg/mnist_lenet.cfg ファイルに記述します。

冒頭には[net]という項目に、学習条件などが含まれています。注意ですが、darknetのclassifierはカラー画像を前提としていますので、channel=3と指定する必要があります。元画像が1チャネルでもopencvが3チャネル化しながら読み込みます。

angleからaspectはaugmentation用の条件ですね。これはclassifierでは使われていないように見えますが、ひとまず変形はしないということで全部1にしておきます。学習のループ回数は max_batches回です。

[net]
batch=100
subdivisions=1
height=28
width=28
channels=3
momentum=0.9
decay=0.00005
max_crop=28

learning_rate=0.01
policy=poly
power=4
max_batches=500

angle=1
hue=1
saturation=1
exposure=1
aspect=1

これ以降にはネットワーク構造を書きます。羅列しておくとSequencialな構造になるようですね。逆にSequencialな構造しか書けないのかな。

今回は TensorFlowのチュートリアルにある4層のLe-Netを記述してみます。

[convolutional]
filters=32
size=5
stride=1
pad=1
activation=relu

[maxpool]
size=2
stride=2

[convolutional]
filters=64
size=5
stride=1
pad=1
activation=relu

[maxpool]
size=2
stride=2

[connected]
output= 1024
activation=relu

[dropout]
probability=.5

[connected]
output= 10
activation=linear

[softmax]
groups=1

[cost]
type=sse

訓練

./darknet classifier train cfg/mnist.dataset cfg/mnist_lenet.cfg

costとか学習率とかが表示されていきます。最終的に、/tmp/backupに重みが出力されます。

推論

/tmp/backout/mnist_lenet.weightsをコピーしてきます。

predict.c/predict_classifier()のresize_network()でエラーになるので
network.c/resize_network()の

if(w == net->w && h == net->h) return 0;
のコメントを外してdarknetをビルドしなおします。

(2018/10/20追記:twitterで指摘いただいてチェックしましたが、現時点での最新のdarknetではエラーにはなりませんでした。)

./darknet classifier predict cfg/mnist.dataset cfg/mnist_lenet.cfg ./mnist_lenet.weights data/mnist/images/v_00000_c7.png

とすると、下記が出力されます。

layer     filters    size              input                output
    0 conv     32  5 x 5 / 1    28 x  28 x   3   ->    28 x  28 x  32
    1 max          2 x 2 / 2    28 x  28 x  32   ->    14 x  14 x  32
    2 conv     64  5 x 5 / 1    14 x  14 x  32   ->    14 x  14 x  64
    3 max          2 x 2 / 2    14 x  14 x  64   ->     7 x   7 x  64
    4 connected                            3136  ->  1024
    5 dropout       p = 0.50               1024  ->  1024
    6 connected                            1024  ->    10
    7 softmax                                          10
    8 cost                                             10
Loading weights from ./mnist_lenet.weights...Done!
data/mnist/images/v_00000_c7.png: Predicted in 0.004880 seconds.
7: 0.995198
5: 0.003397
9: 0.000618
0: 0.000574
3: 0.000106

7の画像が「99.5%で7」と認識されました。いいですね。

14
18
5

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?