LoginSignup
29

More than 5 years have passed since last update.

KerasのSingle-LSTM文字生成サンプルコードを解説

Last updated at Posted at 2018-06-10

はじめに

ディープラーニングの学習がてらKerasのLSTMサンプルコードで遊んでみようと思ったのですが、
内容を読み解くのに意外と苦労したので、内容をまとめたものが皆さんの参考になればと残しておきます。

対象読者

ディープラーニング初心者向けです。

以下の方を対象としています。
* Pythonは分かる
* ディープラーニングは詳しくない
* LSTMネットワークの概要 は読んだ

詳しい方は突っ込み所があればコメント欄で教えてください。

サンプルコードについて

URL

こちらのサンプルコードです。
https://github.com/keras-team/keras/blob/master/examples/lstm_text_generation.py

(追記)
リファクタリングしたものを用意しました。
こちらのコードを読んだほうが分かりやすい・・・はず。
https://github.com/YankeeDeltaBravo225/lstm_text_generation_comment/blob/master/lstm_text_generation_refactored.py

概要

文字列を与えられると、次の文字が何になるかを予測し、m回続けてそれっぽい文字列を生成します。
これにLSTMの中でも一番ベーシックなSingle LSTMと言うモデルを用います。
model.png
例えば、"this is the piano made by a German meister" という文字があるとします。
これを8文字で区切るとすると以下のようになります。

  1. 入力:"this is ", 結果:"t"
  2. 入力:"his is t",結果: "h"
  3. 入力:"is is th",結果: "e"

Learn_target.png

入力データをこの8文字、教師データを次の文字として学習させることで、
ネットワークは"this is " と来たら次の文字は"t"だと予測できるようになります。

もう少し細かく言うと、学習・予測させるのは文章の次に続く文字そのものではなく出現率です。

Teacher_data.png

予測の際は、これを元に乱数で次の文字を決めます。
上記の例で行くと、"this is"ときた場合、85%で"t"が、9.8%で"a"が出現します。

これでどうやって文字列を生成するかですが、"this is a" と来たら、
学習時の8文字に合わせるため左端の"t"を削って"his is a"とし、
同じように次の文字を予測します。

以上延々と繰り返して行くことで、それっぽい文字列を生成します。

コードの解説

コメント付きのコードを用意したので、これを元に説明して行きます。
https://github.com/YankeeDeltaBravo225/lstm_text_generation_comment/blob/master/lstm_text_generation_Ja_comments.py

コードは大別して以下の5つの部分から成ります。

  1. テキストのダウンロード・読み込み
  2. 各々の字のdict作成
  3. 学習用データに整形
  4. モデル作成・フィッティング
  5. 文章生成

1. テキストのダウンロード・読み込み

ここは特に説明は要らないと思います。
テキストファイルをダウンロードして読み込み、大文字を小文字に置換しているだけです。

# ニーチェの文集をダウンロードする
# path : ダウンロードした先のパス
path = get_file('nietzsche.txt', origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')

# text : 入力ファイル
with io.open(path, encoding='utf-8') as f:
    text = f.read().lower()
print('corpus length:', len(text))

ちなみに、僕は以下のように引数で指定したローカルファイルから読み込むようにしていました。

# text : 入力ファイル
with io.open(sys.argv[1], encoding='utf-8') as f:
    text = f.read().replace('\n', '')

2. 各々の字のdict作成

文字はそのままだと扱いづらいので、index番号で扱います。

名前がちょっと分かりづらいですが、 char_indices は
字からcharsのインデックス番号を引くだけのdict。

indices_char は逆にcharsのインデックス番号から
字を引くためのリストです。

# chars : 重複を排除した「字」のリスト
chars = sorted(list(set(text)))
print('total chars:', len(chars))

# char_indices : 「字」を上記charsのindex番号に変換するdict
char_indices = dict((c, i) for i, c in enumerate(chars))

# indices_char : 上記と逆にindex番号を「字」に変換するdict
indices_char = dict((i, c) for i, c in enumerate(chars))

3. 学習用データに整形

先ほどの例で上げたとおり、特定の長さで区切った文章と次の文章のセットを作成します。
これまでは説明の都合上、8文字区切りとしていましたがコードでは40文字区切りとしています。
この40文字1セットの文字列を1つのシーケンスとして学習にかけます。

# cut the text in semi-redundant sequences of maxlen characters
# maxlen : いくつの「字」を1つの「文」とするか
maxlen = 40

# step : 開始位置のスキップ数
step = 3

# sentences  : 「文」のリスト
sentences = []

# next_chars : 各「文」について、その次の「文」の最初の「字」
next_chars = []

for i in range(0, len(text) - maxlen, step):
    # 単純に長さで区切った部分文字列を一つの文という扱いで抽出
    sentences.append(text[i: i + maxlen])

    # 次の文の最初の文字を保存
    next_chars.append(text[i + maxlen])

# 上記の「文」の数をそのままLSTMのsequence数として用いる
print('nb sequences:', len(sentences))

print('Vectorization...')

# x : np.bool型 3次元配列 [文の数, 文の最大長, 字の種類] ⇒ 文中の各位置に各indexの文字が出現するか
x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)

