LoginSignup
1
3

More than 5 years have passed since last update.

GloVeをファインチューニングして記事分類

Last updated at Posted at 2018-11-28

事前学習済みのモデルGloVeを利用してEmbedding層の重みの初期値を設定する
タスクはNew York Times紙の記事分類

GloVeを使わずに分類した結果は過去記事に掲載
https://qiita.com/Phoeboooo/items/dbc8bb5308c9ec5af6a0

① GloVeに含まれている単語とベクトルを辞書にして整理する


embeddings_index = {}
GLOVE_DIR = os.path.join('Downloads', 'glove.6B')
f = open(os.path.join(GLOVE_DIR, 'glove.6B.300d.txt'), "r", encoding="utf-8" )
for line in f:
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

print('Found %s word vectors.' % len(embeddings_index))

Found 400000 word vectors.

② 使用する単語とGloVeのベクトルを結ぶ

・GloVeに含まれていない単語はすべて0のベクトルとして初期化する

・GloVeの場合、EMBEDDING_DIMは50,100,200,300の中から選べる
 すべて試して一番精度が出た300を使っている


EMBEDDING_DIM = 300
embedding_matrix = np.zeros((len(word_to_index), EMBEDDING_DIM))
for word, i in word_to_index.items():
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        # words not found in embedding index will be all-zeros.
        embedding_matrix[i] = embedding_vector

③ モデル

weights=[embedding_matrix]
weights に先程の embedding_matrixを指定する
その他は同じ

*Embedding層にtrainable=Falseを追加することで、
 GloVeで定めたベクトルを学習によって変更されないようにすることができる
 今回はtrainable=Falseにするよりもデフォルトの方が結果が良かったため追加していない


# Build model
model = Sequential()
model.add(Embedding(len(word_to_index),
                    EMBEDDING_DIM,
                    weights=[embedding_matrix],
                    input_length=20))
model.add(Dropout(0.3))
model.add(LSTM(64, dropout=0.3, recurrent_dropout=0.3))
model.add(Dense(3))
model.add(Activation("softmax"))

model.compile(loss="categorical_crossentropy", optimizer="adam",
              metrics=["accuracy"])

model.summary()

Total params: 1,294,535
Trainable params: 1,294,535
Non-trainable params: 0

学習と結果


Train on 2400 samples, validate on 600 samples
Epoch 1/2
2400/2400 [==============================] - 10s 4ms/step - loss: 0.6011 - acc: 0.7633 - val_loss: 0.2866 - val_acc: 0.9100
Epoch 2/2
2400/2400 [==============================] - 8s 3ms/step - loss: 0.3088 - acc: 0.8892 - val_loss: 0.2497 - val_acc: 0.9217

val_loss: 0.2497 - val_acc: 0.9217

(参考)ファインチューニングなし
val_loss: 0.3778 - val_acc: 0.8733

1
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3