KerasのSingle-LSTM文字生成サンプルコードを解説

More than 1 year has passed since last update.


はじめに

ディープラーニングの学習がてらKerasのLSTMサンプルコードで遊んでみようと思ったのですが、

内容を読み解くのに意外と苦労したので、内容をまとめたものが皆さんの参考になればと残しておきます。


対象読者

ディープラーニング初心者向けです。

以下の方を対象としています。

* Pythonは分かる

* ディープラーニングは詳しくない

* LSTMネットワークの概要 は読んだ

詳しい方は突っ込み所があればコメント欄で教えてください。


サンプルコードについて


URL

こちらのサンプルコードです。

https://github.com/keras-team/keras/blob/master/examples/lstm_text_generation.py

(追記)

リファクタリングしたものを用意しました。

こちらのコードを読んだほうが分かりやすい・・・はず。

https://github.com/YankeeDeltaBravo225/lstm_text_generation_comment/blob/master/lstm_text_generation_refactored.py


概要

文字列を与えられると、次の文字が何になるかを予測し、m回続けてそれっぽい文字列を生成します。

これにLSTMの中でも一番ベーシックなSingle LSTMと言うモデルを用います。

model.png

例えば、"this is the piano made by a German meister" という文字があるとします。

これを8文字で区切るとすると以下のようになります。


  1. 入力:"this is ", 結果:"t"

  2. 入力:"his is t",結果: "h"

  3. 入力:"is is th",結果: "e"

Learn_target.png

入力データをこの8文字、教師データを次の文字として学習させることで、

ネットワークは"this is " と来たら次の文字は"t"だと予測できるようになります。

もう少し細かく言うと、学習・予測させるのは文章の次に続く文字そのものではなく出現率です。

Teacher_data.png

予測の際は、これを元に乱数で次の文字を決めます。

上記の例で行くと、"this is"ときた場合、85%で"t"が、9.8%で"a"が出現します。

これでどうやって文字列を生成するかですが、"this is a" と来たら、

学習時の8文字に合わせるため左端の"t"を削って"his is a"とし、

同じように次の文字を予測します。

以上延々と繰り返して行くことで、それっぽい文字列を生成します。


コードの解説

コメント付きのコードを用意したので、これを元に説明して行きます。

https://github.com/YankeeDeltaBravo225/lstm_text_generation_comment/blob/master/lstm_text_generation_Ja_comments.py

コードは大別して以下の5つの部分から成ります。


  1. テキストのダウンロード・読み込み

  2. 各々の字のdict作成

  3. 学習用データに整形

  4. モデル作成・フィッティング

  5. 文章生成


1. テキストのダウンロード・読み込み

ここは特に説明は要らないと思います。

テキストファイルをダウンロードして読み込み、大文字を小文字に置換しているだけです。

# ニーチェの文集をダウンロードする

# path : ダウンロードした先のパス
path = get_file('nietzsche.txt', origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')

# text : 入力ファイル
with io.open(path, encoding='utf-8') as f:
text = f.read().lower()
print('corpus length:', len(text))

ちなみに、僕は以下のように引数で指定したローカルファイルから読み込むようにしていました。

# text : 入力ファイル

with io.open(sys.argv[1], encoding='utf-8') as f:
text = f.read().replace('\n', '')


2. 各々の字のdict作成

文字はそのままだと扱いづらいので、index番号で扱います。

名前がちょっと分かりづらいですが、 char_indices は

字からcharsのインデックス番号を引くだけのdict。

indices_char は逆にcharsのインデックス番号から

字を引くためのリストです。

# chars : 重複を排除した「字」のリスト

chars = sorted(list(set(text)))
print('total chars:', len(chars))

# char_indices : 「字」を上記charsのindex番号に変換するdict
char_indices = dict((c, i) for i, c in enumerate(chars))

# indices_char : 上記と逆にindex番号を「字」に変換するdict
indices_char = dict((i, c) for i, c in enumerate(chars))


3. 学習用データに整形

先ほどの例で上げたとおり、特定の長さで区切った文章と次の文章のセットを作成します。

これまでは説明の都合上、8文字区切りとしていましたがコードでは40文字区切りとしています。

この40文字1セットの文字列を1つのシーケンスとして学習にかけます。

# cut the text in semi-redundant sequences of maxlen characters

# maxlen : いくつの「字」を1つの「文」とするか
maxlen = 40

# step : 開始位置のスキップ数
step = 3

# sentences : 「文」のリスト
sentences = []

# next_chars : 各「文」について、その次の「文」の最初の「字」
next_chars = []

for i in range(0, len(text) - maxlen, step):
# 単純に長さで区切った部分文字列を一つの文という扱いで抽出
sentences.append(text[i: i + maxlen])

# 次の文の最初の文字を保存
next_chars.append(text[i + maxlen])

# 上記の「文」の数をそのままLSTMのsequence数として用いる
print('nb sequences:', len(sentences))

print('Vectorization...')

# x : np.bool型 3次元配列 [文の数, 文の最大長, 字の種類] ⇒ 文中の各位置に各indexの文字が出現するか
x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)

