0
2

More than 3 years have passed since last update.

Pytorch で MNIST をやってみた時のメモ

Last updated at Posted at 2020-05-06

はじめに

昨今の機械学習や Deep Learning の話題に少しでもついていけるように勉強するために、PyTorch で MNIST をやってみました。
この記事はメモなので、Pytorch で MNIST をやるために必要な知識をすべて網羅するものではなく、また、ちょっとわき道に逸れたような内容もあります。

参考ページ

これ以外に参考にしたページは、本文の関連する箇所にリンクを置いていきます。

目標

基本的には下記の1つ目の記事にあるコードを写経をしつつ理解することが目標です。
ただし、写経だけだと寂しいので、ネットワークの構成や損失関数などを PyTorch 公式サンプルのものに差し替えました。

  1. PyTorchでMNIST by @fukuit :: https://qiita.com/fukuit/items/215ef75113d97560e599
  2. Basic MNIST Example (公式のサンプルコード) :: https://github.com/pytorch/examples/blob/master/mnist/main.py

書いたコードは https://github.com/fukai-t/pytorch-mnist に置いています。

環境構築

今回は Docker コンテナで環境を作りました。
- python:3 をベースにし、 pip で torchtorchvision をインストール。
- torch は pytorch 本体、 torchvision は機械学習用の画像や動画のデータセットを扱うためのライブラリ(だぶん)?
- Dockerfile は https://github.com/fukai-t/pytorch-mnist/blob/master/Dockerfile に置いています。

画像の取得

MNIST の画像は torchvision を使ってダウンロードします。

データの前処理を理解する

データの前処理で出てくるコンポーネントに関しては,以下の説明がわかりやすかったので引用させていただきます。
引用元: https://qiita.com/takurooo/items/e4c91c5d78059f92e76d#transformsdatasetdataloader%E3%81%AE%E5%BD%B9%E5%89%B2

- transforms
  - データの前処理を担当するモジュール
- Dataset
  - データとそれに対応するラベルを1組返すモジュール
  - データを返すときにtransformsを使って前処理したものを返す。
- DataLoader
  - データセットからデータをバッチサイズに固めて返すモジュール

DatasetDataLoader はなんとなく何するかイメージがつきやすいのですが, transforms はちょっとイメージが湧きづらかったのでもう少し調べてみました。

参考ページ: (Official) TORCHVISION.TRANSFORMS https://pytorch.org/docs/stable/torchvision/transforms.html

Normalize (torchvision.transforms.Normalize())について
- 日本語で言うところの標準化という処理になるらしい。
- ただし、標準化という言葉が指すものは文脈によってまちまちらしい(英語圏でも同じ)
- torchvision.transforms.Normalize() の標準化は平均を0, 分散と標準偏差を1にするようにデータを変換する処理。
- 6-2. データを標準化してみよう: https://bellcurve.jp/statistics/course/19647.html
- 今回は入力データのチャンネルは一つなので、平均も分散も要素数1のタプルを指定している。複数チャンネルあれば、チャンネル数分の要素を持つタプルをそれぞれ指定するらしい。
- (ここ では 平均と分散が 0.5 になっているものの、これは間違い?)

DataLoader の num_worker とは?
- 複数のプロセスからデータをロードする場合は、そのプロセス数を指定するとのこと。並列で処理する際なんかに指定するものらしい。
- References
- Multi-process data loading: https://pytorch.org/docs/stable/data.html#multi-process-data-l
- DataLoader: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

参考ページ
- Official tutorial: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
- PyTorch transforms/Dataset/DataLoaderの基本動作を確認する: https://qiita.com/takurooo/items/e4c91c5d78059f92e76d

ネットワークの定義

Official と同じネットワークを定義してみる。

Dropout とは?
- torch.nn.Dropout: https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout
- 【ニューラルネットワーク】Dropout(ドロップアウト)についてまとめる: https://qiita.com/shu_marubo/items/70b20c3a6c172aaeb8de
- 学習時に確率的に(引数で指定した確率となるようランダムに)ネットワーク上のノードを無効化する
- maxpooling … フィルタ内の最大値を出力する操作
- 各層でのノード数の変化やパラメーターの数はソースコードのコメントに書いてみました。

