2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

CNNを利用して判別(Udemyコースの応用)

Last updated at Posted at 2019-05-06

1.はじめに

1.1 お断り

自身がUdemyを利用して学んだことをアウトプットするものです。
必ずしもUdemyがすべての人に有用だと主張するつもりはありません。
Udemyのご利用はご自身の判断でお願いします。
いいコースと思って利用していますが、それはあくまで私個人の主観でありすべての人に適用されるものではありません。
また、私は一利用者であり、Udemyやコースの関係者ではありません。

1.2 目的

Udemyのコースを利用して学習をしていましたがアウトプットとして利用するものです。

1.3 対象コース

ディープラーニング : Pythonでゼロから構築し学ぶ人工知能(AI)と深層学習の原理

1.3.1対象内容

セクション8 畳み込みニューラルネットワーク(CNN)

1.4 現状の確認

内容は講習で説明されているので概略のみ。

  • 入力はscikit-learnに収録されているdigits。8×8のサイズの0~9までの手書き文字1797個。
  • 入力の標準化を実施((入力-平均値)/標準偏差)、
  • 正解をon-hotにする。正解の0~9を配列の0~9番目の配列値=1.0に変更
  • 訓練データ、テストデータに分割。(3の倍数がテストデータ、倍数でないものが訓練データ)
  • 画像のサイズ、チャンネル、エポック(50)、バッチサイズ(8)

1.5 変更内容

1.5.1 入力を差し替えてみる。

入力をdigitsからMNINSTに差し替えてみます。

###1.5.1.1 読み込み部分

import keras
# -- 手書き文字データセットの読み込み --
(input_train, train_labels), (input_test, test_labels) = keras.datasets.mnist.load_data()

###1.5.1.2 標準化
(入力-平均値)/標準偏差 で計算します。
平均値、標準偏差はnumpyではaverage、stdメソッドでそれぞれ簡単に出せます。
digitsでは訓練データ、テストデータが分割されていない状態でしたが、
MNINSTでは分かれているため、一旦結合して算出します。

input_all = np.concatenate([input_train,input_test])
ave_input_all = np.average(input_all)
std_input_all = np.std(input_all)

input_train = (input_train - ave_input_all) / std_input_all
input_test = (input_test - ave_input_all) / std_input_all

###1.5.1.3 one-hot変換
学習データ、テストデータともにone-hot化します。
出力を1~10をそれぞれ10timesの出力に変換します。
元ソースではこの時点では学習・テスト分かれていなかったため。
行っている内容自体は変更ありません。

#トレーニングデータ
correct_train = np.zeros((input_train.shape[0], 10))
for i in range(input_train.shape[0]):
    correct_train[i, train_labels[i]] = 1.0

#テストデータ
correct_test = np.zeros((input_test.shape[0], 10))
for i in range(input_test.shape[0]):
    correct_test[i, test_labels[i]] = 1.0

###1.5.1.4 学習データ/テストデータ分離
digitは学習データ/テストデータの分離はされていませんが
mninstは分離されているためこの処理はコメント化します。

###1.5.1.5 設定値
縦横サイズを修正

img_h = 28  # 入力画像の高さ
img_w = 28  # 入力画像の幅

これで動くはずです。
ただしデータ数が1700=>70000、サイズが8×8=>28×28になったため
単純計算で500倍程度の計算量になります。
作業用PCでは2時間程度かかりました。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?