# y : np.bool型 2次元配列 [文の数, 字の種類] ⇒ 次の文の開始文字のindex
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)

# vector化は各「文」について実施
for i, sentence in enumerate(sentences):
for t, char in enumerate(sentence):
x[i, t, char_indices[char]] = 1
y[i, char_indices[next_chars[i]]] = 1

x, yがどんな形になるか分かりづらいので、1になる箇所をプロットしてみました。


  • x : 現在のシーケンス(sentence)

    bool_vector_x.png


  • y : 現在のシーケンス(sentence)の次に来る文字

    y bool_vector_y.png


なお、プロットに使用したコードは以下の通りです。

# import文

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 本文(3D)
# Plot three dimension bool vector
def plot_bool_vector_3d(vector, labels):
# こちらは3Dでプロットするためのコード
# http://ct-innovation01.hateblo.jp/entry/matplotlib05 で詳しく説明
coordinates = [pos for pos, valid in np.ndenumerate(vector) if valid]

xs, ys, zs = [], [], []
for coordinate in coordinates:
x, y, z = coordinate
xs.append(x)
ys.append(y)
zs.append(z)

# pyplotの使い方の話になるので、こちらのリファレンスも見ると良いかも
# https://matplotlib.org/api/pyplot_summary.html
ax = plt.figure().add_subplot(111, projection='3d')
ax.scatter(xs, ys, zs)

ax.set_xlabel(labels[0])
ax.set_ylabel(labels[1])
ax.set_zlabel(labels[2])

plt.show()

# 本文(2D)
# Plot two dimension bool vector
def plot_bool_vector_2d(vector, labels):
coordinates = [pos for pos, valid in np.ndenumerate(vector) if valid]

xs, ys = [], []
for coordinate in coordinates:
x, y = coordinate
xs.append(x)
ys.append(y)

ax = plt.figure().add_subplot(111)
ax.scatter(xs, ys)

ax.set_xlabel(labels[0])
ax.set_ylabel(labels[1])

plt.show()

# 使い方
plot_bool_vector_3d(x, ('sentence id', 'char pos in sentence', 'character id'))
plot_bool_vector_2d(y, ('sentence id', '1st character id in next sentence'))


4. モデル作成・フィッティング

Kerasにある既製品のLSTMを使用します。

# build the model: a single LSTM

print('Build model...')
model = Sequential()
model.add(LSTM(128, input_shape=(maxlen, len(chars))))
model.add(Dense(len(chars)))
model.add(Activation('softmax'))

# 勾配法にRMSpropを用いる
# 以下参照
# https://qiita.com/tokkuman/items/1944c00415d129ca0ee9

optimizer = RMSprop(lr=0.01)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)

これにフィッティングをかける訳ですが、各エポックの終了時に

その時点のモデルを使ったテキスト生成処理が呼ばれるようにCallbackを設定します。

# 各epoch終了時のcallbackとして、上記のon_epoch_endを呼ぶ

# 参照
# https://keras.io/ja/callbacks/#lambdacallback
print_callback = LambdaCallback(on_epoch_end=on_epoch_end)

# フィッティング実施、各epoch完了時に先述の on_epoch_end が呼ばれる
model.fit(x, y,
batch_size=128,
epochs=60,
callbacks=[print_callback])


5. 文章生成

長いので、on_epoch_end と、そこから呼ばれる sample の二つの関数それぞれについて

解説して行きます。


on_epoch_end

このモデルの特性上、はじめに maxlen (=40) 文字の文章が必要になりますが、

元の入力テキストからランダムに選択し、それを変数 sentence に格納します。

この文字列は元の入力テキストから選ぶ必要はないのですが、学習時に使用した

文字種のみで構成されていないといけません。

また、diversityと言う言葉が出てきますが、sampleの方で詳しく説明します。

def on_epoch_end(epoch, logs):

# Function invoked at end of each epoch. Prints generated text.
print()
print('----- Generating text after Epoch: %d' % epoch)

# モデルは40文字の「文」からその次の「字」を予測するものであるため、
# その元となる40文字の「文」を入力テキストからランダムに選ぶ
start_index = random.randint(0, len(text) - maxlen - 1)

# diversityとは多様性を意味する言葉
# この値が低いとモデルの予測で出現率が高いとされた「字」がそのまま選ばれ、
# 高ければそうでない「字」が選ばれる確率が高まる
for diversity in [0.2, 0.5, 1.0, 1.2]:
print('----- diversity:', diversity)

generated = ''

