Help us understand the problem. What is going on with this article?

(Chainer) DNC(Differentiable Neural Computers)で文字列の学習&生成

More than 3 years have passed since last update.

はじめに

文字列を学習して、それっぽくを生成させて楽しみたい。
せっかくなので、興味があったDNCでやりたい。

DNCが何かについては

などを参照してください。

DNCの実装

DNC (Differentiable Neural Computers) の概要 + Chainer による実装
に載っているコードをベースに編集して、GPUバッチ学習に対応させました。
ベースにできるコードがあったのでだいぶ助かりました。ありがとうございます。 :grinning:

個人的にはまあ許容できるくらいの速度で動いてます。
具体的な計測はしていないですが、chainernumpy.sort, numpy.cumprod, numpy.prod相当のものがあればもっと早くなると思います

コードはこちらにあります。
python3でしか動作を確認していませんがpython2でも動くかもしれません。

ちゃんと実装できてる?

計算部分を弄った分ちゃんと実装できてるか不安なので、DNC (Differentiable Neural Computers) の概要 + Chainer による実装
で行われていたものと同じタスクをやらせてみます。

どんなタスクかというと、

実行ステップ 1 2 3 4 5 6 7 8
入力 1 1 2 3 - - - -
欲しい出力 x x x x 1 1 2 3

のように遅れて最初の入力を返させます。(-はゼロ, xは出力がなんでもいいことを示しています)

ところで、DNCのコントローラーとしてLSTMを使うと、LSTMの力だけでタスクをクリアしてしまうかもしれないので、コントローラーとしてLinearを使いました。(内部に状態を持たないので、外部メモリを使わざるをえないはずです)

このタスクができたとしても実装が正しいという保証にはなりませんが、できなきゃヤバ

結果

( 1990 ) 0.000777187338099 0.0
[[-26.15471077  -3.17894959  -8.98376179   9.1276083   -0.657956  ]] [3] True
[[-29.14633179  -0.67815143 -10.29454613   4.49074841   3.79639697]] [4] False
[[-28.97230339   0.72816759  -9.82902908   7.07141733  -0.72419703]] [3] True
[[-27.09978485   0.47589365  -8.69920349   7.42882442  -1.40964496]] [3] True
[[-24.46948624   4.50180721  -7.12348366   0.55667412  -1.48180509]] [1] True
.
.
.
[[-20.07484818  -0.60636777  -6.53328705  -0.06233668   4.94069195]] [4] True
[[-22.54798508   4.72276878  -6.0871973   -1.23646426  -0.35858893]] [1] True
[[-21.5928688    0.5505662   -6.83123302   5.42372036  -0.77451634]] [3] True
[[-18.02187729   0.08152789  -7.14690161   5.81850386  -0.7215544 ]] [3] True
( 1997 ) 0.00210291147232 0.0
[[-22.53721428  -3.10748935 -13.75279045  10.75553894   2.53855157]] [3] True
[[-20.70213318   1.15888417  13.07026482  -1.97864366  -8.50465393]] [2] True
[[-18.76399612  -0.41023731  -6.24387026  -2.41328239   7.38532257]] [4] True
[[-14.35320473   1.11414146   9.84414101  -4.74714947  -3.62668037]] [2] True
( 1998 ) 5.68032264709e-05 0.0
[[ -3.94404173   3.23077869  10.6953516   -6.58970022  -7.45262957]] [2] True
[[-7.08518839  1.42143321  8.55715466 -5.73273754 -7.28167248]] [2] True
[[-6.47711086  0.12313008  7.89064693 -4.47937965 -5.46987915]] [2] True
[[-6.35165119 -0.89073217  6.81142473 -3.43888187 -4.04115105]] [2] True

2000回学習した付近を見たところ間違えてないようでした。(Trueがいっぱい表示されてる)
軽くミスっていところもありますがとりあえずOKということで

文字列生成プログラム

ChainerのptbデモにDNCをぶち込む感じで。
でもちょっと変えます。

動くものはGitHubに上げました。

ptbデモの不満なところ

ptbデモにTrainerと可変長データがあまり噛み合ってない箇所がありました。
しかしTrainerを使ったほうがいろいろ楽な気がするので対処していきます。

1. プログレスバーの表示がおかしい

  • イテレータ(ParallelSequentialIterator)のepoch_detailでちゃんとした値を出すようにする
  • 試していないですが、1.18.0で治った :grinning:

2. ロスの表示おかしくない?

BPTTUpdaterでは決められた回数ロスを集めてからバックプロップしているので、ちゃんとTrainerが取得できてるか不安。
そもそもTrainerはどうやってロスの値を取得しているのでしょうか。
使っている場合は、ですがchainer.links.Classifierがやってくれているようです

