Qiitaに新規登録させていただいたので、とっても恥ずかしいレベルですが、PRMLの1章の始めのところをClojureで書いてみたものを共有します(まだ、ここまでしか進んでいません)。たぶん、もっとずっと綺麗な書き方があると思いますが、大目に見てください。
prml1_1.clj
(ns prml.prml1-1)
(use 'clojure.set) ; incanter対策
(use '(incanter core charts))
; データ作成
(defn test-func [x] (sin (* 2 Math/PI x)))
(def x-values [0.0, 0.111111, 0.222222, 0.333333, 0.444444, 0.555556, 0.666667, 0.777778, 0.888889, 1.0])
(def true-y-values (map #'test-func x-values))
(def observed-y-values [0.349486, 0.830839, 1.007332, 0.971507, 0.133066, 0.166823, -0.848307, -0.445686, -0.563567, 0.261502])
; 式 1.1
; 係数がw-listであるようなxについての多項式。w-listの最初の項は定数項、次数は(- (count w-list) 1)
(defn func-y [x, w-list]
(reduce + (map #(* (nth w-list %) (Math/pow x %)) (range 0 (count w-list)))))
; 式 1.2
; func-yについての二乗誤差関数。x-listは観測値のベクタ、t-listが真の値のベクタ。 w-listは係数項のベクタ。
(defn func-error [x-list, w-list, t-list]
(if (== (count x-list) (count t-list))
(* 0.5
(reduce +
(map #(Math/pow (- (func-y (nth x-list %) w-list) (nth t-list %)) 2.0)
(range 0 (count x-list)))))
nil
)
)
; 式 1.4
; func-yについての二乗誤差関数を正則化。x-listは観測値のベクタ、t-listが真の値のベクタ。 w-listは係数項のベクタ。
; lambdaは罰金項
(defn func-error-with-penalty [x-list, w-list, t-list, lambda]
(if (== (count x-list) (count t-list))
(+ (* 0.5
(reduce +
(map #(Math/pow (- (func-y (nth x-list %) w-list) (nth t-list %)) 2.0)
(range 0 (count x-list)))))
(* 0.5 lambda (reduce + (map #(* (nth w-list %) (nth w-list %)) (range 0 (count w-list)))))
)
nil
)
)
; 式 1.123
; 参考 http://d.hatena.ne.jp/aidiary/20100327/1269657354
; 多項式近似の解の計算(演習1.1)
(defn calc-a-sub [x-list, i, j]
(reduce + (map #(Math/pow (nth x-list %) (+ i j)) (range 0 (count x-list)))))
(defn calc-a [x-list, m]
(matrix (for [i (range 0 (inc m))]
(for [j (range 0 (inc m))]
(calc-a-sub x-list i j)))))
(defn calc-a-with-penalty [x-list, m, lambda]
(matrix (for [i (range 0 (inc m))]
(for [j (range 0 (inc m))]
(if (= i j)
(+ (calc-a-sub x-list i j) lambda)
(calc-a-sub x-list i j)
)))))
(defn calc-t-sub [x-list, t-list, i]
(reduce + (map #(* (Math/pow (nth x-list %) i) (nth t-list %))(range 0 (count x-list)))))
(defn calc-t [x-list, t-list, m]
(matrix (map #(calc-t-sub x-list t-list %) (range 0 (inc m)))))
; 係数項ベクタの推定
(defn estimate [x-list, t-list m]
(solve (calc-a x-list m) (calc-t x-list t-list m)))
(defn estimate-with-penalty [x-list, t-list m lambda]
(solve (calc-a-with-penalty x-list m lambda) (calc-t x-list t-list m)))
; 誤差確認
(println "観測値に対する誤差確認")
(println (func-error x-values (to-vect (estimate x-values observed-y-values 3)) observed-y-values))
(println (func-error x-values (to-vect (estimate x-values observed-y-values 5)) observed-y-values))
(println (func-error x-values (to-vect (estimate x-values observed-y-values 9)) observed-y-values))
(println (func-error x-values (to-vect (estimate-with-penalty x-values observed-y-values 9 (Math/exp -18))) observed-y-values))
(println "理論値に対する誤差確認")
(println (func-error x-values (to-vect (estimate x-values observed-y-values 3)) true-y-values))
(println (func-error x-values (to-vect (estimate x-values observed-y-values 5)) true-y-values))
(println (func-error x-values (to-vect (estimate x-values observed-y-values 9)) true-y-values))
(println (func-error x-values (to-vect (estimate-with-penalty x-values observed-y-values 9 (Math/exp -18))) true-y-values))
; グラフ描画
(def plot (scatter-plot x-values observed-y-values :title "PRML1.1" :y-label "y value" :x-label "x value"))
(add-function plot test-func 0.0 1.0)
(add-points plot x-values true-y-values)
(add-function plot #(func-y % (to-vect (estimate x-values observed-y-values 9))) 0.0 1.0)
(add-function plot #(func-y % (to-vect (estimate-with-penalty x-values observed-y-values 9 (Math/exp -18)))) 0.0 1.0)
(set-theme plot :dark)
(view plot)