LoginSignup
36
34

More than 5 years have passed since last update.

単純なRNNを使ってズンドコキヨシを学習する

Posted at

すでにLSTMを使ってズンドコキヨシを実装している方がいますが、この程度であれば単純なRNNだけでも学習できるはずだと思い、RNNの理解を兼ねてChainerで実装してみました。

書籍「深層学習 (機械学習プロフェッショナルシリーズ)」(ISBN-13: 978-4061529021)の7.5によれば、RNNが記憶できるのは過去10時刻分程度とあります。「ズン」が4回、「ドコ」が1回出現するパターンを覚えれば良いので、この範囲に収まるはずです。

考え方

できるだけシンプルな構造を考えました。入力は2つ($x_1, x_2$)とします。中間層を一つ持ち、ユニット数を10とします。出力も2つ($y_1, y_2$)とします。
入力は以下のように定義します。

ズン → x_0 = 0, x_1 = 1 \\
ドコ → x_0 = 1, x_1 = 0

出力は以下のように定義します。単純な分類問題とします。

キヨシ不成立 → y_0 = 1, y_1 = 0 \\
キヨシ成立 → y_0 = 0, y_1 = 1

モデルの定義

コードによる定義は以下の通りになります。

import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L

class RNN(chainer.Chain):
    def __init__(self):
        super(RNN, self).__init__(
            w1 = L.Linear(2, 10),
            h1 = L.Linear(10, 10),
            o = L.Linear(10, 2)
            )
    def reset_state(self):
        self.last_z = chainer.Variable(np.zeros((1,10), dtype=np.float32))
    def __call__(self, x):
        z = F.relu(self.w1(x) + self.h1(self.last_z))
        self.last_z = z
        y = F.relu(self.o(z))
        return y

rnn = RNN()
rnn.reset_state()
model = L.Classifier(rnn)
optimizer = optimizers.Adam() # Adamを使う
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.GradientClipping(10.0)) #勾配の上限を設定

サイズを1固定のミニバッチ前提とし、初期値(last_z)をゼロとします。一度計算した隠れ層の値をself.last_zに保持しつつ、その結果を出力層に与えています。

学習

適度にランダムな系列データを生成し、それを学習させます。

ans_zundoko = [0, 0, 0, 0, 1] # 正解列
src_x_ary = [0, 0, 0, 1] # 乱数生成用配列 0 の出現率を1より高くする

def zd_gen(): # ジェネレータ
    x_ary = [0, 0, 0, 0, 1]
    y_ary = [0, 0, 0, 0, 1]
    while True:
        x = x_ary.pop(0)
        y = y_ary.pop(0)
        x = [0, 1] if x == 0 else [1, 0]
        yield x, y
        new_x = src_x_ary[np.random.randint(0, 4)] # 0〜2であれば0, 3であれば1とする
        x_ary.append(new_x)
        y_ary.append(1 if x_ary == ans_zundoko else 0) # x_aryが[0, 0, 0, 0, 1]の時だけ1

bprop_len = 40 # BPPTの打ち切り時刻
iter = 300 * 100 * 2 # 学習回数
loss = 0
i = 0
for xx, yy in zd_gen():
    x = chainer.Variable(np.asarray([xx], dtype=np.float32))
    t = chainer.Variable(np.asarray([yy], dtype=np.int32))
    loss += model(x, t)
    i += 1
    if i % bprop_len == 0:
        model.zerograds()
        loss.backward()
        loss.unchain_backward()
        optimizer.update()
        print("iter %d, loss %f, x %d, y %d" % (i, loss.data, xx[0], yy))
        loss = 0
    if i > iter:
        break

学習は時間がかかる上に、初期値によってうまくいったりいかなかったりします。最終的に損失が0.1を切るぐらいの値にならないとちゃんと動作しません。

オプティマイザをSGDに変えたり、bprop_lenを変えたりしてみても結果が変わってきます。ここで設定している値は手元でなんとなくうまく行ったケースを使っています。

評価

学習完了したモデルを評価します。入力列をランダムに生成しても良いですが、簡単にするため静的な評価データを用意してみました。

# ズン ズン ズン ズン ドコ ドコ ドコ ズン
x_data = [[0,1], [0,1], [0,1], [0,1], [1,0], [1,0], [1,0], [0,1]]

rnn.reset_state()
for xx in x_data:
    print('ズン' if xx[1] == 1 else 'ドコ')
    x = chainer.Variable(np.asarray([xx], dtype=np.float32))
    y = model.predictor(x)
    z = F.softmax(y, use_cudnn=False)
    if z.data[0].argmax() == 1: # 値の大きい方の配列添字が1の場合キヨシ成立
        print('キヨシ')

参考出力

うまく行った場合の出力を参考までに示しておきます。

iter 59520, loss 0.037670, x 1, y 0
iter 59560, loss 0.051628, x 0, y 0
iter 59600, loss 0.037519, x 0, y 0
iter 59640, loss 0.041894, x 0, y 0
iter 59680, loss 0.059143, x 0, y 0
iter 59720, loss 0.062305, x 0, y 0
iter 59760, loss 0.055293, x 0, y 0
iter 59800, loss 0.060964, x 1, y 1
iter 59840, loss 0.057446, x 1, y 0
iter 59880, loss 0.034730, x 1, y 0
iter 59920, loss 0.054435, x 0, y 0
iter 59960, loss 0.039648, x 0, y 0
iter 60000, loss 0.036578, x 0, y 0
ズン
ズン
ズン
ズン
ドコ
キヨシ
ドコ
ドコ
ズン

感想

自分の中で消化しきれていなかったRNNがこれでようやく理解できた気がします。最初、中間層のユニット数が少なかったりBPTTの打ち切り時刻を短くしすぎたりしてまったく思うように動かなかったのですが、いろいろと調整してようやく動作するようになりました。
世間にはLSTMと単語埋め込み表現を使った実例は多くあるのですが、もっと最小化した問題でやってみようと思い、ようやく実現できて満足です。

しかし、まだ運に頼る部分があるのが困りものです。

36
34
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
36
34