LoginSignup
1
2

More than 5 years have passed since last update.

Lua版 ゼロから作るDeep Learning その8[損失関数]

Last updated at Posted at 2017-06-25

過去記事

Lua版 ゼロから作るDeep Learning その1[パーセプトロンの実装]
Lua版 ゼロから作るDeep Learning その2[活性化関数]
Lua版 ゼロから作るDeep Learning その3[3層ニューラルネットワークの実装]
Lua版 ゼロから作るDeep Learning その4[ソフトマックス関数の実装]
Lua版 ゼロから作るDeep Learning その5[MNIST画像の表示]
Lua版 ゼロから作るDeep Learning その5.5[pklファイルをLua Torchで使えるようにする]
Lua版 ゼロから作るDeep Learning その6[ニューラルネットワークの推論処理]
Lua版 ゼロから作るDeep Learning その7[バッチ処理]

損失関数

 今回は原書4章の損失関数を実装します。 
 
 スクリプトは以下のようになります。

lossfunc.lua
---2乗和誤差算出関数
-- テンソル同士の2乗和誤差(∑(yi-ti)^2)/2を求める
-- @param y 入力1、今回はNNが出力する確率リスト {Type:Tensor}
-- @param t 入力2、今回は正解ラベルリスト {Type:ByteTensor}
-- @return 2乗和誤差 {Type:number}
function mean_squared_error(y, t)
    return ( y:double() - t:double() ):pow(2):sum() * 0.5
end

local t = torch.Tensor({0,0,1,0,0,0,0,0,0,0})
local y = torch.Tensor({0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0})
print(mean_squared_error(y,t))
y = torch.Tensor({0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.5, 0.0})
print(mean_squared_error(y,t))


---交差エントロピー誤差算出関数
-- テンソル同士の交差エントロピー誤差(-∑tilogyi)を求める
-- @param y 入力1、今回はNNが出力する確率リスト {Type:Tensor}
-- @param t 入力2、今回は正解ラベルリスト {Type:ByteTensor}
-- @return 交差エントロピー誤差 {Type:number}
function cross_entropy_error(y, t)
    local delta = 1e-7
    return -torch.cmul(t:double(), ( y:double() + delta ):log() ):sum()
end

local t = torch.Tensor({0,0,1,0,0,0,0,0,0,0})
local y = torch.Tensor({0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0})
print(cross_entropy_error(y,t))
y = torch.Tensor({0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.5, 0.0})
print(cross_entropy_error(y,t))
出力結果
$ th lossfunc.lua
0.0975  
0.7225  
0.51082545709934    
2.3025840929945 

 特につまづくところもないかと思います。
 

おわりに

 今回は以上です。

 次回は交差エントロピー誤差をバッチ処理に対応させましょう。
 
 ありがとうございました。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2