LoginSignup
42
40

More than 5 years have passed since last update.

メモリを操作するRNNでソートアルゴリズム(可変長&順序フラグあり)を機械学習できたよっ!

Posted at

微分可能な神経機械?

Google DeepMindがNatureに投稿した論文「Hybrid computing using a neural network with dynamic external memory」が、なんだかヤバそうな香りがします。
公式の紹介記事「Differentiable neural computers」では、プラトンの記憶論から話が始まりますし、論文では脳の記憶を司る海馬に喩えていたりして、なかなか格調高いです。単なるニューラルネットワークの性能改善に留まらず、哲学や神経科学の観点からも理想の人工知能に一歩近づくことができたよ、これは新しいコンピュータの在り方の発明なのではないか、という気概が感じられます。

仕組みとしては流行りのAttentionという概念が入っていて、メモリを表す行列と、それを選択的に操作しながらベクトルを入出力するコントローラがあります。コントローラはRNN(再帰的ニューラルネットワーク)になっていて、メモリ操作も微分可能です。
そのため、いかにもチューリングマシンぽく、データ構造を持つ一般的なアルゴリズムを機械学習できるということのだと思います。たぶん。論文中の実験では、グラフの経路探索を高精度で学習できたそうです。なにそれ凄い。

と衝撃を受けつつ、まだ論文の数式ちゃんと読めてないのですが、さっそく「DNC (Differentiable Neural Computers) の概要 + Chainer による実装」と実装した方がいらっしゃるので、遊んでみようと思います。
元コードでは、エコーを学習していたので、発展的にソートを学習させてみましょう。

結果

それでは、2万サンプル学習させた後にテストした結果が、こちらになります。はい、全問正解ですね。

# 20000

in  [7, 4] ASC
ans [4, 7]
out [4, 7] 1.0

in  [1, 6, 3] ASC
ans [1, 3, 6]
out [1, 3, 6] 1.0

in  [2, 6, 1, 2] ASC
ans [1, 2, 2, 6]
out [1, 2, 2, 6] 1.0

in  [1, 0, 6, 1, 7] DESC
ans [7, 6, 1, 1, 0]
out [7, 6, 1, 1, 0] 1.0

in  [5, 1, 2, 1, 4, 6] ASC
ans [1, 1, 2, 4, 5, 6]
out [1, 1, 2, 4, 5, 6] 1.0

in  [5, 4, 3, 5, 5, 8, 5] DESC
ans [8, 5, 5, 5, 5, 4, 3]
out [8, 5, 5, 5, 5, 4, 3] 1.0

in  [9, 3, 6, 9, 5, 2, 9, 5] ASC
ans [2, 3, 5, 5, 6, 9, 9, 9]
out [2, 3, 5, 5, 6, 9, 9, 9] 1.0

in  [0, 6, 5, 8, 4, 6, 0, 8, 0] ASC
ans [0, 0, 0, 4, 5, 6, 6, 8, 8]
out [0, 0, 0, 4, 5, 6, 6, 8, 8] 1.0

「in」が入力列と、ソートする順序です。「ASC」が昇順、「DESC」が降順。
「ans」が正解で、「out」が出力列、それから正解との一致率です。
この結果から、DNCがソートを学習、しかも順序フラグがあって列の長さが可変なソートを学習できたことが分かります。

そして訓練データには、列の長さが2から8までのものしか与えてないのですが、テストでは長さが9の列もソートできているため、列の長さについての汎用化にも成功しているっぽいんですよね。これは驚きました。

パラメータ&モデル

DNCの設定パラメータは、こんな感じにしてみました。

パラメータ
入力ベクトルのサイズ 12
出力ベクトルのサイズ 10
メモリの数 20
メモリ一つのサイズ 12
リードヘッダの数 2

コントローラのRNNは元コードのままで、1層がLSTM Layer(in:36 out:83)、2層がFully Connected Layer(in:83 out:83)、のためディープではないです。
このあたり何のチューニングもしていないので、色々と改善できると思います。むしろ何もチューニングせず、一発であっさり学習できちゃったのでビビりました。

エンコード&デコード

in  [1, 6, 3] ASC
ans [1, 3, 6]
out [1, 3, 6] 1.0

