0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

『機械学習のエッセンス(http://isbn.sbcr.jp/93965/)』のPythonサンプルをJuliaで書き換えてみる。(第05章05ラッソ回帰)

Last updated at Posted at 2019-03-10

はじめに

『機械学習のエッセンス(http://isbn.sbcr.jp/93965/)』のPythonサンプルをJuliaで書き換えてみる。(第05章04汎化と過学習)の続きです。

記号の説明

第05章02回帰の特徴量ベクトルが多次元の場合の記号の説明をもう一度記載しておきます。

y = w_0 + w_1x_1 + w_2x_2 + \cdots w_dx_d + \varepsilon

$(x_0,\cdots,x_d)^T$は入力変数、$w_0, w_1, \cdots, w_d$はパラメータ、$y$はターゲット、$\varepsilon$はノイズ。

ベクトル$x=(x_1,x_2,\cdots,x_d)^T$に対して要素1を付加したベクトル$\tilde{x}$、ベクトル$w=(w_0,w_1,\cdots,w_d)^T$とすると、

y = w^T\tilde{x}

行列$X$について、左に1列追加してその要素をすべて1としたものを$\tilde{X}$とする。

\hat{y}(w) = \tilde{X}w 

ラッソ回帰

ラッソ回帰はパラメータ$\lambda$を追加して、

E(w) = ||y - \hat{y}(w)||^2 + \lambda|w|_1

を最小化する$w$を決定する。(リッジ回帰の時は最後の項が$\lambda||w||^2$だった。)
$|\cdot|_1$はL1ノルムと呼ばれ、次で定義される。

|w|_1 = \sum_{i=1}^d|w_i|

座標降下法によるラッソ回帰の実装

※ソースコードの補足程度しか載せていないので、詳細は本をご参照ください。

w_0 = \frac{1}{n}\sum_{i=1}^n(y_i - \sum_{j=1}^d x_{ij}w_j)

ソフト閾値関数$S$を次のように定義する。

S(p,q) = sgn(p)max\{0, ||p|| - q\}

ただし、

sgn(x) = 
\begin{cases}
  -1 & (x \lt 0) \\
  0 & (x = 0) \\
  1 & (x \gt 0)  
\end{cases}

更新後の$w_k$の値$\overline{w}_k$は次で表される。

\overline{w}_k = \frac{S(\sum_{i=1}^n(y_i - w_0 - \sum_{j \neq k}x_{ij}w_j)x_{ik}, \lambda)}{\sum_{i=1}^nx_{ik}^2w_k}
lasso.jl
module lasso
using LinearAlgebra

soft_thresholding(x, y) = sign(x) * max(abs(x) - y, 0)

mutable struct Lasso
    λ_
    tol
    max_iter
    w_
    function Lasso(λ_, tol = 0.0001, max_iter = 1000)
        new(λ_, tol, max_iter, Nothing)
    end
end

function fit(s::Lasso, X::Array, t)
    n = size(X)[1]
    if ndims(X) == 1
        d = 0
    else
        d = size(X)[2]
    end
    s.w_ = zeros(d + 1)
    avgl1 = 0.
    for i in collect(0:s.max_iter)
        avgl1_prev = avgl1
        _update(s, n, d, X, t)
        avgl1 = sum(abs.(s.w_)) / size(s.w_)[1]
        if abs(avgl1 - avgl1_prev) <= s.tol
            break
        end
    end
end

function _update(s::Lasso, n, d, X, t)
    s.w_[1] = sum(t .- X * s.w_[2:end]) ./ n
    w0vec = ones(n) .* s.w_[1]
    for k in collect(1:d)
        ww = s.w_[2:end]
        ww[k] = 0
        q = dot((t - w0vec - X * ww) ,X[:, k])
        r = dot(X[:, k], X[:, k])
        s.w_[k + 1] = soft_thresholding(q / r, s.λ_)
    end
end

function predict(s::Lasso, X)
    if ndims(X) == 1
        X = reshape(X, 1, :)
    end
    Xtil = hcat(ones(size(X)[1]), X)
    return Xtil * s.w_
end

end

ワインの品質データにラッソ回帰を適用

第05章02回帰
のところで使ったワインの品質データを使います。

lasso_winequality1.jl
include("lasso.jl")
using .lasso
using CSV
using Random
using Statistics
using Printf

dataframe = CSV.read("winequality-red.csv", header=true, delim=';')
row,col=size(dataframe)

Xy = Float64[dataframe[r,c] for r in 1:row, c in 1:col]

Random.seed!(0)
shuffle(Xy)

train_X = Xy[1:row-1000, 1:col-1]
train_y = Xy[1:row-1000, col-1]
test_X = Xy[row-999:row, 1:col-1]
test_y = Xy[row-999:row, col-1]

for λ_ in [1., 0.1, 0.01]
  model = lasso.Lasso(λ_)
  lasso.fit(model, train_X, train_y)
  y = lasso.predict(model, test_X)
  println("--- λ = $(λ_) ---")
  println("coefficients:")
  println(model.w_)
  mse = mean((y - test_y).^2)
  println("MSE:$(@sprintf("%0.3f", mse))")
end

実行結果

julia> include("lasso_winequality1.jl")
--- λ = 1.0 ---
coefficients:
[10.1618, 0.0, -0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0]
MSE:1.304
--- λ = 0.1 ---
coefficients:
[10.1755, 0.0, -0.0, 0.845283, 0.0, -3.07043, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0]
MSE:1.266
--- λ = 0.01 ---
coefficients:
[9.8547, -0.0, 0.0458227, 1.09543, 0.0288596, -4.17312, -0.0, -0.0, -0.0, 0.0, 0.359558, 0.0]
MSE:1.271

本と、結果の配列の値もMSEの値も違うので不安です。(本では、$\lambda$の3つの値に対して、配列の最初がだいたい5.5〜5.7くらいで、MSEがそれぞれ、0.6910.6360.539です。)ただし、配列の0でない程度(疎の程度)は同じような感じになりました。
第05章02回帰の実践的な例のところでも誤差の範囲内と見なしたのでここでもOKと見なすことにします。(RMSEは1以上ずれていないならまあままの意味としていました。MSEは基準が書いていないのですが、平方根を取っていないだけなので同じと見なしました。)

補足

lasso.jlについて

最初、本のソースコードに近い形でそのまま書いたところたくさんエラーが出ました。原因は関数の使い方で、dotNumpyでは行列積も内積も計算できますが、第04章03配列の基本計算で書いたように、Juliaでは行列積は*で、内積はdot、あるいは$\cdot$記号を使います。もともとのPythonのサンプルでdotと書いてある箇所が、行列積か内積かを判別して書き換えないといけないので、数式との対比など結構難しかったです。単純に*でエラーになった部分をブロードキャストで.*と書いてしまうと、本来内積でスカラー値になるもの(例えばここではqrなど)が配列になってしまいおかしな結果になってしまいます。
だんだん複雑になってきたので、Pythonの配列が0始まりで、Juliaが1始まりというのも書き換えのときに結構苦労する箇所でした。(最初からJuliaで書く場合は特に意識はしないかもしれませんが。)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?