chainer\links\model\classifier.py
x = args[:-1]
t = args[-1]
self.y = None
self.loss = None
self.accuracy = None
self.y = self.predictor(*x)
self.loss = self.lossfun(self.y, t)
reporter.report({'loss': self.loss}, self) # これ

しかし、BPTTUpdaterでは複数回ロスを集めてからバックプロップしているので、最新以外のロスのreportが上書きされてる気がします。
ちょっと調べてみところ表示と実際のロスが合っていないようだったので(間違ってたらすいません)、
今回は、BPTTUpdaterupdate_coreを編集して、一回バッチを読んだら処理を返すようにしました。
この場合iteration が ウェイトの更新回数と一致しなくなりますが、もしその情報が必要なら自分でreportすれば良さそうです。

実験

時間がなかったので小規模なデータで学習します。
作ったツールで行っていきます。

データ

r/Flamewankerをご存知でしょうか?
このサブレディットではハースストーンというカードゲームのカードテキストをコンピュータで生成して楽しんでいます。
https://www.reddit.com/r/Flamewanker/wiki/index
ここからハースストーンの一個前の拡張までのカードテキストを入手して学習していきます

すいません。唐突に謎のデータが出てきたと思うかもしれませんが、このまま進みます

前処理

データはだいたいこんな感じです

Assassin's Blade @ Rogue |  | Weapon | C | 5 | 3/4 ||  &
Assassinate @ Rogue |  | Spell | B | 5 || Destroy an enemy minion. &
Backstab @ Rogue |  | Spell | B | 0 || Deal 2 damage to an undamaged minion. &
Blessing of Kings @ Paladin |  | Spell | C | 4 || Give a minion +4/+4. [i](+4 Attack/+4 Health)[/i] &
Blessing of Might @ Paladin |  | Spell | B | 1 || Give a minion +3 Attack. &
Bloodfen Raptor @ Neutral | Beast | Minion | B | 2 | 3/2 ||  &

一行ごとに一枚のカードの情報が入っているので、一行ごとに切り出してそれぞれ別のテキストファイルにします。

出来たらデータを一つのファイルにまとめます

$ python pack.py [テキストファイルがいっぱいあるフォルダ] -o flamewanker.pickle

pack.pyはディレクトリにあるテキストファイルを読み込んで、 学習データと辞書を作成します。

形態素解析とかはしてません。1文字ごとに分離します。

学習

$ python train_texts.py flamewanker.pickle -g 0 # -g オプションはgpuを使うかどうか

これでresultディレクトリに学習したモデルがエポックごとに保存されていきます。
デフォルトで40エポック学習するので全部終わるまでGTX 1070で100分位かかりました。

途中のエポックのモデルデータでも生成はできるので次のステップに行きましょう。

モデル

学習に使うモデルは今のところ固定です。(入力, 出力以外)

DNCのパラメータ

よくわからないので、根拠なしの適当です

論文中の記号 意味
X 入力ベクトルのサイズ 85
Y 出力ベクトルのサイズ 85
N メモリの数 64
W メモリ一つのサイズ 64
R リードヘッダの数 16
DNCのコントローラ
super(DeepLSTM, self).__init__(
    l1=L.LSTM(X, 500),
    l2=L.LSTM(500, 500),
    l3=L.Linear(500, Y),
)
def __call__(self, x):
    h1 = self.l1(F.dropout(x, train=self.train))
    h2 = self.l2(F.dropout(h1, train=self.train))
    y = self.l3(F.dropout(h2, train=self.train))
    return y
その他の情報

オプティマイザにはRMSpropを使いました。学習率はデフォルトです。
データから何割か抜き取ってバリデーションデータセットにして~みたいなことは(まだ)していません。

生成

引数に先程のflamewanker.pickleを指定する以外はptbデモと同じです

-p は最初にモデルに入れる文字
-s は乱数のシード値

$ python gentxt.py flamewanker.pickle -m result/model_epoch_37 -p F -s 5454

FnlooBg Kangn @ Neutral | | Minion | R | 8 | 4/5 || [b]Battlecry:[/b] Destroy your minion with [b]Taunt[/b]. [b]Taunt[/b] &

激弱カードですが、しっかり文法を守っていて良いですね。
大抵はもっと無茶苦茶な出力をします。(色々改善の余地はあると思います)

いい感じに生成できたら適当な画像を使ってカードにするのも楽しいです。
http://www.hearthcards.net/

2de5048f.png

hatoo@github
Rustのお仕事ください!
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした