5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

KerasでLSTMを用いた多クラス分類

Last updated at Posted at 2022-02-26

はじめに

  • 今回はKerasでLSTMを用いた多クラス分類を実装してみます。livedoorニュースコーパスの多クラス分類を行ってみました。

データの取得

# Livedoorニュースのファイルをダウンロード&解凍
! wget "https://www.rondhuit.com/download/ldcc-20140209.tar.gz"
!tar -zxvf ldcc-20140209.tar.gz

# tsvファイルの作成
!echo -e "filename\tarticle"$(for category in $(basename -a `find ./text -type d` | grep -v text | sort); do echo -n "\t"; echo -n $category; done) > ./text/livedoor.tsv
!for filename in `basename -a ./text/dokujo-tsushin/dokujo-tsushin-*`; do echo -n "$filename"; echo -ne "\t"; echo -n `sed -e '1,3d' ./text/dokujo-tsushin/$filename`; echo -e "\t1\t0\t0\t0\t0\t0\t0\t0\t0"; done >> ./text/livedoor.tsv
!for filename in `basename -a ./text/it-life-hack/it-life-hack-*`; do echo -n "$filename"; echo -ne "\t"; echo -n `sed -e '1,3d' ./text/it-life-hack/$filename`; echo -e "\t0\t1\t0\t0\t0\t0\t0\t0\t0"; done >> ./text/livedoor.tsv
!for filename in `basename -a ./text/kaden-channel/kaden-channel-*`; do echo -n "$filename"; echo -ne "\t"; echo -n `sed -e '1,3d' ./text/kaden-channel/$filename`; echo -e "\t0\t0\t1\t0\t0\t0\t0\t0\t0"; done >> ./text/livedoor.tsv
!for filename in `basename -a ./text/livedoor-homme/livedoor-homme-*`; do echo -n "$filename"; echo -ne "\t"; echo -n `sed -e '1,3d' ./text/livedoor-homme/$filename`; echo -e "\t0\t0\t0\t1\t0\t0\t0\t0\t0"; done >> ./text/livedoor.tsv
!for filename in `basename -a ./text/movie-enter/movie-enter-*`; do echo -n "$filename"; echo -ne "\t"; echo -n `sed -e '1,3d' ./text/movie-enter/$filename`; echo -e "\t0\t0\t0\t0\t1\t0\t0\t0\t0"; done >> ./text/livedoor.tsv
!for filename in `basename -a ./text/peachy/peachy-*`; do echo -n "$filename"; echo -ne "\t"; echo -n `sed -e '1,3d' ./text/peachy/$filename`; echo -e "\t0\t0\t0\t0\t0\t1\t0\t0\t0"; done >> ./text/livedoor.tsv
!for filename in `basename -a ./text/smax/smax-*`; do echo -n "$filename"; echo -ne "\t"; echo -n `sed -e '1,3d' ./text/smax/$filename`; echo -e "\t0\t0\t0\t0\t0\t0\t1\t0\t0"; done >> ./text/livedoor.tsv
!for filename in `basename -a ./text/sports-watch/sports-watch-*`; do echo -n "$filename"; echo -ne "\t"; echo -n `sed -e '1,3d' ./text/sports-watch/$filename`; echo -e "\t0\t0\t0\t0\t0\t0\t0\t1\t0"; done >> ./text/livedoor.tsv
!for filename in `basename -a ./text/topic-news/topic-news-*`; do echo -n "$filename"; echo -ne "\t"; echo -n `sed -e '1,3d' ./text/topic-news/$filename`; echo -e "\t0\t0\t0\t0\t0\t0\t0\t0\t1"; done >> ./text/livedoor.tsv
  • 作成したtsvファイルをpandasで読み込みます。下記画像のように、クリーニングした記事の内容とニュースカテゴリのone-hot表現を持ったDataFrameが作成できます。
import pandas as pd 
df = pd.read_csv('text/livedoor.tsv', sep='\t')

image.png

  • 正解ラベルのone-hotベクトルの作成
# 正解ラベルのone-hotベクトルの作成
df2 = df.copy()
df2 = df2[['dokujo-tsushin', 'it-life-hack',
       'kaden-channel', 'livedoor-homme', 'movie-enter', 'peachy', 'smax',
       'sports-watch', 'topic-news']]

df['one-hot'] = ''
for index, rows in df2.iterrows():
    df['one-hot'].iloc[index] = list(rows)

クリーニング

import re, unicodedata