# y : np.bool型 2次元配列 [文の数, 字の種類]              ⇒ 次の文の開始文字のindex
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)

# vector化は各「文」について実施
for i, sentence in enumerate(sentences):
    for t, char in enumerate(sentence):
        x[i, t, char_indices[char]] = 1
    y[i, char_indices[next_chars[i]]] = 1

x, yがどんな形になるか分かりづらいので、1になる箇所をプロットしてみました。

  • x : 現在のシーケンス(sentence)
    bool_vector_x.png

  • y : 現在のシーケンス(sentence)の次に来る文字
    y bool_vector_y.png

なお、プロットに使用したコードは以下の通りです。

# import文
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 本文(3D)
# Plot three dimension bool vector
def plot_bool_vector_3d(vector, labels):
    # こちらは3Dでプロットするためのコード
    # http://ct-innovation01.hateblo.jp/entry/matplotlib05 で詳しく説明
    coordinates = [pos for pos, valid in np.ndenumerate(vector) if valid]

    xs, ys, zs = [], [], []
    for coordinate in coordinates:
        x, y, z = coordinate
        xs.append(x)
        ys.append(y)
        zs.append(z)

    # pyplotの使い方の話になるので、こちらのリファレンスも見ると良いかも
    # https://matplotlib.org/api/pyplot_summary.html
    ax = plt.figure().add_subplot(111, projection='3d')
    ax.scatter(xs, ys, zs)

    ax.set_xlabel(labels[0])
    ax.set_ylabel(labels[1])
    ax.set_zlabel(labels[2])

    plt.show()

# 本文(2D)
# Plot two dimension bool vector
def plot_bool_vector_2d(vector, labels):
    coordinates = [pos for pos, valid in np.ndenumerate(vector) if valid]

    xs, ys = [], []
    for coordinate in coordinates:
        x, y = coordinate
        xs.append(x)
        ys.append(y)

    ax = plt.figure().add_subplot(111)
    ax.scatter(xs, ys)

    ax.set_xlabel(labels[0])
    ax.set_ylabel(labels[1])

    plt.show()

# 使い方
plot_bool_vector_3d(x, ('sentence id', 'char pos in sentence', 'character id'))
plot_bool_vector_2d(y, ('sentence id', '1st character id in next sentence'))

4. モデル作成・フィッティング

Kerasにある既製品のLSTMを使用します。

# build the model: a single LSTM
print('Build model...')
model = Sequential()
model.add(LSTM(128, input_shape=(maxlen, len(chars))))
model.add(Dense(len(chars)))
model.add(Activation('softmax'))

# 勾配法にRMSpropを用いる
# 以下参照
# https://qiita.com/tokkuman/items/1944c00415d129ca0ee9

optimizer = RMSprop(lr=0.01)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)

これにフィッティングをかける訳ですが、各エポックの終了時に
その時点のモデルを使ったテキスト生成処理が呼ばれるようにCallbackを設定します。

# 各epoch終了時のcallbackとして、上記のon_epoch_endを呼ぶ
# 参照
# https://keras.io/ja/callbacks/#lambdacallback
print_callback = LambdaCallback(on_epoch_end=on_epoch_end)

# フィッティング実施、各epoch完了時に先述の on_epoch_end が呼ばれる
model.fit(x, y,
          batch_size=128,
          epochs=60,
          callbacks=[print_callback])

5. 文章生成

長いので、on_epoch_end と、そこから呼ばれる sample の二つの関数それぞれについて
解説して行きます。

on_epoch_end

このモデルの特性上、はじめに maxlen (=40) 文字の文章が必要になりますが、
元の入力テキストからランダムに選択し、それを変数 sentence に格納します。

この文字列は元の入力テキストから選ぶ必要はないのですが、学習時に使用した
文字種のみで構成されていないといけません。

また、diversityと言う言葉が出てきますが、sampleの方で詳しく説明します。

def on_epoch_end(epoch, logs):
    # Function invoked at end of each epoch. Prints generated text.
    print()
    print('----- Generating text after Epoch: %d' % epoch)

    # モデルは40文字の「文」からその次の「字」を予測するものであるため、
    # その元となる40文字の「文」を入力テキストからランダムに選ぶ
    start_index = random.randint(0, len(text) - maxlen - 1)

    # diversityとは多様性を意味する言葉
    # この値が低いとモデルの予測で出現率が高いとされた「字」がそのまま選ばれ、
    # 高ければそうでない「字」が選ばれる確率が高まる
    for diversity in [0.2, 0.5, 1.0, 1.2]:
        print('----- diversity:', diversity)

        generated = ''

        # 元にする「文」を選択
        sentence = text[start_index: start_index + maxlen]
        generated += sentence
        print('----- Generating with seed: "' + sentence + '"')
        sys.stdout.write(generated)

