5
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

文章・時系列データ分類を "LSTM" と ”CNN” で比較

Last updated at Posted at 2018-11-18

RNNだけでなく、CNN(Conv1D)でも文章・時系列データを学習できる
・ニューヨークタイムズ紙の記事分類をLSTMとCNNでそれぞれ学習して比較
・比較のための条件設定として、とりあえずパラメーター数を同じぐらいにしておいた

NewYorkTimes紙の記事分類やってみた
https://qiita.com/Phoeboooo/items/dbc8bb5308c9ec5af6a0

データ

politics : 1000
science : 1000
sports : 1000
total : 3000

LSTM


model = Sequential()
model.add(Embedding(4002,64,input_length=20))
model.add(Dropout(0.3))
model.add(LSTM(64, dropout=0.3, recurrent_dropout=0.3))
model.add(Dense(3))
model.add(Activation("softmax"))

model.compile(loss="categorical_crossentropy", optimizer="adam",
              metrics=["accuracy"])

model.summary()

Total params: 289,347
Trainable params: 289,347
Non-trainable params: 0

結果


Train on 2400 samples, validate on 600 samples
Epoch 1/3
2400/2400 [==============================] - 7s 3ms/step - loss: 0.9500 - acc: 0.5138 - val_loss: 0.6780 - val_acc: 0.7317
Epoch 2/3
2400/2400 [==============================] - 3s 1ms/step - loss: 0.5800 - acc: 0.7554 - val_loss: 0.4645 - val_acc: 0.8533
Epoch 3/3
2400/2400 [==============================] - 3s 1ms/step - loss: 0.3229 - acc: 0.8875 - val_loss: 0.3778 - val_acc: 0.8733

Test score: 0.378, accuracy: 0.873

CNN




model = Sequential()
model.add(layers.Embedding(4002,64,input_length=20))
model.add(layers.Conv1D(32, 2, activation='relu'))
model.add(layers.MaxPooling1D(5))
model.add(layers.Conv1D(32, 2, activation='relu'))
model.add(layers.GlobalMaxPooling1D())
model.add(Dense(3))
model.add(Activation("softmax"))

model.compile(loss="categorical_crossentropy", optimizer="adam",
              metrics=["accuracy"])

model.summary()

Total params: 262,435
Trainable params: 262,435
Non-trainable params: 0

結果


Epoch 1/10
2400/2400 [==============================] - ETA: 0s - loss: 1.0442 - acc: 0.469 - 1s 585us/step - loss: 1.0417 - acc: 0.4704 - val_loss: 0.8294 - val_acc: 0.6350
Epoch 2/10
2400/2400 [==============================] - 1s 331us/step - loss: 0.6708 - acc: 0.7533 - val_loss: 0.5984 - val_acc: 0.7617
Epoch 3/10
2400/2400 [==============================] - 1s 326us/step - loss: 0.3608 - acc: 0.8721 - val_loss: 0.5604 - val_acc: 0.7667
.
.
.
Epoch 8/10
2400/2400 [==============================] - 1s 357us/step - loss: 0.1527 - acc: 0.9262 - val_loss: 0.7415 - val_acc: 0.7583
Epoch 9/10
2400/2400 [==============================] - 1s 364us/step - loss: 0.1458 - acc: 0.9312 - val_loss: 0.7581 - val_acc: 0.7567
Epoch 10/10
2400/2400 [==============================] - 1s 357us/step - loss: 0.1420 - acc: 0.9292 - val_loss: 0.7744 - val_acc: 0.7600

Test score: 0.774, accuracy: 0.760

比較結果

LSTM : 87.3%  
CNN : 76.0%

基本通りに時系列データではRNNを使った方が結果が良かった

5
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?