class CleaningData:
    def __init__(self, df, target_column):
        self.df = df
        self.target_column = target_column

    def cleaning(self):
        self.df[self.target_column] = self.df[self.target_column].map(self.remove_extra_spaces)
        self.df[self.target_column] = self.df[self.target_column].map(self.normalize_neologd)

        # クリーニングの過程でtextが空になった行を削除
        self.df = self.df[self.df[self.target_column] != '']
        self.df = self.df[self.df[self.target_column] != '']
        self.df = self.df.reset_index()
        return self.df

    def unicode_normalize(self, cls, s):
        pt = re.compile('([{}]+)'.format(cls))

        def norm(c):
            return unicodedata.normalize('NFKC', c) if pt.match(c) else c

        s = ''.join(norm(x) for x in re.split(pt, s))
        s = re.sub('', '-', s)
        return s

    def remove_extra_spaces(self, s):
        s = re.sub('[  ]+', ' ', s)
        blocks = ''.join(('\u4E00-\u9FFF',  # CJK UNIFIED IDEOGRAPHS
                          '\u3040-\u309F',  # HIRAGANA
                          '\u30A0-\u30FF',  # KATAKANA
                          '\u3000-\u303F',  # CJK SYMBOLS AND PUNCTUATION
                          '\uFF00-\uFFEF'   # HALFWIDTH AND FULLWIDTH FORMS
                          ))
        basic_latin = '\u0000-\u007F'

        def remove_space_between(cls1, cls2, s):
            p = re.compile('([{}]) ([{}])'.format(cls1, cls2))
            while p.search(s):
                s = p.sub(r'\1\2', s)
            return s

        s = remove_space_between(blocks, blocks, s)
        s = remove_space_between(blocks, basic_latin, s)
        s = remove_space_between(basic_latin, blocks, s)
        return s

    def normalize_neologd(self, s):
        s = s.strip()
        s = self.unicode_normalize('0-9A-Za-z。-゚', s)

        def maketrans(f, t):
            return {ord(x): ord(y) for x, y in zip(f, t)}

        s = re.sub('[˗֊‐‑‒–⁃⁻₋−]+', '-', s)  # normalize hyphens
        s = re.sub('[﹣-ー—―─━ー]+', '', s)  # normalize choonpus
        s = re.sub('[~∼∾〜〰~]', '', s)  # remove tildes
        s = s.translate(
            maketrans('!"#$%&\'()*+,-./:;<=>?@[¥]^_`{|}~。、・「」',
                  '!”#$%&’()*+,-./:;<=>?@[¥]^_`{|}〜。、・「」'))

        s = self.remove_extra_spaces(s)
        s = self.unicode_normalize('!”#$%&’()*+,-./:;<>?@[¥]^_`{|}〜', s)  # keep =,・,「,」
        s = re.sub('[’]', '\'', s)
        s = re.sub('[”]', '"', s)
        return s

    def remove_symbols(self, text):
        text = re.sub(r'[◎, 〇, △, ▲, ×, ◇, □]', '', text)
        return text
  • クリーニングの実行
cd = CleaningData(df, 'article')
df = cd.cleaning()

形態素解析

  • 形態素分析ライブラリーMeCab と 辞書(mecab-ipadic-NEologd)のインストール
!apt-get -q -y install sudo file mecab libmecab-dev mecab-ipadic-utf8 git curl python-mecab
!git clone --depth 1 https://github.com/neologd/mecab-ipadic-neologd.git
!echo yes | mecab-ipadic-neologd/bin/install-mecab-ipadic-neologd -n
!pip install mecab-python3
!ln -s /etc/mecabrc /usr/local/etc/mecabrc
  • 形態素解析クラスの定義
import MeCab

class Wakati:
    """ 形態素解析クラス """
    # クラス変数
    MECAB_PATH = "-d /usr/lib/x86_64-linux-gnu/mecab/dic/mecab-ipadic-neologd"

    def __init__(self, df, target_column):
        self.df = df
        self.target_column = target_column

    def wakati_document(self):
        self.df[self.target_column] = self.df[self.target_column].map(self.wakati_sentence)
        return self.df

    def wakati_sentence(self, text):
        tagger = MeCab.Tagger(Wakati.MECAB_PATH)
        words = []
        for c in tagger.parse(text).splitlines()[:-1]:
            #surfaceに単語、featureに解析結果が入る
            try:
                surface, feature = c.split('\t')
            except:
                continue
            pos = feature.split(',')[0]
            words.append(surface)
        return ' '.join(words)
  • 形態素解析の実行
