3
4

More than 5 years have passed since last update.

只今勉強中!機械学習 多クラス分類の例:ニュース配信の分類

Posted at

多クラス分類

前回は 只今勉強中!二値分類のチュートリアル で入力ベクトルが2つのデータを排他クラスに分類する方法を勉強しました。
今回は入力ベクトルが3つ以上あるデータの分類問題に取り組みたいと思います。

ニュース配信の分類

Kerasライブラリに組み込まれている多クラス分類のチュートリアル
Reutersのニュース配信を46種類の相互排他なトピック(クラス)に分類する問題。
※相互排他:各データは1つのカテゴリに分類するケースで多クラス単一ラベル分類(single-label, multiclass classification)問題と呼ばれる。
多クラス多ラベル分類(multi-label, multiclass classification)問題というのもある。

サンプルデータを読み込む

tutorial.py
# Reuters データセットを読み込む
from keras.datasets import reuters

(train_data, train_labels), (test_data, test_labels) = reuters.load_data(num_words=10000)

#サンプル出力
print(len(train_data))
>>> 8982
print(len(test_data))
>>> 2246
print(train_data[0])
>>> [1, 2, 2, 8, 43, 10, 447, 5, 25, 207, 270, 5, 3095, 111, 16, 369, 186, 90, 67, 7, 89, 5, 19, 102, 6, 19, 124, 15, 90, 67, 84, 22, 482, 26, 7, 48, 4, 49, 8, 864, 39, 209, 154, 6, 151, 6, 83, 11, 15, 22, 155, 11, 15, 7, 48, 9, 4579, 1005, 504, 6, 258, 6, 272, 11, 15, 22, 134, 44, 11, 15, 16, 8, 197, 1245, 90, 67, 52, 29, 209, 30, 32, 132, 6, 109, 15, 17, 12]

訓練データの中身

とりあえず訓練データの中身がどうなっているか確認してみる

