概要
書籍『ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装』4章のコードを参考に、損失関数 (loss function) として2乗和誤差 (mean squared error) と交差エントロピー誤差 (cross entropy error) を Python と Ruby で実装する。
計算処理では外部ライブラリを利用する。Python では NumPy を、Ruby では Numo::NArray を使用する。
環境構築が必要な場合はこちらを参照。
→ Python vs Ruby 『ゼロから作るDeep Learning』 1章 sin関数とcos関数のグラフ - Qiita
2乗和誤差 (mean squared error) と交差エントロピー誤差 (cross entropy error) の実装
Python
import numpy as np
# 2乗和誤差
def mean_squared_error(y, t):
# ニューラルネットワークの出力と教師データの各要素の差の2乗、の総和
return 0.5 * np.sum((y-t)**2)
# 交差エントロピー誤差
def cross_entropy_error(y, t):
delta = 1e-7 # マイナス無限大を発生させないように微小な値を追加する
return -np.sum(t * np.log(y + delta))
# テスト
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0] # 正解が1, それ以外が0
y1 = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0] # 2の確率が最も高い場合(0.6)
y2 = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0] # 7の確率が最も高い場合(0.6)
print(mean_squared_error(np.array(y1), np.array(t)))
print(mean_squared_error(np.array(y2), np.array(t)))
print(cross_entropy_error(np.array(y1), np.array(t)))
print(cross_entropy_error(np.array(y2), np.array(t)))
Ruby
require 'numo/narray'
# 2乗和誤差
def mean_squared_error(y, t)
# ニューラルネットワークの出力と教師データの各要素の差の2乗、の総和
return 0.5 * ((y-t)**2).sum
end
# 交差エントロピー誤差
def cross_entropy_error(y, t)
delta = 1e-7 # マイナス無限大を発生させないように微小な値を追加する
return -(t * Numo::NMath.log(y + delta)).sum
end
# テスト
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0] # 正解が1, それ以外が0
y1 = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0] # 2の確率が最も高い場合(0.6)
y2 = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0] # 7の確率が最も高い場合(0.6)
puts mean_squared_error(Numo::DFloat.asarray(y1), Numo::DFloat.asarray(t))
puts mean_squared_error(Numo::DFloat.asarray(y2), Numo::DFloat.asarray(t))
puts cross_entropy_error(Numo::DFloat.asarray(y1), Numo::DFloat.asarray(t))
puts cross_entropy_error(Numo::DFloat.asarray(y2), Numo::DFloat.asarray(t))
実行結果
Python
0.0975
0.5975
0.510825457099
2.30258409299
Ruby
0.09750000000000003
0.5974999999999999
0.510825457099338
2.302584092994546
参考資料
- Python vs Ruby 『ゼロから作るDeep Learning』 まとめ - Qiita http://qiita.com/niwasawa/items/b8191f13d6dafbc2fede