Python
初心者
機械学習
ツール
ディープラーニング

もうディープラーニングの可視化で迷わない!dlt (Deep Learning Tools)をどうぞ!

はじめに

この記事では、深層学習の勉強に役立つパッケージである、dlt (Deep Learning Tools)を紹介したいと思います。本来はKerasを用いたCIFAR-10チュートリアルで公開していたものですが、初心者の方がこのパッケージを通して深層学習に触れられるようにと、新たにこの記事で紹介することにしました。dltに含まれているコードは決して高級なものではありませんが、ディープラーニングの勉強中に他のものに煩わされないために、役に立つと思います。

この記事の内容はこちらdltパッケージの使い方からも参照できます。

なお、この記事では抜粋版を掲載しています。詳しくは上記リポジトリをどうぞ。

機能拡張等いつでもプルリク受付中です!

バージョン

PyPI

動作環境

Python3, Numpy (1.1.13), scikit-learn (0.19.1), tensorflow (1.4.1), keras (2.1.2)

誰が作ったか?

もともとこのパッケージを作ったのは、RWTH Aachen工科大学のDavid Walz氏です。私は彼の「Deep Learning in Physics Research」という講義を取っており、その講義を通してDeep Learningを学びました。この講義で使われていたdliprは、大学のクラスター上で動くものでしたが、私は個人のコンピュータでも動くように修正し、dlt (Deep Learning Tools)として、ここで公開することにしました。

このノートについて

この記事では、それらをどう使うかを紹介していきます。以下では、Fashion MNIST を例に扱ってみます。ここでは詳しいディープラーニングの内容は解説しませんが、dltを使えばどんな結果を得ることができるかを知っていただくことを目的としているので、ディープラーニングについて知らなくても構いません。

どう使えばよいか?

- dltパッケージを使うために、

pip install dlt

を実行すれば使えるようになります。なお、私が公開しているリポジトリは以下のようになっています。

.
├── LICENSE.md
├── README.md
├── How_to_use_dlt_package
├── dltパッケージの使い方
└── dlt
    ├── __init__.py
    ├── cifar.py
    ├── fashion_mnist.py
    ├── mnist.py
    └── utils.py
  • cifar.py, fashion_mnist.py, mnist.pyでは、各データセットを読み込むようになっています。utils.pyで使える形式に落とし込む操作が含まれています。utils.pyには、分類タスクの結果を見やすくするツールや、confusion matrixを出力するものが含まれています。

  • dltパッケージの使い方, How_to_use_dlt_packageにはdltをどう使うかという簡単な実行例がまとめられています。(.htmlおよび.ipynb形式) この記事の内容とほとんど同じです。

ディープラーニング - Fashion-MNIST

基本的にはMNISTのサンプルコードを参考にします。

パッケージは以下のようにしてインポートします。

import dlt

★ データセットの読み込み

Fashion MNISTのデータセットを読み込むメソッドは以下のようにして読み込みます。

data = dlt.fashion_mnist.load_data()
# Downloading Fashion-MNIST dataset
  • dltでは以下のデータセットを用意しています: CIFAR-10, CIFAR-100, MNIST, Fashion-MNIST
  • 各ファイルに上記のload_dataに相当するメソッドが定義されています。

各データには以下のようにアクセスできます。Fashion MNISTはMNISTと同じデータ量(フォーマット)で提供されているので、MNISTと同じように扱えます。

X_train = data.train_images
y_train = data.train_labels
X_test = data.test_images
y_test = data.test_labels

★ データ分布の表示

代表的なデータセットを扱う上では気にする必要はありませんが、自分でデータセットを作る際、(分類タスクにおいて)ラベル毎のデータ分布を均一にする必要があります。そのチェックをするメソッドを実装しています。

dlt.utils.plot_distribution_data(Y=data.train_labels, #正解ラベルのデータセット
                                 dataset_name='y_train', # そのデータセットの名前
                                 classes=data.classes, # ラベル
                                 fname='dist_train.png' # 出力するファイルパス
                                )

# Mean Value: 6000
# Median Value: 6000.0
# Variance: 0
# Standard Deviation: 0.0

dist_train.png

別に面白くないグラフですが、正しく学習させても正解ラベルごとにデータの偏りが出ては正しく予測できないので、これはこれで役に立ちますよ。

★ サンプル画像の表示

dltにはサンプル画像を表示するメソッドも用意されています。どういう画像が学習画像となっているかわかりますね。

dlt.utils.plot_examples(data=data, 
                          num_examples=5, # 縦に何個表示するか (横はカテゴリーと一致)
                         fname='fashion_mnist_examples.png' # 画像を保存するためのファイルパス
                       )

fashion_mnist_examples.png

★ 学習の経過を表す損失関数と精度のグラフは

dlt.utils.plot_loss_and_accuracy(fit,  #model.fitのインスタンス
                                   fname='loss_and_accuracy_graph.png' #保存するファイル名とパス
                                  )

loss_and_accuracy_graph.png

★ テスト画像に対する分類精度

# predicted probabilities for the test set
preds = model.predict(X_test)
cls = model.predict_classes(X_test)

また分類タスクでは、各テスト画像に対してどの程度の精度で分類されたかを知りたいことがあります。その時どう出力すればいいか迷うことがあるのですが、dltでは以下のようにすれば、わかりやすい結果が得られます。

# とりあえず10枚
for i in range(10):
    dlt.utils.plot_prediction(
        Yp=preds[i], # 各クラスに対して予測されたラベル 
        X=data.test_images[i], # 各クラスを表す画像
        y=data.test_labels[i], # 正しいクラスのラベル
        classes=data.classes, # ラベル名
        top_n=False, # 上位いくつまで表示させるか. Falseならすべてのカテゴリーに対する精度を表示
        fname='test-%i.png' % i) # 保存するファイル名

test-0.png
test-1.png
test-2.png
test-3.png
test-4.png

オレンジ色の棒は正しいラベルの分類精度を表し、青色の棒は間違ったラベルの分類精度を表しています。

最後の画像を例とすると、これはShirtとほとんど90%以上の精度で分類していますが、5%程度でT-shirt/topと分類しています。

結果を全体的に見たいときは、以下のConfusion Matrixが便利です。

★ Confusion Matrix

confusion matrixについては以下のように出力できます。

dlt.utils.plot_confusion_matrix(data.test_labels, # 正しいテストラベル(one-hot vectorに変換する前)
                                  cls, # predict_classesを通したpreds
                                  data.classes, # クラス名
                                  title='confusion matrix', # 出力グラフのタイトル
                                  fname='confusion_matrix.png') # 出力パス

confusion_matrix.png

おわりに

ディープラーニングは実際にコードを実行しても、なんだかよくわかりにくいですが、そういうときは一旦可視化してみると、少しづつわかってくるものです。そういうとき、dltがお役に立てたら大変光栄です。よろしくお願いいたします。