5
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

CIFAR-10の分類を実質60行で実装in PyTorch

Last updated at Posted at 2020-02-08

#概要

まだディープラーニングを実装したことがなく、とりあえずディープラーニングを実装してみたい!という方のための記事です。

今回は、画像データの読み込みから訓練誤差・汎化誤差を簡単なグラフで出力するところまでを実質60行で実装しました。

「実質」というのは、コードを見やすくするための改行やコメント文を考慮しない場合であり、実際のコードは100行ぐらいになります。

記事の最後にコードの全文を載せておきます

#処理の流れ

大雑把に言うと、次のような流れになってます。

  • データの準備(cifar-10を使います)
  • モデル定義(resnet50を使います)
  • 損失関数・最適化手法の定義
  • 訓練・推論
  • 結果出力・モデル保存など

では、ソースコードを見てみましょう。

##importするもの

まずは今回用いるライブラリ等を先にimportしておきましょう。

test.py
import torch
import numpy as np
import torch.nn as nn
from torch import optim
import torch.nn.init as init
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset,DataLoader
import torchvision.datasets as dsets
import matplotlib.pyplot as plt

これだけで10行もあるんですね(笑)

##データの準備

画像データはにはCIFAR-10という有名なデータセットを使います。

10というのは10クラスあるということで、50000枚(各クラス5000枚)の訓練画像と10000枚(各クラス1000枚)のテスト画像が用意されています。

以下のようにしてCIFAR-10を読み込むことができます。

test.py
#画像の読み込み
batch_size = 100

train_data = dsets.CIFAR10(root='./tmp/cifar-10', train=True, download=False, transform=transforms.Compose([transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)]))

train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)

test_data = dsets.CIFAR10(root='./tmp/cifar-10', train=False, download=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]))

test_loader = DataLoader(test_data,batch_size=batch_size,shuffle=False)

初めての場合はdownload=Trueとしてください。CIFAR-10のデータセットがダウンロードされます。

transform=transforms.Compose([...])と長いところがありますが、ここは画像データに色々と加工を施す項目です。

  • RandomHorizonalFlip : 画像を左右反転させる
  • ToTensor : テンソル変換
  • Normalize : 標準化
  • RandomErasing : データの一部分にノイズを付加する

と言った感じです。
計算しやすくしたり、Data Augmentation(データの水増し)により過学習を防ぐといった効果があります。

このほかにもtransformsは色々ありますので、こちらを参照ください。

##GPUが使えるか確認

モデル定義の前に、GPUを使って計算できるかどうか確認します。
GPUが使えた方が何倍も速くなるのでおすすめです。

test.py
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

GPUが使えたらcuda、CPUしか使えなかったらcpuと出力されます。

##モデル定義

ディープラーニングで肝心のモデルですが、今は自分でモデルを考えなくても優秀なモデルを扱うことができます。

今回はResnetというモデルを使用します。

test.py
model_ft = models.resnet50(pretrained=True)
model_ft.fc = nn.Linear(model_ft.fc.in_features, 10)
net = model_ft.to(device)

models.resnet50(pretrained=True)により、Resnetの訓練済みのモデルを使用することができるようになります。簡単ですねえ・・・。

ちなみに、pretrained=Falseとすると、まだ学習してないResnetを使用できますが、学習時間が長いのでTrueにすることをお勧めします。

2行目は、出力層を10にしています。40とか100とかにもできます。

3行目は、モデルをnetという変数に代入しています。.to(device)では、GPUが使えたらGPUで計算しますよーってやつです(雑)

##損失関数・最適化手法の定義

モデルが学習するためには、損失関数で正解との誤差を出してあげないといけません。
そして、その誤差を小さくするための手法が最適化です。

ここでは、損失関数にクロスエントロピーを、最適化アルゴリズムにSGDというものを用います。

test.py
criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9,weight_decay=0.00005)

lrは学習率であり、これは大きすぎても小さすぎても良くありません。ここでは0.01としました。

weight_decayとは「重み減衰」と言い、これも過学習を防ぐ正則化手法の1つです。

##学習・推論

準備が整ったので、学習・推論フェーズに入っていきます。

test.py
loss,epoch_loss,count = 0,0,0
acc_list = []
loss_list = []
for i in range(50):
  
  #ここから学習
  net.train()
  
  for j,data in enumerate(train_loader,0):
    optimizer.zero_grad()

    #1:訓練データを読み込む
    inputs,labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)

    #2:計算する
    outputs = net(inputs)

    #3:誤差を求める
    loss = criterion(outputs,labels)

    #4:誤差から学習する
    loss.backward()
    optimizer.step()

    epoch_loss += loss
    count += 1
    print('%d: %.3f'%(j+1,loss))

  print('%depoch:mean_loss=%.3f\n'%(i+1,epoch_loss/count))
  loss_list.append(epoch_loss/count)

  epoch_loss = 0
  count = 0
  correct = 0
  total = 0
  accuracy = 0.0

  #ここから推論
  net.eval()
 
  for j,data in enumerate(test_loader,0):

    #テストデータを用意する
    inputs,labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)

    #計算する
    outputs = net(inputs)

    #予測値を求める
    _,predicted = torch.max(outputs.data,1)

    #精度を計算する
    correct += (predicted == labels).sum()
    total += batch_size

  accuracy = 100.*correct / total
  acc_list.append(accuracy)

  print('epoch:%d Accuracy(%d/%d):%f'%(i+1,correct,total,accuracy))
  torch.save(net.state_dict(),'Weight'+str(i+1))

