Help us understand the problem. What is going on with this article?

GloVeをファインチューニングして記事分類

More than 1 year has passed since last update.

事前学習済みのモデルGloVeを利用してEmbedding層の重みの初期値を設定する
タスクはNew York Times紙の記事分類

GloVeを使わずに分類した結果は過去記事に掲載
https://qiita.com/Phoeboooo/items/dbc8bb5308c9ec5af6a0

① GloVeに含まれている単語とベクトルを辞書にして整理する

embeddings_index = {}
GLOVE_DIR = os.path.join('Downloads', 'glove.6B')
f = open(os.path.join(GLOVE_DIR, 'glove.6B.300d.txt'), "r", encoding="utf-8" )
for line in f:
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

print('Found %s word vectors.' % len(embeddings_index))

Found 400000 word vectors.

② 使用する単語とGloVeのベクトルを結ぶ

・GloVeに含まれていない単語はすべて0のベクトルとして初期化する

・GloVeの場合、EMBEDDING_DIMは50,100,200,300の中から選べる
 すべて試して一番精度が出た300を使っている

EMBEDDING_DIM = 300
embedding_matrix = np.zeros((len(word_to_index), EMBEDDING_DIM))
for word, i in word_to_index.items():
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        # words not found in embedding index will be all-zeros.
        embedding_matrix[i] = embedding_vector

③ モデル

weights=[embedding_matrix]
weights に先程の embedding_matrixを指定する
その他は同じ

*Embedding層にtrainable=Falseを追加することで、
 GloVeで定めたベクトルを学習によって変更されないようにすることができる
 今回はtrainable=Falseにするよりもデフォルトの方が結果が良かったため追加していない

# Build model
model = Sequential()
model.add(Embedding(len(word_to_index),
                    EMBEDDING_DIM,
                    weights=[embedding_matrix],
                    input_length=20))
model.add(Dropout(0.3))
model.add(LSTM(64, dropout=0.3, recurrent_dropout=0.3))
model.add(Dense(3))
model.add(Activation("softmax"))

model.compile(loss="categorical_crossentropy", optimizer="adam",
              metrics=["accuracy"])

model.summary()

Total params: 1,294,535
Trainable params: 1,294,535
Non-trainable params: 0

学習と結果

Train on 2400 samples, validate on 600 samples
Epoch 1/2
2400/2400 [==============================] - 10s 4ms/step - loss: 0.6011 - acc: 0.7633 - val_loss: 0.2866 - val_acc: 0.9100
Epoch 2/2
2400/2400 [==============================] - 8s 3ms/step - loss: 0.3088 - acc: 0.8892 - val_loss: 0.2497 - val_acc: 0.9217

val_loss: 0.2497 - val_acc: 0.9217

(参考)ファインチューニングなし
val_loss: 0.3778 - val_acc: 0.8733

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away