LoginSignup
36
34

More than 5 years have passed since last update.

ニューラルネットの学習過程を雑にグラフ化してみた

Last updated at Posted at 2015-08-29

機械学習 - 雑にニューラルネットを実装して二値分類してみた - Qiita

続き。

ソース:deltam/neuro

バイアス項、重要

タイトルと関係ないけど、この資料が超わかりやすくて感動した!!
バイアス項ってなんで必要なのか分かってなかったけど、スッキリ理解できた! バイアス項、重要。

CHAPTER 4 ニューラルネットワークが任意の関数を表現できることの視覚的証明

前記事だとバイアス項の実装が雑すぎて3層以上のNNだと計算がうまくいってませんでした。
そこで重み行列にくっつけるカタチでバイアス項の実装を更新しました。

lein-gorilla、便利

前記事ではグラフにプロットするのに、CSV出力して表計算ソフトでグラフ化、という周りくどいことをしていました。
lein-gorillaというleiningenプラグインを使うと、ブラウザ上でグラフを描けます。
Incanterというライブラリはもっと多機能らしいけど、とりあえず最小限の要素で始めてみたい。

参考:ブラウザ上でデータ分析が出来る!Clojure/Gorilla入門 - あんちべ!

学習過程をグラフでモニタリング

まずlein-gorillaをproject.cljに追加します。

project.clj
(defproject neuro "0.1.0-SNAPSHOT"
   :description "FIXME: write description"
   :url "http://example.com/FIXME"
   :license {:name "Eclipse Public License"
             :url "http://www.eclipse.org/legal/epl-v10.html"}
   :plugins [[cider/cider-nrepl "0.10.0-SNAPSHOT"]
             [lein-gorilla "0.3.4"]]                        ;; <= これ
   :dependencies [[org.clojure/clojure "1.6.0"]
                  [org.clojure/data.generators "0.1.2"]])

そのつぎにサーバを立ち上げてciderとブラウザでつなぐ。

$ lein gorilla :port 9999 :nrepl-port 9998
Gorilla-REPL: 0.3.4
Started nREPL server on port 9999
Running at http://127.0.0.1:9998/worksheet.html .
Ctrl+C to exit.

Emacs:

M-x cider-connect RET
localhost RET
9999 RET

ブラウザ: http://127.0.0.1:9999/worksheet.html

スクリーンショット 2015-08-29 19.20.50.png

こんなふうに誤差関数の値をvectorに貯めるようにしといて、lein-gorillaでプロットできるようにしときました。

train.clj

 (def +train-err-vec+ (atom []))
 (def +test-err-vec+ (atom []))
 (def +learning-rate+ (atom 8.0))
 (def +go-next-batch+ (atom false))

 (defn train-init []
   (reset! +train-err-vec+ [])
   (reset! +test-err-vec+ [])
   (reset! +go-next-batch+ false))

 (defn- monitoring
   "学習過程をレポートする"
   [epoc train-err test-err]
   (if (zero? (mod epoc *report-period*))
     (do (swap! +train-err-vec+ conj train-err)
         (swap! +test-err-vec+ conj test-err))))

(defn- momentum
   "モメンタム項の計算"
   [pre-nn cur-nn nn]
   (nw/map-nn (fn [w l i o]
               (let [dw ( - (nw/wget cur-nn l i o)
                            (nw/wget pre-nn l i o))]
                  (+ w (* *momentum-param* dw))))
              nn))

 (defn train
   "NNの学習を行なう"
   [init-nn efn w-updater dataset testset terminate-f]
   (loop [pre-nn init-nn, cur-nn init-nn, err (efn init-nn dataset) , epoc 0]
     (let [next-nn (momentum pre-nn cur-nn (w-updater cur-nn efn dataset))
           train-err (efn next-nn dataset)
           test-err (efn next-nn testset)]
       (monitoring epoc train-err test-err)
       (if (or @+go-next-batch+ (terminate-f err train-err))
         (do (reset! +go-next-batch+ false)
             cur-nn)
         (recur cur-nn, next-nn, train-err, (inc epoc))))))

サンプルの学習を走らせておく:

