LoginSignup
49
40

More than 5 years have passed since last update.

LSTMを使ってズンドコキヨシを学習する

Posted at

概要

LSTMを使ってズンドコキヨシを学習してみました。
Chainerを使って実装しています。
ちょっと前にあからさまに誤ったコードを投稿してしまったのですが、修正して再投稿します。

LSTMの説明は以下の投稿が詳しいです。

わかるLSTM ~ 最近の動向と共に

モデル

以下のようなモデルを構築します
* 入力として「ズン」または「ドコ」を受けつける
* 出力はNoneまたは「\キ・ヨ・シ!/」
* 「ズン」「ズン」「ズン」「ズン」「ドコ」をこの順で入力したら「\キ・ヨ・シ!/」を出力、それ以外の場合はNoneを出力するように学習する

コード

zundoko.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import chainer
from chainer import Variable, optimizers, functions as F, links as L

np.random.seed()
zun = 0
doko = 1
input_num = 2
input_words = ['ズン', 'ドコ']
none = 0
kiyoshi = 1
output_num = 2
output_words = [None, '\キ・ヨ・シ!/']
hidden_num = 8
update_iteration = 20

class Zundoko(chainer.Chain):
    def __init__(self):
        super(Zundoko, self).__init__(
            word=L.EmbedID(input_num, hidden_num),
            lstm=L.LSTM(hidden_num, hidden_num),
            linear=L.Linear(hidden_num, hidden_num),
            out=L.Linear(hidden_num, output_num),
        )

    def __call__(self, x, train=True):
        h1 = self.word(x)
        h2 = self.lstm(h1)
        h3 = F.relu(self.linear(h2))
        return self.out(h3)

    def reset_state(self):
        self.lstm.reset_state()

kiyoshi_list = [zun, zun, zun, zun, doko]
kiyoshi_pattern = 0
kiyoshi_mask = (1 << len(kiyoshi_list)) - 1
for token in kiyoshi_list:
    kiyoshi_pattern = (kiyoshi_pattern << 1) | token

zundoko = Zundoko()
for param in zundoko.params():
    data = param.data
    data[:] = np.random.uniform(-1, 1, data.shape)
optimizer = optimizers.Adam(alpha=0.01)
optimizer.setup(zundoko)

def forward(train=True):
    loss = 0
    acc = 0
    if train:
        batch_size = 20
    else:
        batch_size = 1
    recent_pattern = np.zeros((batch_size,), dtype=np.int32)
    zundoko.reset_state()
    for i in range(200):
        x = np.random.randint(0, input_num, batch_size).astype(np.int32)
        y_var = zundoko(Variable(x, volatile=not train), train=train)
        recent_pattern = ((recent_pattern << 1) | x) & kiyoshi_mask
        if i < len(kiyoshi_list):
            t = np.full((batch_size,), none, dtype=np.int32)
        else:
            t = np.where(recent_pattern == kiyoshi_pattern, kiyoshi, none).astype(np.int32)
        loss += F.softmax_cross_entropy(y_var, Variable(t, volatile=not train))
        acc += float(F.accuracy(y_var, Variable(t, volatile=not train)).data)
        if not train:
            print input_words[x[0]]
            y = np.argmax(y_var.data[0])
            if output_words[y] != None:
                print output_words[y]
                break
        if train and (i + 1) % update_iteration == 0:
            optimizer.zero_grads()
            loss.backward()
            loss.unchain_backward()
            optimizer.update()
            print 'train loss: {} accuracy: {}'.format(loss.data, acc / update_iteration)
            loss = 0
            acc = 0

for iteration in range(20):
    forward()

forward(train=False)

出力例

train loss: 18.4753189087 accuracy: 0.020000000298
train loss: 16.216506958 accuracy: 0.0325000006706
train loss: 15.0742883682 accuracy: 0.0350000008941
train loss: 13.9205350876 accuracy: 0.385000001639
train loss: 12.5977449417 accuracy: 0.96249999404
(中略)
train loss: 0.00433994689956 accuracy: 1.0
train loss: 0.00596862798557 accuracy: 1.0
train loss: 0.0027643663343 accuracy: 1.0
train loss: 0.011038181372 accuracy: 1.0
train loss: 0.00512072304264 accuracy: 1.0
ズン
ズン
ズン
ドコ
ドコ
ドコ
ズン
ズン
ズン
ドコ
ドコ
ドコ
ドコ
ドコ
ドコ
ズン
ドコ
ドコ
ズン
ドコ
ドコ
ズン
ドコ
ズン
ズン
ズン
ズン
ズン
ズン
ズン
ズン
ドコ
\キ・ヨ・シ!/

少しはまったところ

最初dropoutを使っていたのですが、そうすると学習がうまくいかず出力がほぼNoneのみになってしまいました。

49
40
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
49
40