4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

15分で3層NNをフルスクラッチしたかった(無理)

Posted at

動機

という@odashi_tさんのツイートを真に受けたので実際にやってみました。15分で出来たら凄いということを達成したかったけど、結果的に3時間くらいかかったので私はニューラルネットの理解が足りないなとつくづく感じました。

やったこと

3層NNを作るために必要なクラスを一通り実装し、それに基づいて学習ループを書いた。Chainerしか知らないのでどうしてもChainer風になってしまった。

タスクは「3入力1出力の足し算」です。訓練データは10個を使って学習し、最後にテストデータ4個で結果を見ている。どのテストも正解から0.2くらいズレている。

出力結果
epoch 0: y = 6.944188825159266
epoch 1: y = 6.838310105649193
epoch 2: y = 6.743423697756349
...
epoch 96: y = 5.853637753722643
epoch 97: y = 5.853626668028286
epoch 98: y = 5.853616573368178
epoch 99: y = 5.853607381144707
test:
8 7.896960108841725
3 2.7903926894351105
21 21.174035399298916
22 22.195348883180234

所感

データ少ないし精度は微妙だが、まあまあそれっぽい結果が出ている。

ニューラルネット楽しいよの気持ちを再認識できたので良かった。

以上。

コード

3LNN.py
import random
napier = 2.71828
lrate = 0.0001
epoch = 100

class node():
  def __init__(self, edge):
    self.b = 0
    self.edge = edge
    
  def forward(self, x):
    if self.edge == None:
      return self.b
    return self.edge.w * x + self.b
  
  def backward(self, dx):
    self.b += lrate * dx
    if self.edge == None:
      return
    self.edge.w += lrate * dx
  
class edge():
  def __init__(self):
    self.w = random.random()

class layer():
  def __init__(self, node_num):
    self.nodes = []
    for _ in range(node_num):
      self.nodes.append(node(edge()))
      
  def __call__(self, x):
    self.x = x
    ans = []
    for n in self.nodes:
      a = 0
      for e in x:
        a += n.forward(e)
      ans.append(a)
    return ans
  
  def update(self, dx):
    for n in self.nodes:
      n.backward(dx)
      
def ewise(x, func):
  for e in x:
    e = func(e)
  return x

def sigmoid(x):
  return 1 / (1 + napier**(-x))
  
class LNN():
  def __init__(self, inputs, hidden, outputs):
    self.l1 = layer(inputs)
    self.l2 = layer(hidden)
    self.l3 = layer(outputs)
    
  def forward(self, x):
    self.x = x
    h = ewise(self.l1(x), sigmoid)
    h = ewise(self.l2(h), sigmoid)
    h = ewise(self.l3(h), sigmoid)
    return h[0]
    
  def backward(self, y, t):
    dx = t-y
    self.l3.update(dx)
    self.l2.update(dx)
    self.l1.update(dx)


# main #
data = [[1, 2, 3], [4, 4, 5], [5, 1, 2], [3, 5, 7], [2, 8, 9],   # add 3 elements
        [2, 3, 6], [3, 4, 5], [6, 7, 8], [1, 9, 7], [2, 2, 2]]
ans = [6, 13, 8, 15, 19, 11, 12, 21, 17, 6]

def train(data, ans, model):
  for x, t in zip(data, ans):
    y = model.forward(x)
    model.backward(y, t)
  return model, y
    
def trainer(epoch):
  model = LNN(3, 3, 1)
  for i in range(epoch):
    model, y = train(data, ans, model)
    print('epoch {}: y = {}'.format(i, y))
  return model
    
model = trainer(epoch)

print('test:')
test = [[[3, 4, 1], 8], [[1, 1, 1], 3], [[5, 8, 8], 21], [[10, 4, 8], 22]]
for x in test:
  print(x[1], model.forward(x[0]))

参考

4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?