はじめに
書籍「機械学習と深層学習」を購入して勉強していたら、C言語で書いてあるプログラムが簡単そうだったので、Rubyで書き換えてみることにしました。
ちなみに、書籍で紹介されているソースコードは以下のURLからダウンロード可能です。
bp1.c
第4章 ニューラルネット
4.2 バックプロパゲーションによるニューラルネットの学習
上記のbp1.cをRubyで書き換えてみました。
なお、完全に同一ではありませんので悪しからず。
また、内容の詳細についてはぜひ書籍をご購入のうえ確認して頂ければと思います。
bp1.rb
# ライブラリの読込
include Math
# 定数の定義
$inputno = 3
$hiddenno = 3
$alpha = 10
$seed = 65535
$maxinputno = 100
$bignum = 100
$limit = 0.001
$randomseed = 42
# 乱数の初期化
$rnd = Random.new($randomseed)
# 学習データの読み込み
def getdata()
e = [
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 0.0, 1.0],
[1.0, 0.0, 1.0, 1.0],
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 1.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0]
]
end
# 中間層の重み学習
def hlearn(wh, wo, hi, e, o)
dj = 0.0
(0..$hiddenno - 1).each do |j|
dj = hi[j] * (1.0 - hi[j]) * wo[j] * (e[$inputno] - o) * o * (1.0 - o)
(0..$inputno - 1).each do |i|
wh[j][i] += $alpha * e[i] * dj
end
wh[j][$inputno] += $alpha * -1.0 * dj
end
return wh
end
# 出力層の重み学習
def olearn(wo, hi, e, o)
d = (e[$inputno] - o) * o * (1.0 - o)
(0..$hiddenno - 1).each do |i|
wo[i] += $alpha * hi[i] * d
end
wo[$hiddenno] += $alpha * -1.0 * d
return wo
end
# 順方向の計算
def forward(wh, wo, hi, e)
(0..$hiddenno - 1).each do |i|
u = 0.0
(0..$inputno - 1).each do |j|
u += e[j] * wh[i][j]
end
u -= wh[i][$inputno]
hi[i] = f(u)
end
o = 0.0
(0..$hiddenno - 1).each do |i|
o += hi[i] * wo[i]
end
o -= wo[$hiddenno]
return f(o)
end
# 結果の出力
def printout(wh, wo)
(0..$hiddenno).each do |i|
(0..$inputno).each do |j|
printf("%lf ",wh[i][j])
end
end
printf("\n")
(0..$hiddenno+1).each do |i|
printf("%lf ",wo[i])
end
printf("\n")
end
# 中間層の重みの初期化
def initwh()
wh = Array.new($hiddenno).map{
Array.new($inputno + 1)
}
(0..$hiddenno - 1).each do |i|
(0..$inputno).each do |j|
wh[i][j] = $rnd.rand(-1.0..1.0)
end
end
return wh
end
# 出力層の重みの初期化
def initwo()
wo = Array.new($hiddenno + 1)
(0..$hiddenno).each do |i|
wo[i] = $rnd.rand(-1.0..1.0)
end
return wo
end
# 伝達関数
def f(u)
return 1.0 / (1.0 + exp(-u))
end
# メイン
# 重みの初期化
wh = initwh()
wo = initwo()
hi = Array.new($hiddenno + 1).map{ 0.0 }
# 学習データの取得
e = getdata()
n_of_e = e.length
puts "学習データの個数: #{n_of_e}"
err = 10.0
cnt = 0
# 学習
while err > $limit && cnt < 200 do
err = 0.0
(0..n_of_e - 1).each do |j|
o = forward(wh, wo, hi, e[j])
wo = olearn(wo, hi, e[j], o)
wh = hlearn(wh, wo, hi, e[j], o)
err += (o - e[j][$inputno]) * (o - e[j][$inputno])
end
cnt += 1
printf("%d\t%f\t\n", cnt, err)
end
# 学習データに対する出力
(0..n_of_e - 1).each do |i|
(0..$inputno).each do |j|
printf("%f ", e[i][j])
end
o = forward(wh, wo, hi, e[i])
printf("%f\n", o)
end
実行
$ ruby bp1.rb
学習データの個数: 8
1 2.048743
2 3.970918
3 3.262440
4 3.788325
5 3.332874
6 2.174208
7 1.978128
8 2.199647
9 2.244221
10 1.826786
11 3.785470
12 4.184682
13 0.545359
14 1.014423
15 0.868709
16 1.920221
17 1.113295
18 0.005492
19 0.004387
20 0.003945
21 0.003669
22 0.003465
23 0.003301
24 0.003161
25 0.003040
26 0.002933
27 0.002836
28 0.002749
29 0.002669
30 0.002596
31 0.002528
32 0.002466
33 0.002407
34 0.002353
35 0.002301
36 0.002253
37 0.002208
38 0.002165
39 0.002124
40 0.002086
41 0.002049
42 0.002014
43 0.001981
44 0.001949
45 0.001918
46 0.001889
47 0.001861
48 0.001834
49 0.001808
50 0.001783
51 0.001759
52 0.001735
53 0.001713
54 0.001691
55 0.001670
56 0.001650
57 0.001630
58 0.001611
59 0.001592
60 0.001574
61 0.001556
62 0.001539
63 0.001523
64 0.001507
65 0.001491
66 0.001476
67 0.001461
68 0.001446
69 0.001432
70 0.001418
71 0.001405
72 0.001391
73 0.001378
74 0.001366
75 0.001353
76 0.001341
77 0.001330
78 0.001318
79 0.001307
80 0.001295
81 0.001285
82 0.001274
83 0.001264
84 0.001253
85 0.001243
86 0.001233
87 0.001224
88 0.001214
89 0.001205
90 0.001196
91 0.001187
92 0.001178
93 0.001169
94 0.001161
95 0.001152
96 0.001144
97 0.001136
98 0.001128
99 0.001120
100 0.001112
101 0.001105
102 0.001097
103 0.001090
104 0.001083
105 0.001075
106 0.001068
107 0.001061
108 0.001055
109 0.001048
110 0.001041
111 0.001035
112 0.001028
113 0.001022
114 0.001016
115 0.001009
116 0.001003
117 0.000997
1.000000 1.000000 1.000000 1.000000 0.993934
1.000000 1.000000 0.000000 1.000000 0.985012
1.000000 0.000000 1.000000 1.000000 0.988907
1.000000 0.000000 0.000000 0.000000 0.006741
0.000000 1.000000 1.000000 1.000000 0.987935
0.000000 1.000000 0.000000 0.000000 0.012637
0.000000 0.000000 1.000000 0.000000 0.015771
0.000000 0.000000 0.000000 0.000000 0.001642
できた!