Help us understand the problem. What is going on with this article?

【Keras入門(6)】単純なRNNモデル定義(最終出力のみ使用)

More than 1 year has passed since last update.

前回記事「【Keras入門(5)】単純なRNNモデル定義」では、RNNを使って1つの入力値に対して次の値を予測しました。今回は、10個の入力値に対して1つの出力をするモデルにします。
これは、文章に対するネガポジ予測(ネガティブ/ポジティブ)や、文書の分類などに使えます。

以下のシリーズにしています。

- 【Keras入門(1)】単純なディープラーニングモデル定義
- 【Keras入門(2)】訓練モデル保存(KerasモデルとSavedModel)
- 【Keras入門(3)】TensorBoardで見える化
- 【Keras入門(4)】Kerasの評価関数(Metrics)
- 【Keras入門(5)】単純なRNNモデル定義
- 【Keras入門(6)】単純なRNNモデル定義(最終出力のみ使用) <- 本記事
- 【Keras入門(7)】単純なSeq2Seqモデル定義

使ったPythonパッケージ

Google Colaboratoryでインストール済の以下のパッケージとバージョンを使っています。KerasはTensorFlowに統合されているものを使っているので、ピュアなKerasは使っていません。Pythonは3.6です。

  • tensorflow: 1.14.0
  • Numpy: 1.16.4
  • matplotlib: 3.0.3

処理概要

等差数列が増加か減少かを判断します。

1st 2nd 3rd 4th 5th 6th 7th 8th 9th 10th 等差数列
0 1 2 3 4 5 6 7 8 9 増加(0)
0 -1 -2 -3 -4 -5 -6 -7 -8 -9 減少(0)

処理プログラム

プログラム全体はGitHubを参照ください。

1. ライブラリインポート

前回から追加して乱数を発生するrandomも読み込んでいます。

from random import randint

import numpy as np
import matplotlib.pyplot as plt

# TensorFlowに統合されたKerasを使用
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, SimpleRNN

2. 前処理

等差数列の配列を増加・減少パターンを交互に作っています。

NUM_RNN = 10
NUM_DATA = 200

# 空の器を作成
x_train = np.empty((0, NUM_RNN))
y_train = np.empty((0, 1))

for i in range(NUM_DATA):
    num_random = randint(-20, 20)
    if i % 2 == 1:  # 奇数の場合
        x_train = np.append(x_train, np.linspace(num_random, num_random+NUM_RNN-1, num=NUM_RNN).reshape(1, NUM_RNN), axis=0)
        y_train = np.append(y_train, np.zeros(1).reshape(1, 1), axis=0)
    else: # 偶数の場合
        x_train = np.append(x_train, np.linspace(num_random, num_random-NUM_RNN+1, num=NUM_RNN).reshape(1, NUM_RNN), axis=0)
        y_train = np.append(y_train, np.ones(1).reshape(1, 1), axis=0)

x_train = x_train.reshape(NUM_DATA, NUM_RNN, 1)
y_train = y_train.reshape(NUM_DATA, 1)

3. モデル定義

今回のRNNモデルは前回と異なり、最終出力のみを使います。そのために下図のようなモデルです。
60.Keras_RNN_Overview01.JPG

参考として、前回のモデル(最終出力以外も使用)はこんなでした。
51.Keras_RNN_Overview01.JPG

SimpleRNN関数return_sequencesの値をFalseにして使わないようにします。また、最後の全結合層は1次元にして二値分類です。

NUM_DIM = 16  # 中間層の次元数

model = Sequential()

# return_sequenceがFalseなので最後のRNN層のみが出力を返す
model.add(SimpleRNN(NUM_DIM, batch_input_shape=(None, NUM_RNN, 1), return_sequences=False))
model.add(Dense(1, activation='sigmoid'))  #全結合層
model.compile(loss='binary_crossentropy', optimizer='adam')

model.summary()

summary関数で以下のようなモデルサマリを出してくれます。

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
simple_rnn (SimpleRNN)       (None, 8)                 80        
_________________________________________________________________
dense (Dense)                (None, 1)                 9         
=================================================================
Total params: 89
Trainable params: 89
Non-trainable params: 0
_________________________________________________________________

4. 訓練実行

fit関数を使って訓練実行です。30epoch程度でそこそこいい精度が出ます。

history = model.fit(x_train, y_train, epochs=30, batch_size=8)
loss = history.history['loss']

plt.plot(np.arange(len(loss)), loss) # np.arangeはlossの連番数列を生成(今回はepoch数の0から29)
plt.show()

image.png

5. テスト

最後にテストです。

5.1. テスト実行

訓練データの最初の10件をテストデータとします。
predict関数を使ってテストデータから予測値を出力します。

# データ数(10回)ループ
for i in range(10):
    y_pred = model.predict(x_train[i].reshape(1, NUM_RNN, 1))
    print(y_pred[0], ':', x_train[i].reshape(NUM_RNN))

見た限り全問正解です(4件だけ抜粋)。

[0.9673256] : [ -6.  -7.  -8.  -9. -10. -11. -12. -13. -14. -15.]
[0.01308651] : [-6. -5. -4. -3. -2. -1.  0.  1.  2.  3.]
[0.9779032] : [12. 11. 10.  9.  8.  7.  6.  5.  4.  3.]
[0.01788531] : [-2. -1.  0.  1.  2.  3.  4.  5.  6.  7.]
FukuharaYohei
気の向いたままにいろいろと書きます。 仕事はSAP関連で、HANA、Fiori、SAPUI5、BusinessObjectsなどいろいろやっています。
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away