RNNだけでなく、CNN(Conv1D)でも文章・時系列データを学習できる
・ニューヨークタイムズ紙の記事分類をLSTMとCNNでそれぞれ学習して比較
・比較のための条件設定として、とりあえずパラメーター数を同じぐらいにしておいた
NewYorkTimes紙の記事分類やってみた
https://qiita.com/Phoeboooo/items/dbc8bb5308c9ec5af6a0
データ
politics : 1000
science : 1000
sports : 1000
total : 3000
LSTM
model = Sequential()
model.add(Embedding(4002,64,input_length=20))
model.add(Dropout(0.3))
model.add(LSTM(64, dropout=0.3, recurrent_dropout=0.3))
model.add(Dense(3))
model.add(Activation("softmax"))
model.compile(loss="categorical_crossentropy", optimizer="adam",
metrics=["accuracy"])
model.summary()
Total params: 289,347
Trainable params: 289,347
Non-trainable params: 0
結果
Train on 2400 samples, validate on 600 samples
Epoch 1/3
2400/2400 [==============================] - 7s 3ms/step - loss: 0.9500 - acc: 0.5138 - val_loss: 0.6780 - val_acc: 0.7317
Epoch 2/3
2400/2400 [==============================] - 3s 1ms/step - loss: 0.5800 - acc: 0.7554 - val_loss: 0.4645 - val_acc: 0.8533
Epoch 3/3
2400/2400 [==============================] - 3s 1ms/step - loss: 0.3229 - acc: 0.8875 - val_loss: 0.3778 - val_acc: 0.8733
Test score: 0.378, accuracy: 0.873
CNN
model = Sequential()
model.add(layers.Embedding(4002,64,input_length=20))
model.add(layers.Conv1D(32, 2, activation='relu'))
model.add(layers.MaxPooling1D(5))
model.add(layers.Conv1D(32, 2, activation='relu'))
model.add(layers.GlobalMaxPooling1D())
model.add(Dense(3))
model.add(Activation("softmax"))
model.compile(loss="categorical_crossentropy", optimizer="adam",
metrics=["accuracy"])
model.summary()
Total params: 262,435
Trainable params: 262,435
Non-trainable params: 0
結果
Epoch 1/10
2400/2400 [==============================] - ETA: 0s - loss: 1.0442 - acc: 0.469 - 1s 585us/step - loss: 1.0417 - acc: 0.4704 - val_loss: 0.8294 - val_acc: 0.6350
Epoch 2/10
2400/2400 [==============================] - 1s 331us/step - loss: 0.6708 - acc: 0.7533 - val_loss: 0.5984 - val_acc: 0.7617
Epoch 3/10
2400/2400 [==============================] - 1s 326us/step - loss: 0.3608 - acc: 0.8721 - val_loss: 0.5604 - val_acc: 0.7667
.
.
.
Epoch 8/10
2400/2400 [==============================] - 1s 357us/step - loss: 0.1527 - acc: 0.9262 - val_loss: 0.7415 - val_acc: 0.7583
Epoch 9/10
2400/2400 [==============================] - 1s 364us/step - loss: 0.1458 - acc: 0.9312 - val_loss: 0.7581 - val_acc: 0.7567
Epoch 10/10
2400/2400 [==============================] - 1s 357us/step - loss: 0.1420 - acc: 0.9292 - val_loss: 0.7744 - val_acc: 0.7600
Test score: 0.774, accuracy: 0.760
比較結果
LSTM : 87.3%
CNN : 76.0%
基本通りに時系列データではRNNを使った方が結果が良かった