tutorial.py
# 訓練データを単語にデコード
word_index = reuters.get_word_index()
# 単語リストの中身
print(word_index.items())
>>> dict_items([('mdbl', 10996), ('fawc', 16260), ('degussa', 12089), ('woods', 8803), ('hanging', 13796), ('localized', 20672), ('sation', 20673), ('chanthaburi' (以下省略...

# { value : key} 型でデータを取得
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])

# 0:パディング, 1:シーケンス, 2:不明の予約語
decoded_newswire = ' '.join([reverse_word_index.get(i - 3, '?') for i in train_data[0]])
# ニュース内容を表示
print(decoded_newswire)
>>> ? ? ? said as a result of its december acquisition of space co it expects earnings per share in 1987 of 1 15 to 1 30 dlrs per share up from 70 cts in 1986 the company said pretax net should rise to nine to 10 mln dlrs from six mln dlrs in 1986 and rental operation revenues to 19 to 22 mln dlrs from 12 5 mln dlrs it said cash flow per share this year should be 2 50 to three dlrs reuter 3
# ラベルの中身
print(train_labels[0])
>>> 3

機械学習演習

データの中身を確認できたところで多クラス単一ラベル分類問題に取り組みたいと思います。
まずは訓練データとテストデータを機械学習用に変換します。
データ構造が10000 * 10000の多クラス単一ラベル分類問題なので二次元テンソル(samples, features)モデルを使います。
ということでOne-Hot型に変換します

tutorial.py
# データの準備
import numpy as np

# One-Hot エンコーディング
def convert_to_one_hot(sequences, dimension=10000):
    """
    2次元テンソルデータをOne-Hotエンコーディングする関数
    Parameters
    ----------
    sequences : 行列データ
                2次元テンソルデータ※one-hotエンコーディング対象データ
    dimension : 要素数(ベクトル数)
    """
    # 0埋めの行列を作成
    results = np.zeros((len(sequences), dimension))

    for i, sequence in enumerate(sequences):
        # i行目 sequence列に1を立てる
        results[i, sequence] = 1.
        return results

訓練データをベクトルデータへ変換

One-Hotエンコーディング関数を使ってベクトルデータへ変換

tutorial.py
# データをOne-Hotエンコーディングします
x_train = convert_to_one_hot(train_data, 10000)
x_test  = convert_to_one_hot(test_data, 10000)

# ラベルをOne-Hotエンコーディング
convert_train_labels = convert_to_one_hot(train_labels, 46)
convert_test_labels  = convert_to_one_hot(test_labels, 46)

"""
tips
kerasパッケージでもカテゴリデータの作成はできる
from keras.utils.np_utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels  = to_categorical(test_labels)
"""

ニューラルネットワークの構築とモデルの定義

モデルを設計します
レイヤー層のタイプ、出入力数、活性化関数のタイプ、そして層の数を設計します

tutorial.py
# ニューラルネットワークの構築とモデルの定義
from keras import models
from keras import layers

model = models.Sequential()
model.add(layers.Dense(64, activation='relu', input_shape=(10000,)))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(46, activation='softmax'))

訓練の妥当性の検証のためのデータを小分けします

tutorial.py
# 検証データセットの設定

# データを訓練用と検証用に分割
x_val = x_train[:1000]
partial_x_train = x_train[1000:]

# ラベルを訓練用と検証用に分割
y_val = convert_train_labels[:1000]
partial_y_train = convert_train_labels[1000:]

モデルの実装

tutorial.py
from keras import optimizers
from keras import losses
from keras import metrics

# モデルを実装
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

# 訓練の実施
# モデルの訓練
history = model.fit(partial_x_train, partial_y_train, epochs=8, batch_size=512, validation_data=(x_val, y_val))

# テスト結果を取得
results = model.evaluate(x_test, convert_test_labels)
print(results)
# 訓練
Train on 7982 samples, validate on 1000 samples
Epoch 1/8
7982/7982 [==============================] - 1s 157us/step - loss: 0.0000e+00 - acc: 1.0000 - val_loss: 0.0038 - val_acc: 0.9990
Epoch 2/8
7982/7982 [==============================] - 1s 78us/step - loss: 0.0000e+00 - acc: 1.0000 - val_loss: 0.0038 - val_acc: 0.9990
Epoch 3/8
7982/7982 [==============================] - 1s 77us/step - loss: 0.0000e+00 - acc: 1.0000 - val_loss: 0.0038 - val_acc: 0.9990
Epoch 4/8
7982/7982 [==============================] - 1s 82us/step - loss: 0.0000e+00 - acc: 1.0000 - val_loss: 0.0038 - val_acc: 0.9990
Epoch 5/8
7982/7982 [==============================] - 1s 79us/step - loss: 0.0000e+00 - acc: 1.0000 - val_loss: 0.0038 - val_acc: 0.9990
Epoch 6/8
7982/7982 [==============================] - 1s 81us/step - loss: 0.0000e+00 - acc: 1.0000 - val_loss: 0.0038 - val_acc: 0.9990
Epoch 7/8
7982/7982 [==============================] - 1s 80us/step - loss: 0.0000e+00 - acc: 1.0000 - val_loss: 0.0038 - val_acc: 0.9990
Epoch 8/8
7982/7982 [==============================] - 1s 83us/step - loss: 0.0000e+00 - acc: 1.0000 - val_loss: 0.0038 - val_acc: 0.9990
2246/2246 [==============================] - 0s 119us/step
# テストデータの予測値を出力 (loss, acc)
[0.0016826143884701274, 0.9995547640249333]

振り返り

演習をやってるのだけれど、訓練中の損失値、正答率、検証時の損失値、正答率がおかしい気がする・・・

まとめ

  • 最後の層について
    多クラス単一ラベル分類問題では最後の層の活性化関数としてソフトマックスを使用すると良い。 ※Softmax関数:N個の出力クラスに対する0 ~ 1の確率分布を出力。合計で1になる。 ※シグモイド関数:各クラスが0~1の間の値をとる。
  • ラベルのエンコーディング
    1 ラベルをOne-Hotエンコーディングを用いてエンコードし、損失関数にcategorical_crossentropyを使用する。
    2 ラベルを整数でエンコードした場合は、損失関数にsparse_categorical_crossentropyを使用する。
  • 中間層 分類するカテゴリの数が多い場合は中間層が小さすぎるとネットワークに情報ボトルネックが生じる可能性がある。
3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4