LoginSignup
2
1

More than 5 years have passed since last update.

画像の分類 Pytorch tutorial

Posted at

画像の分類 Pytorch

TL;DR

TensorFlowは応用でやってる人には難しすぎるしkerasは凝った実装をしようとすると逆にめんどくさくなるという話を聞き、今流行ってそうなPytorchでも勉強するかという話です。Cyfar10の公式tutorialをGoogleColabで動かします。

Google ColaboratoryでPytorchを動かす設定など

#2019/03/14現在
!pip3 install torch torchvision
#各種ライブラリのimport
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

#deviceをGPUに設定、ランタイムをGPUに変更しておく
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#deviceを確認
print(device)
# cuda:0

データセットの読み込み

cyfar10
- 3*32*32
- plane,car, bird, cat, deer, dog, frog, horse, ship, truckの10クラスの画像

処理の手順
1. torchvision.dataset.CIFAR10を使ってcifar10をダウンロード
2. torch.utils.data.DataLoaderを使ってデータセットを作成
3. その際にtransform.Composeを使ってTensorに変換し[-1,1]間で正規化

transoform.Composeを使えばData Augmentationできそう

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

モデルの作成

普通にモデル作るだけ。
__init__でモデルを作ってforwardで順伝播の時の処理を書く。10クラス分類なので最後は10次元で出力。
GPU使いたいのでGPUにモデル送っとくことは忘れずに

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
#model to gpu
net.to(device)

トレーニング

損失関数 CrossEntropyLoss
最適化法 MomentumSGD (lr=0.001, momentum=0.9)
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')
[1,  2000] loss: 2.139
[1,  4000] loss: 1.817
[1,  6000] loss: 1.641
[1,  8000] loss: 1.573
[1, 10000] loss: 1.533
[1, 12000] loss: 1.486
[2,  2000] loss: 1.405
[2,  4000] loss: 1.416
[2,  6000] loss: 1.375
[2,  8000] loss: 1.342
[2, 10000] loss: 1.317
[2, 12000] loss: 1.290
Finished Training

評価

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))
Accuracy of the network on the 10000 test images: 55 %

何も工夫してないしこんなものですか、tutorialの値とも大して変わってないので

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1