repl
user> (require '[neuro.network :as nw])
nil
user> (require '[neuro.train :as tr])
nil
user> (def nn (nw/gen-nn :rand 2 10 1))

#'user/nn
user> ;; doc/train.mdからコピペして定義

#'user/traindata-2class

user> (def testdata-2class (take 10 (shuffle traindata-2class))) ; テストデータを用意するのが面倒だったので訓練データからサンプリング(雑)
#'user/testdata-2class

user> (def nn2
        (tr/train-sgd nn tr/err-fn-2class traindata-2class testdata-2class (fn [_ e] (< e 0.4))))
batch start  4

gorillaにはこんなコードを書いてShift+Enter。

(ns wandering-stream
  (:require [gorilla-plot.core :as plot])
  (:require [neuro.train :as tr]))
(plot/list-plot (take-last 500 @tr/+train-err-vec+) :joined true)

スクリーンショット 2015-08-29 19.38.59.png

Ctrl+Shift+Enterを押せばgorilla上のコードをすべて再評価してくれるので、連打すればリアルタイムモニタリング!

パラメタの調整

gorillaのブラウザ上でREPLが走っているので、学習を走らせながらパラメタを書き換えて様子を見てみることができます。

(reset! tr/+learning-rate+ 3) ; 学習係数を調整する
(reset! tr/+go-next-batch+ true) ; 次のミニバッチに強制的にすすめる

スクリーンショット 2015-08-29 20.08.46.png

学習誤差の減り具合とかを見て調整してみると面白いです。

学習結果のプロット

前の記事では表計算ソフトを使ってましたが、lein-gorilla上で完結できて便利になりました。
(いまはベタ書きしてるけど、プロットする関数を定義しとくと便利かも)

訓練データをプロット:

(let [data0 (filter (fn [{[a] :ans}] (= a 0)) user/traindata-2class)
      series0 (map (fn [{[x y _] :x}] [x y]) data0)
      data1 (filter (fn [{[a] :ans}] (= a 1)) user/traindata-2class)
      series1 (map (fn [{[x y _] :x}] [x y]) data1)]
(plot/compose
  (plot/list-plot series0 :aspect-ratio 1.0 :plot-range [[0 9] [0 9]] :color :blue)
  (plot/list-plot series1 :aspect-ratio 1.0 :plot-range [[0 9] [0 9]] :color :red)))

スクリーンショット 2015-08-31 20.09.49.png

学習結果を同じようにプロット:

(let [series0 (for [x (range 10) y (range 10)
                    :let [[a] (neuro.core/nn-calc user/nn2 [x y])]
                    :when (< a 0.5)]
                [x y])
      series1 (for [x (range 10) y (range 10)
                    :let [[a] (neuro.core/nn-calc user/nn2 [x y])]
                    :when (>= a 0.5)]
                [x y])]
  (plot/compose
    (plot/list-plot series0 :aspect-ratio 1.0 :plot-range [[0 9] [0 9]] :color :blue)
    (plot/list-plot series1 :aspect-ratio 1.0 :plot-range [[0 9] [0 9]] :color :red)))

スクリーンショット 2015-08-30 7.39.35.png

どの程度学習できているか比較できますね(x,y < 3 あたりが上手く分類できてないとか)

学習誤差とテスト誤差を比較

学習誤差とテスト誤差の差が開いていく場合、過学習になってる可能性があるのでそのミニバッチを打ち切ったほうがいいらしい(early stopping)ので、その様子を監視するため次のようなグラフを表示してみるのもいいですね。

(ns wandering-stream
  (:require [gorilla-plot.core :as plot])
  (:require [neuro.train :as tr]))
(let [range 500]
  (plot/compose
    (plot/list-plot (take-last range @tr/+train-err-vec+) :joined true :color :blue)
    (plot/list-plot (take-last range @tr/+test-err-vec+) :joined true :color :red)))

スクリーンショット 2015-08-30 7.38.53.png

まとめ

  • グラフにプロットするのが楽になって実験が進めやすくなった
  • バイアス項の計算をちゃんと実装しました。

参考資料

ニューラルネットワークと深層学習

  • 日本語訳、超たすかる!!!!

Deep learning実装の基礎と実践

  • 実験の進め方がとくに参考になった。
    • 学習係数の決め方、プロット重要
    • 学習曲線ながめてると日が暮れるので注意

次回

学習が遅くて実験を進めづらいので、そろそろバックプロパゲーションを実装しようかな。

36
34
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
36
34