この記事を見て、キュウリの仕分けに挑戦してみた
噂の「TensorFlowでキュウリの仕分けを行うマシン」がMFT2016に展示されていたので実物を見てきた
GitHub
使用したデータはprototype_1のほう
展開したデータを'./cucumber-batches-py'に置く
とりあえず適当なモデルを組んで試してみたら70%台くらいの精度しか出なかったけど、モデルを複雑にしてもあまり変わらなかったのでデータを水増ししてみた
撮影条件は揃っているようだし、大きさや艶も仕分けに関係あるらしいので、やったのは移動・回転・反転のみ
使用しているのはKeras、AWSのg2.2xlargeインスタンスで実行
コード
cucumber.py
#coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers.core import Dense, Activation, Flatten, Dropout
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers.advanced_activations import ELU
from keras.callbacks import Callback, EarlyStopping
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
try:
import cPickle as pickle
except:
import pickle
import numpy as np
import sys, os, time
class Evaluate(Callback):
epoch = 0
def on_epoch_end(self, logs={}, acc=None):
self.epoch += 1
if self.epoch % 10 == 0:
score = self.model.evaluate(X_test, y_test, batch_size=495)
print('test - loss: {loss:.4f} - acc: {accuracy:.4f}'.format(loss=score[0], accuracy=score[1]))
def unpickle(file_name):
with open(file_name, 'rb') as f:
if sys.version_info.major == 2:
return pickle.load(f)
elif sys.version_info.major == 3:
return pickle.load(f, encoding='latin-1')
def load_data():
path = './cucumber-batches-py'
nb_train_samples = 2475
X_train = np.zeros((nb_train_samples, 3, 32, 32), dtype='uint8')
y_train = np.zeros((nb_train_samples,), dtype='uint8')
for i in range(1, 6):
fpath = os.path.join(path, 'data_batch_' + str(i))
batch_dict = unpickle(fpath)
data = batch_dict['data']
labels = batch_dict['labels']
X_train[(i-1)*495:i*495, :, :, :] = data.reshape(495, 3, 32, 32)
y_train[(i-1)*495:i*495] = labels
fpath = os.path.join(path, 'test_batch')
batch_dict = unpickle(fpath)
data = batch_dict['data']
labels = batch_dict['labels']
X_test = data.reshape(495, 3, 32, 32)
y_test = labels
return (X_train, y_train), (X_test, y_test)
nb_classes = 9
(X_train, y_train), (X_test, y_test) = load_data()
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
y_train = np_utils.to_categorical(y_train, nb_classes)
y_test = np_utils.to_categorical(y_test, nb_classes)
print('Data loaded')
data_gen = ImageDataGenerator(horizontal_flip=True, width_shift_range=4.0/32, height_shift_range=4.0/32, rotation_range=10, fill_mode='nearest', cval=1.0)
data_gen.fit(X_train)
val_gen = ImageDataGenerator()
val_gen.fit(X_train)
model = Sequential()
for pool in range(4):
for conv in range(4):
if pool == 0 and conv == 0:
model.add(Convolution2D(nb_filter=8 * 2**pool, nb_row=3, nb_col=3, border_mode='same', input_shape=(3, 32, 32)))
else:
model.add(Convolution2D(nb_filter=8 * 2**pool, nb_row=3, nb_col=3, border_mode='same'))
model.add(ELU())
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2), border_mode='same'))
model.add(Flatten())
n = 3
drop = 1.0 - np.power(0.5, 1.0 / n)
for dense in range(1, 4):
model.add(Dense(4096))
model.add(ELU())
model.add(Dropout(drop))
model.add(BatchNormalization())
model.add(Dense(nb_classes))
model.add(Activation('softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
evaluate = Evaluate()
stop = EarlyStopping(monitor='val_loss', patience=30)
begin = time.clock()
model.fit_generator(data_gen.flow(X_train, y_train, batch_size=128, shuffle=True), samples_per_epoch=12800, nb_epoch=1000, verbose=2, callbacks=[evaluate, stop], validation_data=val_gen.flow(X_train, y_train, batch_size=495*2, shuffle=True), nb_val_samples=495*2)
print('Time elapsed: %.0f' % (time.clock() - begin))
score = model.evaluate(X_test, y_test, batch_size=495)
print('test - loss: {loss:.4f} - acc: {accuracy:.4f}'.format(loss=score[0], accuracy=score[1]))
結果
Train Accuracy : 0.96
Test Accuracy : 0.89
Time Elapsed : 19480s(71s/epoch)
実際に32X32の画像で人間がどれくらいの精度出せるのかベンチマークがないと、評価が難しい
データに偏りや不足があるかもしれないし、画像が小さすぎて必要な情報が欠落しているかもしれない