これは実際には、次のような系列でDNCに入出力しています。
まず順序フラグを入れ、数列を入れおわると、ソート結果を吐きだす流れです。

実行ステップ 1 2 3 4 5 6 7
入力 ASC 1 6 3 - - -
出力 x x x x 1 3 6

正確には、入力はone-hotエンコーディングしています。「-」はゼロベクトルです。
出力は一番大きいベクトル要素のインデックスを採用します。「x」は無視します。
このあたりも、元コードのままです。

DNCとしては、ソートする数列の長さはゼロベクトルが来て初めて分かるようになっていて、もうそのステップでは最小もしくは最大の数を出力しないといけません。
学習前は、入力から出力まで猶予ステップをあげないとダメかなと予想していたのですが、とくに問題ありませんでした。かしこい。

ポエム

近年のニューラールネットワークの発展と応用は目覚ましく、それでもまだ画像とか自然言語といった特定領域でのパターン認識の化け物でしょ、真の人工知能にはまだまだ遠いよね。え、DCGANで二次元イラストが生成できた?(Chainerを使ってコンピュータにイラストを描かせる)、LSTMでシェークスピアっぽい文章が生成できた?(The Unreasonable Effectiveness of Recurrent Neural Networks)、うん遠目に見れば騙されちゃうかもね。すごいすごい。

……くらいには高を括っていたのですが、そろそろホントに知性を獲得しはじめてるんじゃないかな、という妄想が過ってきました。もし本当にDNCのようなアプローチで、データ構造とアルゴリズムの汎用的な機械学習ができるのだとしたら、プログラマの仕事けっこう奪われるのではという漠然とした畏れがあり、ああ人は自らの知性があると信じているところを自動化されると怯えるのか、産業革命で機械を打ち壊した職人のこと笑えないなぁ、っていう。。。
いやまぁ実用的には、今回のソートのようなものにDNCを持ちだすのは明らかにオーバースペックで、フリーランチ定理が教えるとおり、数理最適化などの知見を活かしてがっつり性能だせるタスクは、人がプログラムするのが一番良いと思いますが。タスクのジャンルによっては、「プログラムするな訓練せよ」の時代が来るのかもしれませんねっ。

コード

yos1up / DNC / main.py | GitHub

のファイル実行部分を、下記に差し替えました。
Chainer知らずの見様見真似コードなので、おかしなところがあったら、ご指摘くださいm(_ _)m

len_number = 10
X = len_number + 2
Y = len_number
N = 20
W = X
R = 2
mdl = DNC(X, Y, N, W, R)
opt = optimizers.Adam()
opt.setup(mdl)


def run(low, high, train=False):
    order = np.random.randint(0, 2)
    content = [np.random.randint(0, len_number) for _ in range(np.random.randint(low, high))]
    sorted_content = sorted(content, reverse=(order != 0))
    x_seq_list = map(lambda i: onehot(i, X), [len_number + order] + content) + [np.zeros(X).astype(np.float32)] * len(content)
    t_seq_list = ([None] * (1 + len(content))) + map(lambda i: onehot(i, Y), sorted_content)

    result = []
    loss = 0.0

    for (x_seq, t_seq) in zip(x_seq_list, t_seq_list):
        y = mdl(Variable(x_seq.reshape(1, X)))
        if t_seq is not None:
            t = Variable(t_seq.reshape(1, Y))
            if train:
                loss += (y - t) ** 2
            else:
                result.append(np.argmax(y.data))

    mdl.reset_state()
    if train:
        mdl.cleargrads()
        loss.grad = np.ones(loss.data.shape, dtype=np.float32)
        loss.backward()
        loss.unchain_backward()
        opt.update()
    else:
        print 'in ', content, 'ASC' if order == 0 else 'DESC'
        print 'ans', sorted_content
        print 'out', result, sum(1.0 if s == r else 0.0 for (s, r) in zip(sorted_content, result)) / len(result)


for i in range(100000):
    run(2, 9, train=True)
    if i % 100 == 0:
        print
        print '#', i
        for j in range(2, 10):
            print
            run(j, j + 1)

実行環境

$ python --version
Python 2.7.11

$ pip freeze
chainer==1.17.0
filelock==2.0.7
nose==1.3.7
numpy==1.11.2
protobuf==3.1.0.post1
six==1.10.0
42
40
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
42
40