Help us understand the problem. What is going on with this article?


More than 3 years have passed since last update.

import numpy as np
import matplotlib.pyplot as plt
import chainer
from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L

class MyChain(Chain):
    def __init__(self):
        super(MyChain, self).__init__(
            l1 = L.Linear(1, 10),
            l2 = L.Linear(10, 1),

    def __call__(self, x):
        h = F.sigmoid(self.l1(x))
        out = self.l2(h)
        return out

class MyModel_fitting(Chain):
    def __init__(self, predictor):
        super(MyModel_fitting, self).__init__(predictor=predictor)

    def __call__(self, x, t):
        y = self.predictor(x)
        loss = F.mean_squared_error(y,t)*0.5
#        loss = (y-t)*(y-t)*0.5
        return loss

    def predict(self, x):
        y = self.predictor(x)
        return y

if __name__ == "__main__":
    #---Loading training data---#
    train_x = np.linspace(0.0, 1.0, num=1000, dtype=np.float32)
    train_y = train_x*train_x
    n_epoch = 1000
    n_batch = 100

    model = MyModel_fitting(MyChain())
#    serializers.load_hdf5('MyChain.model', model)
    optimizer = optimizers.Adam()

    for epoch in range(n_epoch):
        print 'epoch : ', epoch
        indexes = np.random.permutation(np.size(train_x))
        for i in range(n_batch):
            x = Variable(np.array([[train_x[indexes[i]]]], dtype=np.float32))
            t = Variable(np.array([[train_y[indexes[i]]]], dtype=np.float32))
            loss = model(x, t)
        print 'loss : ',

    learned_y = np.zeros_like(train_y)
    for i in range(np.size(learned_y)):
        x = Variable(np.array([[train_x[i]]], dtype=np.float32))
        learned_y[i] = model.predict(x).data[0,0]

    plt.plot(train_x, train_y, 'o')
    plt.plot(train_x, learned_y)
    serializers.save_hdf5('MyChain.model', model)


Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away