9
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorch による MNISTデータセットの分類モデル作成

Last updated at Posted at 2020-01-23

環境構築

Windows で python3.6 + OpenCV3 + PyTorch による画像データの機械学習を行うための環境構築設定を以下に記載します。


1. WinPython のインストールと実行

今回は「WinPython64-3.6.7.0Qt5」を使用するため、以下の URL から「WinPython64-3.6.7.0Qt5.exe」をダウンロードしてください。
https://ja.osdn.net/projects/sfnet_winpython/releases/

ダウンロードしたら「C:\study」配下に配置して実行してください。
(フォルダがない場合は作成)

実行が完了したら「C:\study\WPy-3670」配下の「Jupyter Lab.exe」を実行して
「Notebook」の「Python 3」をクリックして下さい。


2. Numpy のアップグレード

OpenCV を使用する際に Numpy のバージョンが低いとエラーが出るため、Numpy をアップグレードします。
以下のコマンドを Jupyter Lab 上で実行してください。

!pip install numpy --upgrade

「Successfully installed ~」と表示されればアップグレード成功です。
(既にアップグレード済みの場合は「Requirement already~」と表示されます)
※ Jupyter Lab (Notebook) のセル上で UNIXコマンドを扱いたい場合は、頭に「!」を付けます。


3. OpenCV のインストール

OpenCV を WinPython にインストールします。

以下の URL から「opencv_python-3.4.3+contrib-cp36-cp36m-win_amd64.whl」をダウンロードしてください。
https://www.lfd.uci.edu/~gohlke/pythonlibs/#opencv

ダウンロードしてきたファイルを「C:\study\WPy-3670\notebooks」に配置して、以下のコマンドを実行してください。


!pip install opencv_python-3.4.3+contrib-cp36-cp36m-win_amd64.whl

インストール出来たらバージョンを確認してみます。

import cv2
print(cv2.__version__)

4. PyTorch のインストール

PyTorch を WinPython のインストールします。

以下のコマンドを実行して下さい。

!pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-win_amd64.whl

torchvicioon (PyTorch の画像向けパッケージ)もインストールします。
以下のコマンドを実行して下さい。

!pip install torchvision

PyTorch と torchvision もバージョンを確認します。

import torch
import torchvision
print(torch.__version__)
print(torchvision.__version__)

環境構築は以上です。


PyTorch による MNISTデータセットの分類モデル作成

MNIST というメジャーなデータセットを使用して、画像データの機械学習を行います。
(参考サイト: https://www.madopro.net/entry/pytorch_mnist)

【目的】

  1. 機械学習での画像の扱い方をざっくりと理解する
  2. PyTorch の使い方をざっくりと理解する
  3. ニューラルネットワーク、深層学習についてざっくり理解する

ニューラルネットワーク とは

深層学習 とは

MNIST とは

MNIST (Mixed National Institute of Standards and Technology database) とは
出書きの0~9までの数字画像のオープンデータセットで、機械学習(特にニューラルネットワーク)のチュートリアルとしてよく使用されています。
(http://yann.lecun.com/exdb/mnist/)

あと、ニューラルネットワークって書くと長いので以降は「NN」と表記します。

学習の流れ

  1. MNISTデータセットを取得して内容を確認・視覚化する
  2. PyTorch で NNのモデルを作成する
  3. 作成した NNモデルに MNISTデータを学習させる
  4. 学習させた NNモデルに手書きの画像を識別して貰う

1. MNIST データの取得

MNIST データセットを取得してきます。
PyTorch の torchvision には有名なデータセットや画像処理関数がまとめられてパッケージが存在するため、今回はそれを利用してデータを取得してきます。

C:\study\data」フォルダを作成した後に、以下を実行してください。

# 実行に影響のない警告を非表示
import warnings
warnings.filterwarnings('ignore')

# 必要な機能をインポート
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

# # MNISTデータセットを取得する関数
def load_mnist_data(trn, dl, bsize, sf):
    mnist_data = MNIST('C:\study\data\mnist', train=trn, download=dl, transform=transforms.ToTensor())
    data_loader = DataLoader(mnist_data,
                             batch_size=bsize,
                             shuffle=sf)
    
    return data_loader
# MNISTデータを取得
train = True # 訓練用データを取得
dl_flag = True # 初回はデータをダウンロードするために True にする
batch_size = 4 # 取得してきたデータの件数
sf_flag = False # データをシャッフルするか否か。今回はシャッフル無し
data_loader = load_mnist_data(train, dl_flag, batch_size, sf_flag)

data_iter = iter(data_loader)
images, labels = data_iter.next()

# とりあえず先頭の内容を見てみる
print(images[0])
print(labels[0])
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0706, 0.0706, 0.0706,
          0.4941, 0.5333, 0.6863, 0.1020, 0.6510, 1.0000, 0.9686, 0.4980,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.1176, 0.1412, 0.3686, 0.6039, 0.6667, 0.9922, 0.9922, 0.9922,
          0.9922, 0.9922, 0.8824, 0.6745, 0.9922, 0.9490, 0.7647, 0.2510,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1922,
          0.9333, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922,
          0.9922, 0.9843, 0.3647, 0.3216, 0.3216, 0.2196, 0.1529, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706,
          0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.7137,
          0.9686, 0.9451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.3137, 0.6118, 0.4196, 0.9922, 0.9922, 0.8039, 0.0431, 0.0000,
          0.1686, 0.6039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0549, 0.0039, 0.6039, 0.9922, 0.3529, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.5451, 0.9922, 0.7451, 0.0078, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0431, 0.7451, 0.9922, 0.2745, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.9451, 0.8824, 0.6275,
          0.4235, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3176, 0.9412, 0.9922,
          0.9922, 0.4667, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
          0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
          0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
          0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
          0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
          0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
          0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.6706,
          0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.7647, 0.3137, 0.0353,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.2157, 0.6745, 0.8863, 0.9922,
          0.9922, 0.9922, 0.9922, 0.9569, 0.5216, 0.0431, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.5333, 0.9922, 0.9922, 0.9922,
          0.8314, 0.5294, 0.5176, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000]]])