w = Wakati(df, 'article')
df = w.wakati_document()

テキストのベクトル化

  • 日本語のままでは学習できないので、KerasのTokenizerクラスを用いて単語をベクトル化します。
  • 入力する文書の長さを揃える為、各文書の単語数がmax_lenに及ばないものはpad_sequencesにて0埋めしています。
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow import keras

class TensorFlowTokenizer:
    """ 単語にIDを振り、文章をベクトル化するクラス """
    def __init__(self, df, maxlen):
        self.df = df
        self.maxlen = maxlen

    def tokenizer(self):
        self.X_train, self.word_to_index, self.index_to_word = self.tokenize_column(list(self.df['article']))
        return self.X_train, self.word_to_index, self.index_to_word
        
    def tokenize_column(self, text_list):
        ''' 単語を数値に変換し、データ長を合わせるメソッド '''
        keras_tokenizer = Tokenizer()
        keras_tokenizer.fit_on_texts(text_list)

        print('学習した単語数:{}'.format(len(keras_tokenizer.word_index)))
        word_to_index = keras_tokenizer.word_index
        index_to_word = {}
        for word, index in word_to_index.items():
            index_to_word[index] = word

        # 文章を数値に変換
        text_vector = keras_tokenizer.texts_to_sequences(text_list)
        X_train = keras.preprocessing.sequence.pad_sequences(text_vector, maxlen=self.maxlen)

        return X_train, word_to_index, index_to_word
  • 分かち書きした文書のベクトル化
    • 今回は各文書の長さを500に固定しました。
tft = TensorFlowTokenizer(df, 500)
X, word_to_index, index_to_word = tft.tokenizer()

データセットの分割

  • scikit-learnのtrain_test_splitを用いて訓練データ、テストデータ、検証データに分割します。
import numpy as np
from sklearn.model_selection import train_test_split

y = np.array(list(df['one-hot']))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=100, stratify=y)
X_test, X_valid, y_test, y_valid = train_test_split(X_test, y_test, test_size=0.5, random_state=100, stratify=y_test)

モデルの定義

from keras.models import Sequential
from keras.layers import LSTM, Dense, Embedding, Dropout

vocabulary_size = len(word_to_index) + 1 

model = Sequential()

model.add(Embedding(input_dim=vocabulary_size, output_dim=256, input_length=500))
model.add(LSTM(128, return_sequences=False))
model.add(Dropout(0.5))
model.add(Dense(9, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.summary()

Model: "sequential_24"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 embedding_24 (Embedding)    (None, 500, 256)          25054464  
                                                                 
 lstm_24 (LSTM)              (None, 128)               197120    
                                                                 
 dropout_11 (Dropout)        (None, 128)               0         
                                                                 
 dense_20 (Dense)            (None, 9)                 1161      
                                                                 
=================================================================
Total params: 25,252,745
Trainable params: 25,252,745
Non-trainable params: 0

学習

from tensorflow.keras.callbacks import EarlyStopping
callbacks = [EarlyStopping(monitor='val_loss',
                           patience=5, # ここで指定したエポック数の間改善がないと停止
                           verbose=1,
                           mode='max')
            ]

history = model.fit(X_train, y_train, epochs=30, batch_size=32, validation_data=(X_valid, y_valid), callbacks=callbacks)

評価

  • 各ラベルに対する評価値は以下の通り。accuracyは92%でした。
from sklearn.metrics import classification_report

y_pred =  model.predict(X_test)
print(classification_report(np.argmax(y_test, axis=1), np.argmax(y_pred, axis=1)))

      precision    recall  f1-score   support

           0       0.99      0.80      0.89        87
           1       0.96      0.99      0.97        87
           2       0.97      0.95      0.96        87
           3       0.79      0.80      0.80        51
           4       0.88      0.97      0.92        87
           5       0.84      0.88      0.86        84
           6       1.00      0.99      0.99        87
           7       0.89      0.98      0.93        90
           8       0.99      0.88      0.93        77

    accuracy                           0.92       737
   macro avg       0.92      0.92      0.92       737
weighted avg       0.93      0.92      0.92       737
  • 混同行列はこんな感じでした、少しラベル3のデータが少なかったかも。学習データの偏りをなくせばもう少しよくなりそうですね。
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

sns.heatmap(confusion_matrix(np.argmax(y_test, axis=1), np.argmax(y_pred, axis=1)), annot=True)
plt.xlabel("pred")
plt.ylabel('true')

image.png

参考記事

5
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?