11
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

簡単なディープラーニングのサンプルコード (2入力1出力/2クラス分類) with Keras

Last updated at Posted at 2016-07-14

KDD CUP 2015に参加しているわけですが、ディープラーニングとか使ってみよ!
と思った時に参考になるコードが見つからなかったので置いておきます。
 

ラベルは0と1で、どちらかを推定します。
必要なライブラリはKeras、sklearn、pandasあたりです。
pipで依存関係よろしく入った気がしますが記憶が定かではありません。
Chainerもトライしましたが、結局僕の知識では扱いきれませんでした....
 

参考にしたコードはこちら。
kaggle ottoというコンペの問題を解くコードで、多クラス分類のコードです。
@yag_aysに教えてもらいました。ありがとうございます。
https://github.com/fchollet/keras/blob/master/examples/kaggle_otto_nn.py

Kerasについてはこちらの記事でも詳しく解説されてます。

"ディープラーニングをKerasというフレームワーク上で行う"
http://sharply.hatenablog.com/entry/2016/05/03/202806#
 
 
で、今回のコードの例が以下。 

keras_deep_nn_example.py
#!/usr/bin/python

from __future__ import absolute_import

import numpy as np
import pandas as pd

from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.advanced_activations import PReLU
from keras.utils import np_utils, generic_utils

from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler

def preprocess_data(X, scaler=None):
    if not scaler:
        scaler = StandardScaler()
        scaler.fit(X)
    X = scaler.transform(X)
    return X, scaler
    
def preprocess_labels(labels, encoder=None, categorical=True):
    if not encoder:
        encoder = LabelEncoder()
        encoder.fit(labels)
    y = encoder.transform(labels).astype(np.int32)
    if categorical:
        y = np_utils.to_categorical(y)
    return y, encoder

X_list = []
labels_list = []
# setup training data
for i in xrange(20000):
    # [1,2] => [0]
    X_list.append([1,2])
    labels_list.append(0)
    # [2,1] => [1]
    X_list.append([2,1])
    labels_list.append(1)

print("Loading data...")
X = np.array(X_list)
labels = np.array(labels_list)
X, scaler = preprocess_data(X)
Y, encoder = preprocess_labels(labels)

np.random.seed(1337) # for reproducibility

# input for predection
X_test_list = [[1,2],[2,1]]
X_test = np.array(X_test_list)
X_test, _ = preprocess_data(X_test_list, scaler)

nb_classes = Y.shape[1]
print(nb_classes, 'classes')

dims = X.shape[1]
print(dims, 'dims')

print("Building model...")

neuro_num = 16

# setup deep NN
model = Sequential()
model.add(Dense(dims, neuro_num, init='glorot_uniform'))
model.add(PReLU((neuro_num,)))
model.add(BatchNormalization((neuro_num,)))
model.add(Dropout(0.5))

model.add(Dense(neuro_num, neuro_num, init='glorot_uniform'))
model.add(PReLU((neuro_num,)))
model.add(BatchNormalization((neuro_num,)))
model.add(Dropout(0.5))

model.add(Dense(neuro_num, neuro_num, init='glorot_uniform'))
model.add(PReLU((neuro_num,)))
model.add(BatchNormalization((neuro_num,)))
model.add(Dropout(0.5))

model.add(Dense(neuro_num, nb_classes, init='glorot_uniform'))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy', optimizer="adam")

print("Training model...")
model.fit(X, Y, nb_epoch=20, batch_size=128, validation_split=0.15)

print("Prediction...")
proba = model.predict_proba(X_test)

# predicted result
print("probability of [label=0 label=1]")
print("  input: [1,2] => " + str(proba[0]))
print("  input: [2,1] => " + str(proba[1]))

https://github.com/ryogrid/kddcup2015/blob/master/train/keras_deep_nn_example.py
♯16/07/15 Kerasのバージョンが最新だと、これだと動かないようです。↓のコードを参考にネットワークのところだけ修正してみて下さい
https://github.com/ryogrid/fx_systrade/blob/e93103c804e992bc21b7afc8a0f707ec4ce37029/keras_trade.py

【実行結果】

(Kerasの出力は省略)
probability of [label=0 label=1]
 input: [1,2] => [ 9.99876779e-01 2.05630928e-04]
 input: [2,1] => [ 1.21604726e-04 9.99792734e-01]

 
やってることはコードを見てもらえば大体分かるかと思います。
普通は、もっとたくさんのパターンで学習させるし、予測する時も学習させた時のパターンそのままとかはあり得ないと思いますが、まあそこはご愛嬌ということで。
もし、これはディープラーニングじゃない、ということであればコメント欄でツッコミお願いします。
♯ 確かに CNN, RBN, Auto Encoder(&Decoder) とか使ってないしなあ...

あと、RからH2Oを叩いても簡単にできるみたいです。
「Rで一行でディープラーニング」
http://d.hatena.ne.jp/dichika/20140503/p1
けど、試してみたらR周りでトラブったりしたので、H2Oのウェブインタフェース (Flowと呼ぶらしい) で頑張ったほうが早いかもです。

#もっと簡単な方法もあるようです(11/8 追記)
http://qiita.com/rindai87/items/546991f5ecae0ef7cde3

11
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?