3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchでXORを実装してみる

Last updated at Posted at 2020-11-04

#はじめに
Kerasでやりたいことをやろうとすると、結局tensorflowを使わざるを得ず、それならPyTorchの方がいいんじゃね?ということで早速、XORを実装してみた。

#環境

  • Python 3.6
  • pytorch 1.7.0

#ソース

import torch
import torch.nn as nn
import torch.optim as optim


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(2, 8)
        self.fc2 = torch.nn.Linear(8, 8)
        self.fc3 = torch.nn.Linear(8, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.sigmoid(x)
        return x


def main():

    import numpy as np
    x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
    y = np.array([[0], [1], [1], [0]])

    num_epochs = 10000

    # convert numpy array to tensor
    x_tensor = torch.from_numpy(x).float()
    y_tensor = torch.from_numpy(y).float()

    # crate instance
    net = Net()

    # set training mode
    net.train()

    # set training parameters
    optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
    criterion = torch.nn.MSELoss()

    # start to train
    epoch_loss = []
    for epoch in range(num_epochs):
        print(epoch)
        # forward
        outputs = net(x_tensor)

        # calculate loss
        loss = criterion(outputs, y_tensor)

        # update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # save loss of this epoch
        epoch_loss.append(loss.data.numpy().tolist())

    print(net(torch.from_numpy(np.array([[0, 0]])).float()))
    print(net(torch.from_numpy(np.array([[1, 0]])).float()))
    print(net(torch.from_numpy(np.array([[0, 1]])).float()))
    print(net(torch.from_numpy(np.array([[1, 1]])).float()))

if __name__ == "__main__":
    main()

結果

tensor([[0.0511]], grad_fn=<SigmoidBackward>)
tensor([[0.9363]], grad_fn=<SigmoidBackward>)
tensor([[0.9498]], grad_fn=<SigmoidBackward>)
tensor([[0.0666]], grad_fn=<SigmoidBackward>)

お、いい感じだ。

#感想
まだ、ほんの触り程度だが、KerasやTensorflowに比べると、ブラックボックス感がなくPythonからシームレスに使える感じがすごくいい。
例えば、モデルの中にprint文を入れても, そのまま出力される。実行中の可視化なんかもすごくやりやすそう。

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?