tensor(5)
images: 画像データ
labels: 画像データ示す数字(0~9)

が入っています。
ただ画像データがこのままだと面白くないので可視化します。


2. MNIST データの可視化

MINSTデータの画像データは、28行*28列 の行列で表現されています。
また、配列内の数値は色の濃淡を 0.000 ~ 1.000 で示しています。
(数値が高いほど濃く、低いほど淡い)
そのためグラフとして描画する事ができます。

# 可視化に Matplotlib というライブラリを取得
import matplotlib.pyplot as plt
%matplotlib inline

# matplotlib で1つ目のデータを可視化
for idx, _ in enumerate(labels):
    plt.figure(idx+1)    
    npimg = images[idx].numpy()
    npimg = npimg.reshape((28, 28))
    plt.imshow(npimg, cmap='Greens')
    plt.title('Label: {}'.format(labels[idx]))

2020-01-24 07.18.15 localhost 0b8351f5b7f1.png


3. ニューラルネットワークによる MNISTデータセットの学習

MNISTデータを可視化することで中身が確認出来たので、このデータを機械に学習して貰いましょう。

今回は比較的単純な「順伝播型ニューラルネットワーク(Feed Forward NN)」モデルを作成します。
「順伝播型」などの用語については今回は割愛するため、興味のある方は調べてみましょう。


3-1. モデルの作成

PyTorch を使用して、順伝播型ニューラルネットワークモデルを作成します。
モデルの設計は以下のようにします。

  • 入力層:28行*28列 784次元
  • 隠れ層:50次元
  • 出力層:10次元 -> 0~9の10種類の数字のため
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

class FeedForwardNeuralNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 50) # 入力層から隠れ層への設定
        self.l2 = nn.Linear(50, 10) # 隠れ層から出力層への設定
        
    def forward(self, x):
        # テンソルのリサイズ: (N, 1, 28, 28) --> (N, 784)
        # images のデータが 28行*28列のため、784次元データに変換
        x = x.view(-1, 28 * 28)
        x = self.l1(x)
        x = self.l2(x)
        return x
    
ffnn = FeedForwardNeuralNet()

3-2. コスト関数と学習方法を定義

