Help us understand the problem. What is going on with this article?

Python vs Ruby 『ゼロから作るDeep Learning』 4章 損失関数 (loss function) の実装

More than 3 years have passed since last update.

概要

書籍『ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装』4章のコードを参考に、損失関数 (loss function) として2乗和誤差 (mean squared error) と交差エントロピー誤差 (cross entropy error) を Python と Ruby で実装する。

計算処理では外部ライブラリを利用する。Python では NumPy を、Ruby では Numo::NArray を使用する。

環境構築が必要な場合はこちらを参照。
Python vs Ruby 『ゼロから作るDeep Learning』 1章 sin関数とcos関数のグラフ - Qiita

2乗和誤差 (mean squared error) と交差エントロピー誤差 (cross entropy error) の実装

Python

import numpy as np

# 2乗和誤差
def mean_squared_error(y, t):
  # ニューラルネットワークの出力と教師データの各要素の差の2乗、の総和
  return 0.5 * np.sum((y-t)**2)

# 交差エントロピー誤差
def cross_entropy_error(y, t):
  delta = 1e-7 # マイナス無限大を発生させないように微小な値を追加する
  return -np.sum(t * np.log(y + delta))

# テスト
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0] # 正解が1, それ以外が0
y1 = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0] # 2の確率が最も高い場合(0.6)
y2 = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0] # 7の確率が最も高い場合(0.6)
print(mean_squared_error(np.array(y1), np.array(t)))
print(mean_squared_error(np.array(y2), np.array(t)))
print(cross_entropy_error(np.array(y1), np.array(t)))
print(cross_entropy_error(np.array(y2), np.array(t)))

Ruby

require 'numo/narray'

# 2乗和誤差
def mean_squared_error(y, t)
  # ニューラルネットワークの出力と教師データの各要素の差の2乗、の総和
  return 0.5 * ((y-t)**2).sum
end

# 交差エントロピー誤差
def cross_entropy_error(y, t)
  delta = 1e-7 # マイナス無限大を発生させないように微小な値を追加する
  return -(t * Numo::NMath.log(y + delta)).sum
end

# テスト
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0] # 正解が1, それ以外が0
y1 = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0] # 2の確率が最も高い場合(0.6)
y2 = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0] # 7の確率が最も高い場合(0.6)
puts mean_squared_error(Numo::DFloat.asarray(y1), Numo::DFloat.asarray(t))
puts mean_squared_error(Numo::DFloat.asarray(y2), Numo::DFloat.asarray(t))
puts cross_entropy_error(Numo::DFloat.asarray(y1), Numo::DFloat.asarray(t))
puts cross_entropy_error(Numo::DFloat.asarray(y2), Numo::DFloat.asarray(t))

実行結果

Python

0.0975
0.5975
0.510825457099
2.30258409299

Ruby

0.09750000000000003
0.5974999999999999
0.510825457099338
2.302584092994546

参考資料

niwasawa
迷子になりがちな地図・位置情報系プログラマ。
http://niwasawa.hatenablog.jp/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away