4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

[LSTM] NewYorkTimes紙の記事分類

Last updated at Posted at 2018-11-19

About

・LSTMでNewYorkTimes紙の見出しをデータとして記事分類

import


import numpy as np
from keras.utils import np_utils
from keras.layers import Activation, Dense, Dropout, Embedding, LSTM
from keras.models import Sequential
from keras.preprocessing import sequence
from sklearn.model_selection import train_test_split

データとラベル

politics : 0
science : 1
sports : 2


DATA = [['Trump Calls China’s List of Trade Concessions ‘Not Acceptable’',0],
        ['                                                              ',0],
        .
        .
        .
        ['Where Will Science Take Us? To the Stars',1],
        ['                                        ',1],
        .
        .
        .
        ['Shohei Ohtani and Ronald Acuña Jr. Are Rookies of the Year',2],
        ['                                                          ',2],
        .
        .
        .
]

非常に簡略化したが、上記のようにしてデータをまとめた

データの整理

① 記事の見出しとラベルを分ける


# split (news and labels)
data_X = []
data_y = []
for data in DATA:
    data_X.append(data[0]) 
    data_y.append(data[1])

②クラスごとにいくつデータがあるか集計


c_0 = 0
c_1 = 0
c_2 = 0
for d in data_y:
    if d == 0:
        c_0 += 1
    elif d == 1:
        c_1 += 1
    else:
        c_2 += 1
        
print('politics : ' ,c_0)
print('science : ' ,c_1)
print('sports : ' ,c_2)
print('total : ' ,c_0 + c_1 + c_2)

politics : 1000
science : 1000
sports : 1000
total : 3000

③ラベルをone-hotエンコーディングする


def flatten(data):
    for item in data:
        if hasattr(item, '__iter__'):
            for element in flatten(item):
                yield element
        else:
            yield item

flatten(data_y)
data_y = np.array(data_y)
data_y = np_utils.to_categorical(data_y)
print('data_y_shape : ', data_y.shape)

data_y_shape : (3000, 3)

単語をインデックス化する

def sentence2words(sentence):
上から順に、小文字化→ 改行削除→ 記号をスペースに変換→ スペースで区切る(文章を単語区切りにする)
→ 数字削除→ ストップワード削除

・出現頻度の高い4000語(語彙数)のみを対象にする
・パディングのインデックスを0とする
・上記の4000語に入らなかった単語のインデックスを1とする
・ストップワードの頻度が高かったため、ストップワードのインデックスは2になった


# sentence to words
import re
import collections

def sentence2words(sentence):
    
    stopwords =["i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself",
                "yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its", "itself",
                "they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this", "that", "these",
                "those", "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having", "do", 
                "does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while", 
                "of", "at", "by", "for", "with", "about", "against", "between", "into", "through", "during", "before", 
                "after", "above", "below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again",
                "further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "any", "both", "each",
                "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", 
                "too", "very",  "can", "will", "just",  "should", "now", "s", "t", "don", "didn", "aren", "isn", "can", "re", "ll", "ve"]
    
    sentence = sentence.lower() # lower
    sentence = sentence.replace("\n", "") # delete new lines
    sentence = re.sub(re.compile(r"[!-\?()' ‘’.,;/:-@[-`{-~]"), " ", sentence) # symbol to space
    sentence = sentence.split(" ") # split words with space
    sentence_words = []
    for word in sentence:
        if (re.compile(r"^.*[0-9]+.*$").fullmatch(word) is not None): # delete words with numbers
            continue
        if word in stopwords: # delete words with stopwords 
            continue
        sentence_words.append(word)        
    return sentence_words
 
# word to index
word_to_index = {}
index_to_word = {}
num_recs = 0
maxlen = 0
word_freqs = collections.Counter()
for sentence in data_X:
    sentence_words = sentence2words(sentence)
    maxlen =max(maxlen, len(sentence_words))
    for word in sentence_words:
        word_freqs[word] += 1
    num_recs += 1
    
max_features = 4000
vocab_size = min(max_features, len(word_freqs)) + 2
word_to_index = {x[0]: i+2 for i, x in
              enumerate(word_freqs.most_common(max_features))}
word_to_index["PAD"] = 0
word_to_index["UNK"] = 1
index_to_word = {v: k for k, v in word_to_index.items()}
  

ここまでの確認

①文章数


# number of headlines
print('num_recs : ', num_recs)

num_recs : 3000

②一文の長さの最大値


# number of words in one headline
print('maxlen : ', maxlen)

maxlen : 25

③単語の種類数


