LoginSignup
0
1

More than 5 years have passed since last update.

chainerで九九

Posted at

概要

chainerで九九、やってみた。

結果

214         0.0205839
215         0.0216847
216         0.0217508
217         0.0217022


493         0.00358729
494         0.00372512
495         0.00368116
496         0.00367013
497         0.0035804
498         0.00389284
499         0.00329972
500         0.00354485
      1   2   3   4   5   6   7   8   9
  1   1   2   3   4   5   6   7   8   9
  2   2   4   6   8  10  12  14  16  18
  3   3   6   9  12  15  18  21  24  27
  4   4   8  12  16  20  24  28  32  36
  5   5  10  15  20  25  30  35  40  45
  6   6  12  18  24  30  36  42  48  54
  7   7  14  21  28  35  42  49  56  63
  8   8  16  24  32  40  48  56  64  72
  9   9  18  27  36  45  54  63  72  81

サンプルコード

import numpy as np
from chainer import Variable, report, datasets, iterators
from chainer import optimizers
from chainer import Chain
from chainer import training
from chainer.training import extensions
import chainer.functions as F
import chainer.links as L
import chainer

def in_encode(i, j):
    k = j * 16 + i
    return np.array([k >> d & 1 for d in range(8)])

def out_encode(i, j):
    k = j * i
    return np.array([k >> d & 1 for d in range(7)])

def decode(p):
    f = 0
    if p[0] > 0.5:
        f += 1
    if p[1] > 0.5:
        f += 2
    if p[2] > 0.5:
        f += 4
    if p[3] > 0.5:
        f += 8
    if p[4] > 0.5:
        f += 16
    if p[5] > 0.5:
        f += 32
    if p[6] > 0.5:
        f += 64
    return f

class MLP(Chain):
    def __init__(self, n_in, n_units, n_out):
        super(MLP, self).__init__(l1 = L.Linear(n_in, n_units), l2 = L.Linear(n_units, n_out))
    def __call__(self, x):
        return self.l2(F.relu(self.l1(x)))

def make_data():
    X = np.array([in_encode(i, j) for i in range(1, 10) for j in range(1, 10)], dtype = np.float32)
    T = np.array([out_encode(i, j) for i in range(1, 10) for j in range(1, 10)], dtype = np.float32)
    return datasets.TupleDataset(X, T)

def main():
    train = make_data()
    epoch = 500
    batchsize = 10
    model = L.Classifier(MLP(8, 96, 7), lossfun = F.mean_squared_error)
    model.compute_accuracy = False
    optimizer = optimizers.Adam()
    optimizer.setup(model)
    train_iter = iterators.SerialIterator(train, batchsize)
    updater = training.StandardUpdater(train_iter, optimizer)
    trainer = training.Trainer(updater, (epoch, 'epoch'), out = 'result')
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss']))
    trainer.run()
    x = np.array([in_encode(i, j) for i in range(1, 10) for j in range(1, 10)], dtype = np.float32)
    y = model.predictor(x).data
    p = '    '
    j = 1
    for i in range(1, 10):
        p += '%3d ' % (i * j)
    p += '\n'
    for j in range(1, 10):
        p += '%3d ' % (j)
        for i in range(1, 10):
            g = y[(i - 1) * 9 + (j - 1)]
            k = decode(g)
            p += '%3d ' % (k)
        p += '\n'
    print (p)

if __name__ == '__main__':
    main()


以上。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1