0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Pytorch で MNIST をやってみた時のメモ

Last updated at Posted at 2020-05-06

はじめに

昨今の機械学習や Deep Learning の話題に少しでもついていけるように勉強するために、PyTorch で MNIST をやってみました。
この記事はメモなので、Pytorch で MNIST をやるために必要な知識をすべて網羅するものではなく、また、ちょっとわき道に逸れたような内容もあります。

参考ページ

これ以外に参考にしたページは、本文の関連する箇所にリンクを置いていきます。

目標

基本的には下記の1つ目の記事にあるコードを写経をしつつ理解することが目標です。
ただし、写経だけだと寂しいので、ネットワークの構成や損失関数などを PyTorch 公式サンプルのものに差し替えました。

  1. PyTorchでMNIST by @fukuit :: https://qiita.com/fukuit/items/215ef75113d97560e599
  2. Basic MNIST Example (公式のサンプルコード) :: https://github.com/pytorch/examples/blob/master/mnist/main.py

書いたコードは https://github.com/fukai-t/pytorch-mnist に置いています。

環境構築

今回は Docker コンテナで環境を作りました。

画像の取得

MNIST の画像は torchvision を使ってダウンロードします。

データの前処理を理解する

データの前処理で出てくるコンポーネントに関しては,以下の説明がわかりやすかったので引用させていただきます。
引用元: https://qiita.com/takurooo/items/e4c91c5d78059f92e76d#transformsdatasetdataloader%E3%81%AE%E5%BD%B9%E5%89%B2

- transforms
  - データの前処理を担当するモジュール
- Dataset
  - データとそれに対応するラベルを1組返すモジュール
  - データを返すときにtransformsを使って前処理したものを返す。
- DataLoader
  - データセットからデータをバッチサイズに固めて返すモジュール

DatasetDataLoader はなんとなく何するかイメージがつきやすいのですが, transforms はちょっとイメージが湧きづらかったのでもう少し調べてみました。

参考ページ: (Official) TORCHVISION.TRANSFORMS https://pytorch.org/docs/stable/torchvision/transforms.html

Normalize (torchvision.transforms.Normalize())について

  • 日本語で言うところの標準化という処理になるらしい。
    • ただし、標準化という言葉が指すものは文脈によってまちまちらしい(英語圏でも同じ)
  • torchvision.transforms.Normalize() の標準化は平均を0, 分散と標準偏差を1にするようにデータを変換する処理。
  • 今回は入力データのチャンネルは一つなので、平均も分散も要素数1のタプルを指定している。複数チャンネルあれば、チャンネル数分の要素を持つタプルをそれぞれ指定するらしい。
    • (ここ では 平均と分散が 0.5 になっているものの、これは間違い?)

DataLoader の num_worker とは?

参考ページ

ネットワークの定義

Official と同じネットワークを定義してみる。

Dropout とは?

学習処理 (loss, 重みの更新の定義)

Qiita にある記事と公式のサンプルコードとでは使っている損失関数と最適化アルゴリズムが違う。
コード全体としては、Qiita のものを参考にしつつ、損失関数と最適化アルゴリズムだけ差し替えています。
以下、参考にしたコードのそれぞれの該当箇所です。

nll_loss とは?

adadelta とは?

学習結果を確認

参考コードの該当箇所: https://qiita.com/fukuit/items/215ef75113d97560e599#%E5%AD%A6%E7%BF%92%E7%B5%90%E6%9E%9C%E3%81%AE%E7%A2%BA%E8%AA%8D

torch.max() は何?

  • tensor と次元の番号を渡して、最大値を集めた tensor と最大値があったインデックスの tensor が戻ってくる。

(labels == predicts).sum().item() なんて書き方ができるの!? –> 以下のように tensor 同士を比較すると、各要素を比較した結果を格納した tensor が返ってくるので可能。

>>> labels
tensor([3., 3., 3., 3.])
>>> predicts
tensor([3, 0, 3, 1])
>>> labels == predicts
tensor([ True, False,  True, False])
>>> (labels == predicts).sum()
tensor(2)
>>> (labels == predicts).sum().item()
2

実行結果

epoch 2 回で実行したみた結果は以下の通り。

# python mnist.py
[epoch #1, iter #100] loss: 0.645978
[epoch #1, iter #200] loss: 0.227659
[epoch #1, iter #300] loss: 0.180389
[epoch #1, iter #400] loss: 0.134909
[epoch #1, iter #500] loss: 0.126124
[epoch #1, iter #600] loss: 0.121547
[epoch #2, iter #100] loss: 0.092214
[epoch #2, iter #200] loss: 0.095481
[epoch #2, iter #300] loss: 0.091667
[epoch #2, iter #400] loss: 0.100502
[epoch #2, iter #500] loss: 0.085322
[epoch #2, iter #600] loss: 0.085219
Accuracy: 97.50 %%

おわりに

PyTorch で MNIST の例題を一通りやってみたときに調べたことを書きました。
本当は、自分の書いた数字の画像を推論してみたりしたいなと思いましたが、それは今後の課題ということにします。

0
2
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?