画像分類の学習をしたかったのでTensorFlowのチュートリアルを読み解きながら進めてみました。
AWSで実行環境を立ち上げてサンプルを動かしてみます。
TensorFlow使うのも初めてだったので、各用語を調べながら進めたメモを残します。
学習手法について
ここではCNNという手法が使われています。
CNNは日本語でいうと畳み込みニューラルネットワークといいます。
解説記事がたくさん世の中にあるので詳細な説明は省きますが、理論はそこまで難しい数式が出てくるわけではないように見えたので、興味がある方は是非調べてみると面白いと思います。
CIFAR-10について
CIFAR-10とは、画像認識のベンチマークとして広く使われているデータセットです。
32*32の画像が60000枚入っていて、airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truckの10種類のラベル付けがされています。
このチュートリアルではこのデータセットを対象に画像分類を行います。
チュートリアルのコードについて
このチュートリアルのコードはここから入手することができます。
このサンプルコードは大きく分けて3つの部分に分けることができます。
- モデルの入力生成
- モデルによる予測の評価
- モデルの学習
1. モデルの入力生成
ここでは学習するための画像の前処理を行っています。
具体的には以下のようなことを行っています。
- 32*32の画像を24*24にランダムに切り取る
- 画像の明るさをそろえる
- 画像に反転や明るさの変更などの加工を行い、学習データを増やす
2. モデルによる予測の評価
モデルをもとに、予測結果を評価します。
3. モデルの学習
データをN個に分類するための学習モデルとして一般的である多クラス分類ロジスティック回帰を使用しています。
詳細な解説は私も初心者なので書けませんが、理解できた部分部分だけ書くと、
- ニューラルネットワークの出力をソフトマックス関数で変換したものを各クラスに属する確率とする
- 出力された各クラスの確率のベクトルと、正解のベクトルとの間の交差エントロピーを損失関数とし、これを最小化するように学習する
- 交差エントロピーは正の値であり、正解に近づくほど0に近くなる性質を持つ
- 損失関数を最小化させる重みを最急降下法で学習させる
間違いや補足があったらご指摘いただきたいです...
サンプルの実行
AWS上でEC2インスタンスを立ち上げてサンプルを実行してみます。
EC2インスタンスをAMIか作ることで、最初から必要なソフトウェアがインストールされた環境を作ることができます。
今回はDeep Learning AMI (Amazon Linux) Version 5.0を選択しました。
試しに動かしてみるだけなので、インスタンスタイプはt2.mediumを選択。
やったことは、
- 立ち上げたEC2インスタンスにSSHでログイン
source activate tensorflow_p36 #tensorflowの環境を選択、これをしないとtensorflowが使えない
git clone https://github.com/tensorflow/models.git
cd models/tutorials/image/cifar10
python cifar10_train.py
- 画像ダウンロードが走った後...
2018-03-14 13:50:02.431156: step 0, loss = 4.68 (97.2 examples/sec; 1.316 sec/batch)
2018-03-14 13:50:10.778302: step 10, loss = 4.62 (153.3 examples/sec; 0.835 sec/batch)
2018-03-14 13:50:19.006783: step 20, loss = 4.53 (155.6 examples/sec; 0.823 sec/batch)
2018-03-14 13:50:27.264905: step 30, loss = 4.42 (155.0 examples/sec; 0.826 sec/batch)
2018-03-14 13:50:35.519980: step 40, loss = 4.33 (155.1 examples/sec; 0.826 sec/batch)
2018-03-14 13:50:43.708251: step 50, loss = 4.34 (156.3 examples/sec; 0.819 sec/batch)
...
無事動きました。
しかし学習のログからは損失関数の値しかわかりません。
画像分類の正解率を知るためには、評価用のプログラムを動かします。
python cifar10_eval.py
2018-03-14 14:00:09.476609: precision @ 1 = 0.096
右に出ているのが正解率です。
今回は学習を短い時間しか動かしていないので9.6%という低い正解率になっていますが、学習途中の正解率を表示することができました。
ここまででtensorflowによるCNNを用いたCIFAR10の画像分類のチュートリアルを動かすことができました。
次は自分で用意したデータセットでの学習や、GCPを用いた画像分類に挑戦したいと思います。