0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

chainer3で偶数と奇数を分類するための初心者の備忘録

0
Posted at

教師データをどれだけ充実させるかとか
GPUを使うかとか、Trainerを使うとか、そういうのは気にせず、
まずは奇数と偶数を分類できるようにchainer3で単純に学習させてみました。

お題: 偶数と奇数の分類 by Chainer3

入力: 0 〜 9 の整数
出力: [1,0]:偶数、[0,1]:奇数

実験

下記のコードで実験してみました。

test.py
import numpy as np

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import Variable, optimizers

class MyModel(chainer.Chain):

    def __init__(self):
        super(MyModel, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(10, 10)
            self.l2 = L.Linear(10, 10)
            self.l3 = L.Linear(10, 2)

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

if __name__ == '__main__':

    model = MyModel()
    optimizer = optimizers.SGD()
    optimizer.setup(model)

    # バッチ処理用のデータを作成
    listX = []
    listY = []
    for i in range(10):
        l = np.zeros(10)
        l[i] = 1
        listX.append(l)
        listY.append(i % 2)

    x = Variable(np.array(listX, dtype=np.float32))
    y = Variable(np.array(listY, dtype=np.int32))

    # 学習
    for i in range(2000):
        model.zerograds()
        loss = F.softmax_cross_entropy(model(x), y)
        loss.backward()
        optimizer.update()

    # テスト
    for i in range(10):
        l = np.zeros(10)
        l[i] = 1
        x = Variable(np.array(l, dtype=np.float32).reshape(1,-1))
        print(F.softmax(model(x)))

結果

以下のように、0 〜 9 の入力に対し、偶数->奇数->偶数->奇数の順に数値が高くなりました。

$ python test.py 
variable([[ 0.96782762  0.03217238]])
variable([[ 0.02553093  0.97446907]])
variable([[ 0.95543504  0.04456493]])
variable([[ 0.10487143  0.89512861]])
variable([[ 0.97594351  0.02405647]])
variable([[ 0.03047855  0.96952146]])
variable([[ 0.91441429  0.08558577]])
variable([[ 0.19360393  0.80639613]])
variable([[ 0.98645091  0.01354905]])
variable([[ 0.01921912  0.9807809 ]])
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?