学習処理 (loss, 重みの更新の定義)

Qiita にある記事と公式のサンプルコードとでは使っている損失関数と最適化アルゴリズムが違う。
コード全体としては、Qiita のものを参考にしつつ、損失関数と最適化アルゴリズムだけ差し替えています。
以下、参考にしたコードのそれぞれの該当箇所です。
- https://qiita.com/fukuit/items/215ef75113d97560e599#loss%E9%96%A2%E6%95%B0%E3%81%A8%E6%9C%80%E9%81%A9%E5%8C%96%E3%81%AE%E8%A8%AD%E5%AE%9A%E3%82%92%E3%81%99%E3%82%8B
- https://github.com/pytorch/examples/blob/master/mnist/main.py#L43

nll_loss とは?
- NLL = Negative Log-Likelihood
- 尤度の対数の負値。
- 参考ページ: DeepLearningを0から実装できるようにするための連載 Vol.0 by @lucidfrontier45 https://qiita.com/lucidfrontier45/items/113ba14ea136a9e6055f#%E6%9C%80%E5%B0%A4%E6%8E%A8%E5%AE%9A

adadelta とは?
- 最適化アルゴリズムの一つ。細かいことはよくわかってないものの、今では Adadelta ではなく Adam というアルゴリズムがよく使われている?
以下参考にしたページです。
- CLASS torch.optim.Adadelta: https://pytorch.org/docs/stable/optim.html#torch.optim.Adadelta
- 勾配降下法の最適化アルゴリズムの備忘録 by @Hiroki11x: https://qiita.com/Hiroki11x/items/02a5212a66b3d12d2759#4-adadelta
- 勾配降下法の最適化アルゴリズムを概観する: https://postd.cc/optimizing-gradient-descent/#adadelta

学習結果を確認

参考コードの該当箇所: https://qiita.com/fukuit/items/215ef75113d97560e599#%E5%AD%A6%E7%BF%92%E7%B5%90%E6%9E%9C%E3%81%AE%E7%A2%BA%E8%AA%8D

torch.max() は何?
- tensor と次元の番号を渡して、最大値を集めた tensor と最大値があったインデックスの tensor が戻ってくる。
- 2 つ目の引数は、何番目の次元を削減するか。基本的には torch.max は階数を一つ減らすのみ。
- 2 階の tensor を渡すと 1 階の tensor が、 3 階 の tensor を渡すと 2 階の tensor がそれぞれ返ってくる。
- どの軸で見た時の最大値を返すかを指定するのが2つ目の引数の役割。
- 参考ページ: [PyTorch]torch.max()でちょっと迷ったこと by @Haaamaaaaa https://qiita.com/Haaamaaaaa/items/b9f47cba588b83ad34a7
- ドキュメント: https://pytorch.org/docs/stable/torch.html#torch.max

(labels == predicts).sum().item() なんて書き方ができるの!? –> 以下のように tensor 同士を比較すると、各要素を比較した結果を格納した tensor が返ってくるので可能。

>>> labels
tensor([3., 3., 3., 3.])
>>> predicts
tensor([3, 0, 3, 1])
>>> labels == predicts
tensor([ True, False,  True, False])
>>> (labels == predicts).sum()
tensor(2)
>>> (labels == predicts).sum().item()
2

実行結果

epoch 2 回で実行したみた結果は以下の通り。

# python mnist.py
[epoch #1, iter #100] loss: 0.645978
[epoch #1, iter #200] loss: 0.227659
[epoch #1, iter #300] loss: 0.180389
[epoch #1, iter #400] loss: 0.134909
[epoch #1, iter #500] loss: 0.126124
[epoch #1, iter #600] loss: 0.121547
[epoch #2, iter #100] loss: 0.092214
[epoch #2, iter #200] loss: 0.095481
[epoch #2, iter #300] loss: 0.091667
[epoch #2, iter #400] loss: 0.100502
[epoch #2, iter #500] loss: 0.085322
[epoch #2, iter #600] loss: 0.085219
Accuracy: 97.50 %%

おわりに

PyTorch で MNIST の例題を一通りやってみたときに調べたことを書きました。
本当は、自分の書いた数字の画像を推論してみたりしたいなと思いましたが、それは今後の課題ということにします。

0
2
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2