ざっくり言うと、機械がどのようなロジックに基づいてデータを学習していくかを定義します。
詳しくは割愛します。
(参考: http://nnadl-ja.github.io/nnadl_site_ja/)

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
# 最適化手法にSGDを選択
optimizer = optim.SGD(ffnn.parameters(), lr=0.01)

3-3. MNIST の訓練データを取得して学習させる

モデルの作成と学習方法の定義が出来たため、訓練用データを取得して学習させてみます。

# 訓練データの取得
train = True
trn_dl_flag = False
trn_batch_size = 4
# デフォルトのデータ順序が学習に与える影響を効力して、データをシャッフルして取得する
trn_sf_flag = True
train_loader = load_mnist_data(train, trn_dl_flag, trn_batch_size, trn_sf_flag)
# 学習実施
print('--------------------')
print('学習の進行による損失値の減少を表示')
print('--------------------')
# epoch(=訓練データを何回学習させるか)は3回
for epoch in range(3):
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, labels = data
        
        # Variableに変換
        # 参照: https://www.hellocybernetics.tech/entry/2017/10/07/013937#torchautogradVariableの配列データを見る
        inputs, labels = Variable(inputs), Variable(labels)
        
        # 勾配情報をリセット
        optimizer.zero_grad()
        
        # 順伝播
        outputs = ffnn(inputs)
        
        # コスト関数を使ってロスを計算する
        loss = criterion(outputs, labels)
        
        # 逆伝播
        # 参照: https://qiita.com/43x2/items/50b55623c890564f1893
        loss.backward()
        
        # パラメータの更新
        optimizer.step()
        
        running_loss += loss.data[0]
        
        if i % 5000 == 4999:
            print('%d %d loss: %.3f' % (epoch + 1, i + 1, running_loss / 1000))
            running_loss = 0.0

print('Finished Training')
--------------------
学習の進行による損失値の減少を表示
--------------------
1 5000 loss: 2.727
1 10000 loss: 1.650
1 15000 loss: 1.618
2 5000 loss: 1.516
2 10000 loss: 1.488
2 15000 loss: 1.556
3 5000 loss: 1.430
3 10000 loss: 1.501
3 15000 loss: 1.444
Finished Training

学習が進むごとに「loss」というのが減少しています。
「loss」というのはざっくり言うと、機械が予測した結果と正解との誤差の事です。


3-4. テストデータで機械による学習の結果を検証

学習を行った機械による性能を評価するため、学習に使っていないデータを機械に予測させて精度を見てみます。

# テストデータを取得
train = False
tst_dl_flag = False
tst_batch_size = 4
tst_sf_flag = True
tst_loader = load_mnist_data(train, tst_dl_flag, tst_batch_size, tst_sf_flag)

import torch

correct = 0
total = 0
for data in tst_loader:
    inputs, labels = data
    outputs = ffnn(Variable(inputs))
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

print('Accuracy: {} / {} = {}'.format(correct, total, float(correct)/total))
Accuracy: 9182 / 10000 = 0.9182

精度は 92% くらいになります。
なかなかの高確率で機械は MNISTデータの識別が出来るようです。


3-5. 機械が予測したデータを可視化して確認

実際に機械が識別したデータを視覚化して見ます。

tst_iter = iter(tst_loader)
true_lst = []
false_lst = []
for _ in range(100):
    inputs, labels = tst_iter.next()
    outputs = ffnn(Variable(inputs))
    _, predicted = torch.max(outputs.data, 1)

    for idx in range(len(labels)):
        lst = [inputs[idx], labels[idx], predicted[idx]]
        if int(labels[idx]) == int(predicted[idx]):
            true_lst.append(lst)
        else:
            false_lst.append(lst)

# 予測が正解しているデータを表示
print('------------------')
print('機械の予測が当たっているデータ')
print('------------------')
for idx, tlst in enumerate(true_lst[:10]):
    plt.figure(idx+1)   
    plt.imshow(tlst[0].numpy().reshape(28, 28), cmap='Greens')
    plt.title('True: {}, Estim: {}'.format(tlst[1], tlst[2]))

機械の予測が当たっているデータ
2020-01-24 07.16.48 localhost f58b23919b73.png

# 予測が不正解のデータを表示
print('------------------')
print('機械の予測が外れているデータ')
print('------------------')
for idx, flst in enumerate(false_lst[:10]):
    plt.figure(idx+1)   
    plt.imshow(flst[0].numpy().reshape(28, 28), cmap='Greens')
    plt.title('True: {}, Estim: {}'.format(flst[1], flst[2]))

機械の予測が外れているデータ
2020-01-24 07.16.18 localhost 2398ffebfaa0.png

以上で 機械学習(ニューラルネットワーク)による MNISTデータの学習と予測、検証が出来ました。

9
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?