LoginSignup
13
14

More than 5 years have passed since last update.

Kerasメモ:時系列データのone-hotラベルをto_categoricalで生成

Last updated at Posted at 2017-10-06

動機

翻訳やチャットボットで盛り上がっているseq2seqをKerasで作ろうと思いました。

教師データとして入出力文章のペア(コーパス)を用意したら、コーパス内に登場する単語の一覧を取得し、一単語につき一つのラベルを与えてインデックス化します。
(今回は英語コーパスで取り組んだので、単語の分かち書きなどについては特に取り組んでいません。)

これを用いれば、単語の集まりである文章はラベルの集まりに変換できるので、seq2seqはラベルのシーケンスを入力し、ラベルのシーケンスを出力させるようなネットワーク構造になります。

すなわち、用意した語彙の中で、時刻毎にどのラベルを出力するべきかを学習する(クラス数が膨大な)他クラス分類問題になるのですが、このときKerasでは教師データの出力ラベルをone-hot vectorにする必要があります。
(Chainerなどでも一緒かな?)
one-hot vectorとは、例えばラベルが0~9の10通りであのとき、

label  0 1 2 3 4 5 6 7 8 9
1:    [0,1,0,0,0,0,0,0,0,0]
5:    [0,0,0,0,0,1,0,0,0,0]

といったように、ラベルのところだけ1で他は全て0という、クラス数の分だけの次元数を持つベクトルです。

これを作るためにKerasでは便利な関数to_categoricalがあるのですが、何も考えずに使ったら思った通りには動かなかったので、そのためのメモ。

to_categorical():ラベルをone hot vector化

from keras.utils import to_categorical
import numpy as np

data_num = 13
seq_len = 7
num_classes = 5

label = np.random.randint(5, size=[data_num, seq_len])
one_hot_label = to_categorical(label, num_classes=num_classes)

print(label.shape)          #(13, 7)
print(one_hot_label.shape)  #(91, 5)  ...(13, 7, 5)ではない

例として、

データ数(data_num): 13
シーケンス長(seq_len): 7
クラス数(num_classes): 5

としました。
num_classesは指定しない場合、クラス数を勝手に数えてくれるようなので、必ずしも必要ではないようです。
が、クラス数を明記した方が安全かなと思ったので書きました。

本当であれば、長さ5のベクトルが7つ集まってシーケンスとなり、13シーケンスあるので次元は(13, 7, 5)になって欲しかったのですが、長さ91のひとつながりのシーケンスになってしまいました。

というわけで、numpy.reshapeで所用のサイズにします。

from keras.utils import to_categorical
import numpy as np

data_num = 13
seq_len = 7
num_classes = 5

label = np.random.randint(num_classes, size=[data_num, seq_len])
one_hot_label = to_categorical(label,num_classes=num_classes).reshape(data_num, seq_len, num_classes)

print(label.shape)          #(13, 7)
print(one_hot_label.shape)  #(13, 7, 5)

結局はnumpyの話になってしまいしたが...

13
14
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
14