(ns prml.prml1-1)

(use 'clojure.set)   ; incanter対策
(use '(incanter core charts))

; データ作成
(defn test-func [x] (sin (* 2 Math/PI x)))
(def x-values [0.0, 0.111111, 0.222222, 0.333333, 0.444444, 0.555556, 0.666667, 0.777778, 0.888889, 1.0])
(def true-y-values (map #'test-func x-values))
(def observed-y-values [0.349486, 0.830839, 1.007332, 0.971507, 0.133066, 0.166823, -0.848307, -0.445686, -0.563567, 0.261502])

; 式 1.1
; 係数がw-listであるようなxについての多項式。w-listの最初の項は定数項、次数は(- (count w-list) 1)
(defn func-y [x, w-list]
    (reduce + (map #(* (nth w-list %) (Math/pow x %))  (range 0 (count w-list)))))

; 式 1.2
; func-yについての二乗誤差関数。x-listは観測値のベクタ、t-listが真の値のベクタ。 w-listは係数項のベクタ。
(defn func-error [x-list, w-list, t-list]
  (if (== (count x-list) (count t-list))
    (* 0.5
       (reduce +
         (map #(Math/pow (- (func-y (nth x-list %) w-list) (nth t-list %)) 2.0)
           (range 0 (count x-list)))))

; 式 1.4
; func-yについての二乗誤差関数を正則化。x-listは観測値のベクタ、t-listが真の値のベクタ。 w-listは係数項のベクタ。
; lambdaは罰金項
(defn func-error-with-penalty [x-list, w-list, t-list, lambda]
  (if (== (count x-list) (count t-list))
    (+ (* 0.5
       (reduce +
         (map #(Math/pow (- (func-y (nth x-list %) w-list) (nth t-list %)) 2.0)
           (range 0 (count x-list)))))
       (* 0.5 lambda (reduce + (map #(* (nth w-list %) (nth w-list %)) (range 0 (count w-list)))))

; 式 1.123
; 参考	http://d.hatena.ne.jp/aidiary/20100327/1269657354
; 多項式近似の解の計算(演習1.1)
(defn calc-a-sub [x-list, i, j]
  (reduce + (map #(Math/pow (nth x-list %) (+ i j)) (range 0 (count x-list)))))

(defn calc-a [x-list, m]
  (matrix (for [i (range 0 (inc m))]
                (for [j (range 0 (inc m))]
                  (calc-a-sub x-list i j)))))

(defn calc-a-with-penalty [x-list, m, lambda]
  (matrix (for [i (range 0 (inc m))]
                (for [j (range 0 (inc m))]
                  (if (= i j)
                    (+ (calc-a-sub x-list i j) lambda)
                    (calc-a-sub x-list i j)

(defn calc-t-sub [x-list, t-list, i]
  (reduce + (map #(* (Math/pow (nth x-list %) i) (nth t-list %))(range 0 (count x-list)))))

(defn calc-t [x-list, t-list, m]
   (matrix (map #(calc-t-sub x-list t-list %) (range 0 (inc m)))))

; 係数項ベクタの推定
(defn estimate [x-list, t-list m]
  (solve (calc-a x-list m) (calc-t x-list t-list m)))

(defn estimate-with-penalty [x-list, t-list m lambda]
   (solve (calc-a-with-penalty x-list m lambda) (calc-t x-list t-list m)))

; 誤差確認
(println "観測値に対する誤差確認")
(println (func-error x-values (to-vect (estimate x-values observed-y-values 3))  observed-y-values))
(println (func-error x-values (to-vect (estimate x-values observed-y-values 5))  observed-y-values))
(println (func-error x-values (to-vect (estimate x-values observed-y-values 9))  observed-y-values))
(println (func-error x-values (to-vect (estimate-with-penalty x-values observed-y-values 9 (Math/exp -18)))  observed-y-values))

(println "理論値に対する誤差確認")
(println (func-error x-values (to-vect (estimate x-values observed-y-values 3))  true-y-values))
(println (func-error x-values (to-vect (estimate x-values observed-y-values 5))  true-y-values))
(println (func-error x-values (to-vect (estimate x-values observed-y-values 9))  true-y-values))
(println (func-error x-values (to-vect (estimate-with-penalty x-values observed-y-values 9 (Math/exp -18)))  true-y-values))

; グラフ描画
(def plot (scatter-plot x-values observed-y-values :title "PRML1.1" :y-label "y value" :x-label "x value"))
(add-function plot test-func 0.0 1.0)
(add-points plot x-values true-y-values)
(add-function plot #(func-y % (to-vect (estimate x-values observed-y-values 9)))  0.0 1.0)
(add-function plot #(func-y % (to-vect (estimate-with-penalty x-values observed-y-values 9 (Math/exp -18))))  0.0 1.0)
(set-theme plot :dark)

(view plot)

