LoginSignup
8

More than 5 years have passed since last update.

TensorflowでAutoEncoderをやってみた

Posted at

AutoEncoder(自己符号化器)とは、機械学習において、ニューラルネットワークを使用した次元圧縮のためのアルゴリズム。(From Wiki

今回はAutoEncoderを使って、画像の異常検出をやってみたいです。
使用する画像はWorkpilesさんのきゅうりの画像(勝手に使って申し訳ないですけど…)

687474703a2f2f776f726b70696c65732e636f6d2f776f726470726573732f77702d636f6e74656e742f75706c6f6164732f323031362f30322f637563756d6265725f636c617373696669636174696f6e2e6a7067.jpg
2L分類のきゅうり:正常
C分類のきゅうり:異常
にします。

モデル

Input:32x32x3の画像
Encoder:畳み込み層+プーリング*3
 32x32x3 → 16x16x32 → 8x8x16 → 4x4x8
 つまり、抽出された特徴は128です。
Decoder:4x4x8の特徴から32x32x3画像へ復元

ソースコード

Githubに保存しました。

encoder
W_conv1 = weight_variable([3, 3, DEPTH, 32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

W_conv2 = weight_variable([3, 3, 32, 16])
b_conv2 = bias_variable([16])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

W_conv3 = weight_variable([3, 3, 16, 8])
b_conv3 = bias_variable([8])
h_conv3 = tf.nn.relu(conv2d(h_pool2, W_conv3) + b_conv3)
h_pool3 = max_pool_2x2(h_conv3)
decoder
de_W_conv3 = weight_variable([3, 3, 16, 8])
de_b_conv3 = bias_variable([16])
de_h_conv3 = tf.nn.relu(de_conv2d(x, de_W_conv3, [batch_size, WIDTH//4, HEIGHT//4, 16], [1,2,2,1]) + de_b_conv3)

de_W_conv2 = weight_variable([3, 3, 32, 16])
de_b_conv2 = bias_variable([32])
de_h_conv2 = tf.nn.relu(de_conv2d(de_h_conv3, de_W_conv2, [batch_size, WIDTH//2, HEIGHT//2, 32], [1,2,2,1]) + de_b_conv2)

de_W_conv1 = weight_variable([3, 3, 3, 32])
de_b_conv1 = bias_variable([3])
de_h_conv1 = tf.nn.relu(de_conv2d(de_h_conv2, de_W_conv1, [batch_size, WIDTH, HEIGHT, 3], [1,2,2,1]) + de_b_conv1)

結果

Lossが0.003以下になると、学習を中止します。

きゅうりの分類が2L(正常)とC(異常)のデータ各50枚を使って検証してみました。
Lossが学習時Lossの95%信頼区間の1.2倍を超える場合は異常と認識します。

             precision    recall  f1-score   support

          0       0.74      0.78      0.76        50
          1       0.77      0.72      0.74        50

avg / total       0.75      0.75      0.75       100

[[39 11]
 [14 36]]

正解率は75%です。

最後に、生成した画像を見ましょう。
Origin: 0.jpg1.jpg2.jpg3.jpg4.jpg
Faked: 0fake.jpg1fake.jpg2fake.jpg3fake.jpg4fake.jpg

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8