Posted at

Python vs Ruby 『ゼロから作るDeep Learning』 4章 損失関数 (loss function) の実装

More than 1 year has passed since last update.


概要

書籍『ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装』4章のコードを参考に、損失関数 (loss function) として2乗和誤差 (mean squared error) と交差エントロピー誤差 (cross entropy error) を Python と Ruby で実装する。

計算処理では外部ライブラリを利用する。Python では NumPy を、Ruby では Numo::NArray を使用する。

環境構築が必要な場合はこちらを参照。

Python vs Ruby 『ゼロから作るDeep Learning』 1章 sin関数とcos関数のグラフ - Qiita


2乗和誤差 (mean squared error) と交差エントロピー誤差 (cross entropy error) の実装


Python

import numpy as np

# 2乗和誤差
def mean_squared_error(y, t):
# ニューラルネットワークの出力と教師データの各要素の差の2乗、の総和
return 0.5 * np.sum((y-t)**2)

# 交差エントロピー誤差
def cross_entropy_error(y, t):
delta = 1e-7 # マイナス無限大を発生させないように微小な値を追加する
return -np.sum(t * np.log(y + delta))

# テスト
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0] # 正解が1, それ以外が0
y1 = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0] # 2の確率が最も高い場合(0.6)
y2 = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0] # 7の確率が最も高い場合(0.6)
print(mean_squared_error(np.array(y1), np.array(t)))
print(mean_squared_error(np.array(y2), np.array(t)))
print(cross_entropy_error(np.array(y1), np.array(t)))
print(cross_entropy_error(np.array(y2), np.array(t)))


Ruby

require 'numo/narray'

# 2乗和誤差
def mean_squared_error(y, t)
# ニューラルネットワークの出力と教師データの各要素の差の2乗、の総和
return 0.5 * ((y-t)**2).sum
end

# 交差エントロピー誤差
def cross_entropy_error(y, t)
delta = 1e-7 # マイナス無限大を発生させないように微小な値を追加する
return -(t * Numo::NMath.log(y + delta)).sum
end

# テスト
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0] # 正解が1, それ以外が0
y1 = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0] # 2の確率が最も高い場合(0.6)
y2 = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0] # 7の確率が最も高い場合(0.6)
puts mean_squared_error(Numo::DFloat.asarray(y1), Numo::DFloat.asarray(t))
puts mean_squared_error(Numo::DFloat.asarray(y2), Numo::DFloat.asarray(t))
puts cross_entropy_error(Numo::DFloat.asarray(y1), Numo::DFloat.asarray(t))
puts cross_entropy_error(Numo::DFloat.asarray(y2), Numo::DFloat.asarray(t))


実行結果


Python

0.0975

0.5975
0.510825457099
2.30258409299


Ruby

0.09750000000000003

0.5974999999999999
0.510825457099338
2.302584092994546


参考資料