print('word_freqs : ' ,len(word_freqs))

word_freqs : 6832

④インデックス化されている単語数


# 0 = padding
# 1 = unknown
# 2 = stop words
print('word_to_index : ', len(word_to_index))

word_to_index : 4002

文章をインデックスで表す


# headline to index
data_X_vec = []
for sentence in data_X:
    sentence_words = sentence2words(sentence)
    sentence_ids = []
    for word in sentence_words:
        if word not in word_to_index:
            sentence_ids.append(1)
            continue
        sentence_ids.append(word_to_index[word])
    data_X_vec.append(sentence_ids)

パディングして文章の長さを同じにする

長さは20に設定


data_X = sequence.pad_sequences(data_X_vec, maxlen=20)

訓練データとテストデータに分ける


X_train, X_test, y_train, y_test = train_test_split(data_X, data_y, test_size=0.2, random_state=0)

データの確認


print('X_train : ', X_train.shape)
print('X_test : ', X_test.shape)
print('y_train : ', y_train.shape)
print('y_test : ', y_test.shape)

X_train : (2400, 20)
X_test : (600, 20)
y_train : (2400, 3)
y_test : (600, 3)

モデル


# Build model
model = Sequential()
model.add(Embedding(4002,64,input_length=20))
model.add(Dropout(0.3))
model.add(LSTM(64, dropout=0.3, recurrent_dropout=0.3))
model.add(Dense(3))
model.add(Activation("softmax"))

model.compile(loss="categorical_crossentropy", optimizer="adam",
              metrics=["accuracy"])

model.summary()

Total params: 289,347
Trainable params: 289,347
Non-trainable params: 0

学習


history = model.fit(X_train, y_train, batch_size=20,
                    epochs=3,
                    validation_data=(X_test, y_test))

Train on 2400 samples, validate on 600 samples
Epoch 1/3
2400/2400 [==============================] - 7s 3ms/step - loss: 0.9500 - acc: 0.5138 - val_loss: 0.6780 - val_acc: 0.7317
Epoch 2/3
2400/2400 [==============================] - 3s 1ms/step - loss: 0.5800 - acc: 0.7554 - val_loss: 0.4645 - val_acc: 0.8533
Epoch 3/3
2400/2400 [==============================] - 3s 1ms/step - loss: 0.3229 - acc: 0.8875 - val_loss: 0.3778 - val_acc: 0.8733

テスト

・精度をチェック
・ランダムに選んだ10個の結果を確認

・ランダムに選んだ10個の中の唯一の間違い
predict : [1] label : 2 headline : illegal street sport takes country UNK
スポーツという単語が入っている文章を科学として分類してしまっている
一番簡単そうな文章を間違えて、その他は正解、謎である


score, acc = model.evaluate(X_test, y_test, batch_size=10)
print("Test score: {:.3f}, accuracy: {:.3f}".format(score, acc))


        
for i in range(10):
    idx = np.random.randint(len(X_test))
    xtest = X_test[idx].reshape(1,20)
    
    ypred = model.predict(xtest)
    predicted_class_indices=np.argmax(ypred,axis=1)
    ylabel = y_test[idx]
    ylabel = np.argmax(ylabel)
    headline = " ".join([index_to_word[x] for x in xtest[0].tolist() if x != 0])
    print('predict :', predicted_class_indices, 'label :', ylabel,'headline :', headline)

600/600 [==============================] - 0s 463us/step
Test score: 0.378, accuracy: 0.873
predict : [2] label : 2 headline : zach britton yankees cover UNK strength
predict : [1] label : 2 headline : illegal street sport takes country UNK
predict : [0] label : 0 headline : kavanaugh may hold key vote first death penalty case
predict : [0] label : 0 headline : democrats looking ahead see future female
predict : [2] label : 2 headline : rafael nadal outlasts UNK thiem marathon u open UNK
predict : [0] label : 0 headline : politics republicans fret key battleground races
predict : [0] label : 0 headline : randy rainbow singing political satirist spends sundays
predict : [0] label : 0 headline : pence speech string together narrative chinese aggression
predict : [0] label : 0 headline : rude terrible person midterms trump renews attacks press
predict : [0] label : 0 headline : politics lisa lerer war sexes

関連記事

NewYorkTimes紙の記事分類をCNNで学習したらどうなるか?
https://qiita.com/Phoeboooo/items/1beab63d257a6a89fb93

GloVeをファインチューニングしてみたら精度が上がった
https://qiita.com/Phoeboooo/items/bcbf51acd946d32416b6

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?