はじめに
文字列を学習して、それっぽくを生成させて楽しみたい。
せっかくなので、興味があったDNCでやりたい。
DNCが何かについては
などを参照してください。
DNCの実装
DNC (Differentiable Neural Computers) の概要 + Chainer による実装
に載っているコードをベースに編集して、GPUとバッチ学習に対応させました。
ベースにできるコードがあったのでだいぶ助かりました。ありがとうございます。
個人的にはまあ許容できるくらいの速度で動いてます。
具体的な計測はしていないですが、chainer
にnumpy.sort
, numpy.cumprod
, numpy.prod
相当のものがあればもっと早くなると思います
コードはこちらにあります。
python3でしか動作を確認していませんがpython2でも動くかもしれません。
ちゃんと実装できてる?
計算部分を弄った分ちゃんと実装できてるか不安なので、DNC (Differentiable Neural Computers) の概要 + Chainer による実装
で行われていたものと同じタスクをやらせてみます。
どんなタスクかというと、
実行ステップ | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|---|
入力 | 1 | 1 | 2 | 3 | - | - | - | - |
欲しい出力 | x | x | x | x | 1 | 1 | 2 | 3 |
のように遅れて最初の入力を返させます。(-はゼロ, xは出力がなんでもいいことを示しています)
ところで、DNCのコントローラーとしてLSTM
を使うと、LSTM
の力だけでタスクをクリアしてしまうかもしれないので、コントローラーとしてLinear
を使いました。(内部に状態を持たないので、外部メモリを使わざるをえないはずです)
このタスクができたとしても実装が正しいという保証にはなりませんが、できなきゃヤバ
結果
( 1990 ) 0.000777187338099 0.0
[[-26.15471077 -3.17894959 -8.98376179 9.1276083 -0.657956 ]] [3] True
[[-29.14633179 -0.67815143 -10.29454613 4.49074841 3.79639697]] [4] False
[[-28.97230339 0.72816759 -9.82902908 7.07141733 -0.72419703]] [3] True
[[-27.09978485 0.47589365 -8.69920349 7.42882442 -1.40964496]] [3] True
[[-24.46948624 4.50180721 -7.12348366 0.55667412 -1.48180509]] [1] True
.
.
.
[[-20.07484818 -0.60636777 -6.53328705 -0.06233668 4.94069195]] [4] True
[[-22.54798508 4.72276878 -6.0871973 -1.23646426 -0.35858893]] [1] True
[[-21.5928688 0.5505662 -6.83123302 5.42372036 -0.77451634]] [3] True
[[-18.02187729 0.08152789 -7.14690161 5.81850386 -0.7215544 ]] [3] True
( 1997 ) 0.00210291147232 0.0
[[-22.53721428 -3.10748935 -13.75279045 10.75553894 2.53855157]] [3] True
[[-20.70213318 1.15888417 13.07026482 -1.97864366 -8.50465393]] [2] True
[[-18.76399612 -0.41023731 -6.24387026 -2.41328239 7.38532257]] [4] True
[[-14.35320473 1.11414146 9.84414101 -4.74714947 -3.62668037]] [2] True
( 1998 ) 5.68032264709e-05 0.0
[[ -3.94404173 3.23077869 10.6953516 -6.58970022 -7.45262957]] [2] True
[[-7.08518839 1.42143321 8.55715466 -5.73273754 -7.28167248]] [2] True
[[-6.47711086 0.12313008 7.89064693 -4.47937965 -5.46987915]] [2] True
[[-6.35165119 -0.89073217 6.81142473 -3.43888187 -4.04115105]] [2] True
2000回学習した付近を見たところ間違えてないようでした。(Trueがいっぱい表示されてる)
軽くミスっていところもありますがとりあえずOKということで
文字列生成プログラム
ChainerのptbデモにDNCをぶち込む感じで。
でもちょっと変えます。
動くものはGitHubに上げました。
ptbデモの不満なところ
ptbデモにTrainerと可変長データがあまり噛み合ってない箇所がありました。
しかしTrainerを使ったほうがいろいろ楽な気がするので対処していきます。
1. プログレスバーの表示がおかしい
- イテレータ(
ParallelSequentialIterator
)のepoch_detail
でちゃんとした値を出すようにする - 試していないですが、1.18.0で治った
2. ロスの表示おかしくない?
BPTTUpdater
では決められた回数ロスを集めてからバックプロップしているので、ちゃんとTrainerが取得できてるか不安。
そもそもTrainerはどうやってロスの値を取得しているのでしょうか。
使っている場合は、ですがchainer.links.Classifier
がやってくれているようです
x = args[:-1]
t = args[-1]
self.y = None
self.loss = None
self.accuracy = None
self.y = self.predictor(*x)
self.loss = self.lossfun(self.y, t)
reporter.report({'loss': self.loss}, self) # これ
しかし、BPTTUpdater
では複数回ロスを集めてからバックプロップしているので、最新以外のロスのreportが上書きされてる気がします。
ちょっと調べてみところ表示と実際のロスが合っていないようだったので(間違ってたらすいません)、
今回は、BPTTUpdater
のupdate_core
を編集して、一回バッチを読んだら処理を返すようにしました。
この場合iteration が ウェイトの更新回数と一致しなくなりますが、もしその情報が必要なら自分でreportすれば良さそうです。
実験
時間がなかったので小規模なデータで学習します。
作ったツールで行っていきます。
データ
r/Flamewankerをご存知でしょうか?
このサブレディットではハースストーンというカードゲームのカードテキストをコンピュータで生成して楽しんでいます。
https://www.reddit.com/r/Flamewanker/wiki/index
ここからハースストーンの一個前の拡張までのカードテキストを入手して学習していきます
すいません。唐突に謎のデータが出てきたと思うかもしれませんが、このまま進みます
前処理
データはだいたいこんな感じです
Assassin's Blade @ Rogue | | Weapon | C | 5 | 3/4 || &
Assassinate @ Rogue | | Spell | B | 5 || Destroy an enemy minion. &
Backstab @ Rogue | | Spell | B | 0 || Deal 2 damage to an undamaged minion. &
Blessing of Kings @ Paladin | | Spell | C | 4 || Give a minion +4/+4. [i](+4 Attack/+4 Health)[/i] &
Blessing of Might @ Paladin | | Spell | B | 1 || Give a minion +3 Attack. &
Bloodfen Raptor @ Neutral | Beast | Minion | B | 2 | 3/2 || &
一行ごとに一枚のカードの情報が入っているので、一行ごとに切り出してそれぞれ別のテキストファイルにします。
出来たらデータを一つのファイルにまとめます
$ python pack.py [テキストファイルがいっぱいあるフォルダ] -o flamewanker.pickle
pack.pyはディレクトリにあるテキストファイルを読み込んで、 学習データと辞書を作成します。
**形態素解析とかはしてません。**1文字ごとに分離します。
学習
$ python train_texts.py flamewanker.pickle -g 0 # -g オプションはgpuを使うかどうか
これでresult
ディレクトリに学習したモデルがエポックごとに保存されていきます。
デフォルトで40エポック学習するので全部終わるまでGTX 1070で100分位かかりました。
途中のエポックのモデルデータでも生成はできるので次のステップに行きましょう。
モデル
学習に使うモデルは今のところ固定です。(入力, 出力以外)
DNCのパラメータ
よくわからないので、根拠なしの適当です
論文中の記号 | 意味 | 値 |
---|---|---|
X | 入力ベクトルのサイズ | 85 |
Y | 出力ベクトルのサイズ | 85 |
N | メモリの数 | 64 |
W | メモリ一つのサイズ | 64 |
R | リードヘッダの数 | 16 |
DNCのコントローラ
super(DeepLSTM, self).__init__(
l1=L.LSTM(X, 500),
l2=L.LSTM(500, 500),
l3=L.Linear(500, Y),
)
def __call__(self, x):
h1 = self.l1(F.dropout(x, train=self.train))
h2 = self.l2(F.dropout(h1, train=self.train))
y = self.l3(F.dropout(h2, train=self.train))
return y
その他の情報
オプティマイザにはRMSpropを使いました。学習率はデフォルトです。
データから何割か抜き取ってバリデーションデータセットにして~みたいなことは(まだ)していません。
生成
引数に先程のflamewanker.pickle
を指定する以外はptbデモと同じです
-p は最初にモデルに入れる文字
-s は乱数のシード値
$ python gentxt.py flamewanker.pickle -m result/model_epoch_37 -p F -s 5454
例
FnlooBg Kangn @ Neutral | | Minion | R | 8 | 4/5 || [b]Battlecry:[/b] Destroy your minion with [b]Taunt[/b]. [b]Taunt[/b] &
激弱カードですが、しっかり文法を守っていて良いですね。
大抵はもっと無茶苦茶な出力をします。(色々改善の余地はあると思います)
いい感じに生成できたら適当な画像を使ってカードにするのも楽しいです。
http://www.hearthcards.net/