sentence の次に来る文字を予測し、その結果を sentence の末尾に加えると共に先頭文字を削ります。
これを延々繰り返すことで、それっぽい文字列を生成します。

        # 上記のランダムで選ばれた「文」に続く400個の「字」をモデルから予測し出力する
        for i in range(400):

            # 現在の「文」の中のどの位置に何の「字」があるかのテーブルを
            # フィッティング時に入力したxベクトルと同じフォーマットで生成
            # 最初の次元は「文」のIDなので0固定
            x_pred = np.zeros((1, maxlen, len(chars)))
            for t, char in enumerate(sentence):
                x_pred[0, t, char_indices[char]] = 1.

            # 現在の「文」に続く「字」を予測する
            preds = model.predict(x_pred, verbose=0)[0]
            next_index = sample(preds, diversity)
            next_char = indices_char[next_index]

            # 予測して得られた「字」を生成し、「文」に追加
            generated += next_char

            # モデル入力する「文」から最初の文字を削り、予測結果の「字」を追加
            # 例:sentence 「これはドイツ製」
            #     next_char 「の」
            #     ↓
            #     sentence 「れはドイツ製の」
            sentence = sentence[1:] + next_char

            sys.stdout.write(next_char)
            sys.stdout.flush()
        print()

sample

予測結果を元にして、文字を選びます。

# 各「字」の出現確率の配列(ndarray型)から、出力する文字を選ぶ
# 単純に一番確率の高いものを選ぶのではなく、出現率に従いランダムに選ぶ
#
# predsはモデルからの出力であり、多項分布の形になっているため、
# その総和は必ず 1.0 となる
#
#  preds       : モデルからの出力結果、float32型の多項分布が入ったndarray
#  temperature : 多様度、この値が高いほど preds 中の出現率が低いものが選ばれやすくなる
def sample(preds, temperature=1.0):
    # helper function to sample an index from a probability array

    # 64bit float型に変換
    preds = np.asarray(preds).astype('float64')

    # 確率の低く出た「字」が抽選で選ばれやすくなるようにゲタをはかせるため、
    # 自然対数を取った上、引数の値で割る
    # 参照
    # https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.log.html
    preds = np.log(preds) / temperature

    # 上記で確率の自然対数を取ったため、その逆変換である自然指数関数をとる
    # 参照
    # https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.exp.html
    exp_preds = np.exp(preds)

    # 多項分布の形に合わせるため、総和が1となるように全値を総和で割る
    preds = exp_preds / np.sum(exp_preds)

    # 多項分布に基づいた抽選を行う
    # 参照
    # https://docs.scipy.org/doc/numpy/reference/generated/numpy.random.multinomial.html
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)

on_epoch_end でdiversityという名前だった変数はここでは temperatureとなっています。
この数字は1の場合、予測結果そのままの確率で次の文字を選びますが、
大きい場合、出現率の低い文字にゲタを履かせます。

下の対数グラフの青線は予測結果そのまま、赤線はdiversity=1.2の時のものです。
出現率下位の文字が選ばれやすくなり、より様々な文章が生成されるようになります。

before_and_after_diversity.png

plotに使用したコードはこちら。

# import文
import matplotlib.pyplot as plt

# 本体
def visualize_preds(preds_comb_list, title):
    ax = plt.figure().add_subplot(1, 1, 1)

    for preds_comb in preds_comb_list:
        preds, color, label = preds_comb
        sorted = np.sort(preds)[::-1]
        ax.plot(sorted, color=color, linestyle='solid', label=label)

    plt.legend()
    plt.yscale("log")

    ax.set_title(title)
    ax.grid(True)

    plt.savefig(title + '.png')

#使い方
visualize_preds(
        [(preds_raw, 'blue', 'raw'), (preds_diversity, 'red', 'with diversity')],
        'before and after diversity'
    )

どんな文章が生成されるか

CNNの以下の記事を元にして文章を生成してみました。
https://edition.cnn.com/2018/06/07/politics/trump-g7-canada/index.html

コメントには最低100KBの文章が必要とありましたが、上記の記事のテキストは9KBしかありません。
しかし、ぱっと見ではそれっぽい文章が生成されています(ただし、変な単語が生成されてしまっていますが)。

over economic and other issues. Trump has Trump's uas Thumemees the conven inatweny begand denwattechnry hegrion 
Macron'ald theald leat wath cradeal at a formentely qulll quedwit war abd invoro has a wishe terepscous invint 
the meepart nunes baging his presill. he's s poping and Iprliner to ruride dengato-ound the Wresply ned 
Iopnsuimeling frradaad , pos hely begand dewink hew invorke heade pos and, propuretingsult wery the Walte tedd

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
29