少し長いですね。

net.train()からnet.eval()の手前までが学習フェーズで、net.eval()以降が推論フェーズです。

###学習フェーズ

上記のコードにおいて、学習フェーズは

  1. 訓練データを読み込む
  2. 計算する
  3. 予測した値と実際の値の誤差を求める
  4. 3の誤差から学習する

という流れになっています。

for文の1回のループでbatch_size(←「データの準備」で定義した)分だけ画像が読み込まれます。すべての訓練画像が読み込まれたらfor文は終了します。

(例)学習データが50000枚あり、バッチサイズが100であれば、ループ回数は500回

###推論フェーズ

上記のコードにおいて、推論フェーズは

  1. テストデータを読み込む
  2. 計算する
  3. 予測値を求める
  4. 精度(正答率)を求める

という流れになっています。

1と2は学習時と同じですね。

3では予測値を求めています。「この画像はクラス5だ!」みたいな感じです。
実は、2の計算で求めた値は読み込んだ画像が各クラスに分類される確率なんです。

例えば、1枚の画像を読み込んで、[0.01,0.04,0.95]と出力されたら、

クラス1である確率は0.01(1%)
クラス2である確率は0.04(4%)
クラス3である確率は0.95(95%)

ということになります。

そしてこの場合、予測値はクラス3ということになります。

torch.save(net.state_dict(),'Weight'+str(i+1))は、学習した重みを保存することができます。

##グラフ出力

学習・推論フェーズにacc_list, loss_listというリストが定義されていましたが、これは各エポックごとの訓練誤差と精度を格納するリストです。

これをグラフ化するには以下のようにします。

test.py
plt.plot(acc_list)
plt.show(acc_list)
plt.plot(loss_list)
plt.show(loss_list)

一番簡単なグラフ出力方法です。

ちなみに、このコードを実行した際に出力されたグラフは次のようになります。
精度
acc_short.png
訓練誤差
loss_short.png

8,9エポックあたりに精度が一瞬ガクッと下がってますね。
精度は88%弱。

#最後に

今回はディープラーニングの一連の流れを60行で実装するという記事でした。

もちろん精度を上げるにはもっと工夫が必要ですが、とりあえず基盤を実装したいという方は是非参考にしてみてください。

最後に、全体のコードを載せておきます。

test.py
import torch
import numpy as np
import torch.nn as nn
from torch import optim
import torch.nn.init as init
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset,DataLoader
import torchvision.datasets as dsets
import matplotlib.pyplot as plt

#画像の読み込み
batch_size = 100
train_data = dsets.CIFAR10(root='./tmp/cifar-10', train=True, download=False, transform=transforms.Compose([transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)]))
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
test_data = dsets.CIFAR10(root='./tmp/cifar-10', train=False, download=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]))
test_loader = DataLoader(test_data,batch_size=batch_size,shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model_ft = models.resnet50(pretrained=True)
model_ft.fc = nn.Linear(model_ft.fc.in_features, 10)
net = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9,weight_decay=0.00005)

loss, epoch_loss, count = 0, 0, 0
acc_list = []
loss_list = []

#訓練・推論
for i in range(50):
  
  #ここから学習
  net.train()
  
  for j,data in enumerate(train_loader,0):
    optimizer.zero_grad()

    #1:訓練データを読み込む
    inputs,labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)

    #2:計算する
    outputs = net(inputs)

    #3:誤差を求める
    loss = criterion(outputs,labels)

    #4:誤差から学習する
    loss.backward()
    optimizer.step()

    epoch_loss += loss
    count += 1
    print('%d: %.3f'%(j+1,loss))

  print('%depoch:mean_loss=%.3f\n'%(i+1,epoch_loss/count))
  loss_list.append(epoch_loss/count)

  epoch_loss, count = 0, 0
  correct,total = 0, 0
  accuracy = 0.0

  #ここから推論
  net.eval()
 
  for j,data in enumerate(test_loader,0):

    #テストデータを用意する
    inputs,labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)

    #計算する
    outputs = net(inputs)

    #予測値を求める
    _,predicted = torch.max(outputs.data,1)

    #精度を計算する
    correct += (predicted == labels).sum()
    total += batch_size

  accuracy = 100.*correct / total
  acc_list.append(accuracy)

  print('epoch:%d Accuracy(%d/%d):%f'%(i+1,correct,total,accuracy))
  torch.save(net.state_dict(),'Weight'+str(i+1))

plt.plot(acc_list)
plt.show(acc_list)
plt.plot(loss_list)
plt.show(loss_list)
5
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?