# 元にする「文」を選択
sentence = text[start_index: start_index + maxlen]
generated += sentence
print('----- Generating with seed: "' + sentence + '"')
sys.stdout.write(generated)

sentence の次に来る文字を予測し、その結果を sentence の末尾に加えると共に先頭文字を削ります。

これを延々繰り返すことで、それっぽい文字列を生成します。

        # 上記のランダムで選ばれた「文」に続く400個の「字」をモデルから予測し出力する

for i in range(400):

# 現在の「文」の中のどの位置に何の「字」があるかのテーブルを
# フィッティング時に入力したxベクトルと同じフォーマットで生成
# 最初の次元は「文」のIDなので0固定
x_pred = np.zeros((1, maxlen, len(chars)))
for t, char in enumerate(sentence):
x_pred[0, t, char_indices[char]] = 1.

# 現在の「文」に続く「字」を予測する
preds = model.predict(x_pred, verbose=0)[0]
next_index = sample(preds, diversity)
next_char = indices_char[next_index]

# 予測して得られた「字」を生成し、「文」に追加
generated += next_char

# モデル入力する「文」から最初の文字を削り、予測結果の「字」を追加
# 例:sentence 「これはドイツ製」
# next_char 「の」
# ↓
# sentence 「れはドイツ製の」
sentence = sentence[1:] + next_char

sys.stdout.write(next_char)
sys.stdout.flush()
print()


sample

予測結果を元にして、文字を選びます。

# 各「字」の出現確率の配列(ndarray型)から、出力する文字を選ぶ

# 単純に一番確率の高いものを選ぶのではなく、出現率に従いランダムに選ぶ
#
# predsはモデルからの出力であり、多項分布の形になっているため、
# その総和は必ず 1.0 となる
#
# preds : モデルからの出力結果、float32型の多項分布が入ったndarray
# temperature : 多様度、この値が高いほど preds 中の出現率が低いものが選ばれやすくなる
def sample(preds, temperature=1.0):
# helper function to sample an index from a probability array

# 64bit float型に変換
preds = np.asarray(preds).astype('float64')

# 確率の低く出た「字」が抽選で選ばれやすくなるようにゲタをはかせるため、
# 自然対数を取った上、引数の値で割る
# 参照
# https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.log.html
preds = np.log(preds) / temperature

# 上記で確率の自然対数を取ったため、その逆変換である自然指数関数をとる
# 参照
# https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.exp.html
exp_preds = np.exp(preds)

# 多項分布の形に合わせるため、総和が1となるように全値を総和で割る
preds = exp_preds / np.sum(exp_preds)

# 多項分布に基づいた抽選を行う
# 参照
# https://docs.scipy.org/doc/numpy/reference/generated/numpy.random.multinomial.html
probas = np.random.multinomial(1, preds, 1)
return np.argmax(probas)

on_epoch_end でdiversityという名前だった変数はここでは temperatureとなっています。

この数字は1の場合、予測結果そのままの確率で次の文字を選びますが、

大きい場合、出現率の低い文字にゲタを履かせます。

下の対数グラフの青線は予測結果そのまま、赤線はdiversity=1.2の時のものです。

出現率下位の文字が選ばれやすくなり、より様々な文章が生成されるようになります。

before_and_after_diversity.png

plotに使用したコードはこちら。

# import文

import matplotlib.pyplot as plt

# 本体
def visualize_preds(preds_comb_list, title):
ax = plt.figure().add_subplot(1, 1, 1)

for preds_comb in preds_comb_list:
preds, color, label = preds_comb
sorted = np.sort(preds)[::-1]
ax.plot(sorted, color=color, linestyle='solid', label=label)

plt.legend()
plt.yscale("log")

ax.set_title(title)
ax.grid(True)

plt.savefig(title + '.png')

#使い方
visualize_preds(
[(preds_raw, 'blue', 'raw'), (preds_diversity, 'red', 'with diversity')],
'before and after diversity'
)


どんな文章が生成されるか

CNNの以下の記事を元にして文章を生成してみました。

https://edition.cnn.com/2018/06/07/politics/trump-g7-canada/index.html

コメントには最低100KBの文章が必要とありましたが、上記の記事のテキストは9KBしかありません。

しかし、ぱっと見ではそれっぽい文章が生成されています(ただし、変な単語が生成されてしまっていますが)。

over economic and other issues. Trump has Trump's uas Thumemees the conven inatweny begand denwattechnry hegrion 

Macron'ald theald leat wath cradeal at a formentely qulll quedwit war abd invoro has a wishe terepscous invint
the meepart nunes baging his presill. he's s poping and Iprliner to ruride dengato-ound the Wresply ned
Iopnsuimeling frradaad , pos hely begand dewink hew invorke heade pos and, propuretingsult wery the Walte tedd