4
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Python vs Ruby 『ゼロから作るDeep Learning』 4章 損失関数 (loss function) の実装

Posted at

概要

書籍『ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装』4章のコードを参考に、損失関数 (loss function) として2乗和誤差 (mean squared error) と交差エントロピー誤差 (cross entropy error) を Python と Ruby で実装する。

計算処理では外部ライブラリを利用する。Python では NumPy を、Ruby では Numo::NArray を使用する。

環境構築が必要な場合はこちらを参照。
Python vs Ruby 『ゼロから作るDeep Learning』 1章 sin関数とcos関数のグラフ - Qiita

2乗和誤差 (mean squared error) と交差エントロピー誤差 (cross entropy error) の実装

Python

import numpy as np

# 2乗和誤差
def mean_squared_error(y, t):
  # ニューラルネットワークの出力と教師データの各要素の差の2乗、の総和
  return 0.5 * np.sum((y-t)**2)

# 交差エントロピー誤差
def cross_entropy_error(y, t):
  delta = 1e-7 # マイナス無限大を発生させないように微小な値を追加する
  return -np.sum(t * np.log(y + delta))

# テスト
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0] # 正解が1, それ以外が0
y1 = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0] # 2の確率が最も高い場合(0.6)
y2 = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0] # 7の確率が最も高い場合(0.6)
print(mean_squared_error(np.array(y1), np.array(t)))
print(mean_squared_error(np.array(y2), np.array(t)))
print(cross_entropy_error(np.array(y1), np.array(t)))
print(cross_entropy_error(np.array(y2), np.array(t)))

Ruby

require 'numo/narray'

# 2乗和誤差
def mean_squared_error(y, t)
  # ニューラルネットワークの出力と教師データの各要素の差の2乗、の総和
  return 0.5 * ((y-t)**2).sum
end

# 交差エントロピー誤差
def cross_entropy_error(y, t)
  delta = 1e-7 # マイナス無限大を発生させないように微小な値を追加する
  return -(t * Numo::NMath.log(y + delta)).sum
end

# テスト
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0] # 正解が1, それ以外が0
y1 = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0] # 2の確率が最も高い場合(0.6)
y2 = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0] # 7の確率が最も高い場合(0.6)
puts mean_squared_error(Numo::DFloat.asarray(y1), Numo::DFloat.asarray(t))
puts mean_squared_error(Numo::DFloat.asarray(y2), Numo::DFloat.asarray(t))
puts cross_entropy_error(Numo::DFloat.asarray(y1), Numo::DFloat.asarray(t))
puts cross_entropy_error(Numo::DFloat.asarray(y2), Numo::DFloat.asarray(t))

実行結果

Python

0.0975
0.5975
0.510825457099
2.30258409299

Ruby

0.09750000000000003
0.5974999999999999
0.510825457099338
2.302584092994546

参考資料

4
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?