LoginSignup
2
3

More than 5 years have passed since last update.

【dqn】LSTMで深層強化学習する♬~論理和・積(+,*)の同時学習

Last updated at Posted at 2018-07-24

前回は論理和と論理積を別々に学習できたと称したが、やはり同時に学習できて初めて論理和と論理積が学習できたと云える。
※少なくとも人はそれで初めて学習したと考える

ということでLSTMに多次元入力(もともと二次元だったが)、三次元入力にして学習してみた。
その結果、単独で学習した場合とほぼ同じ精度で学習できた。

やったこと

(1)Pandasでデータ結合
(2)多次元入力にして学習
(3)実行結果と考察

(1)Pandasでデータ結合

はっきり言って、indexの振り直しは0からは比較的検索しやすいけど、完全に振り直しはなかなか出てこなくて苦労しました。
【参考】
pandas 0.23.3 documentation » API Reference » pandas.RangeIndex
Pandas DataFrame RangeIndex
ということで、データ作成を以下のように変更した。
※今回の記事は以下のノウハウが一番価値ありそう

df1 = pd.DataFrame({"s": 1,"a": numpy.random.rand(500)//0.5, "b": numpy.random.rand(500)//0.5})  #//0.5
df1["c"]= ((df1["a"]  + df1["b"]-df1["a"]*df1["b"]).shift(LAG-1)).fillna(0) #論理和学習のため

df2 = pd.DataFrame({"s": 0,"a": numpy.random.rand(500)//0.5, "b": numpy.random.rand(500)//0.5})  #//0.5
df2["c"]= ((df2["a"]  * df2["b"]).shift(LAG-1)).fillna(0) #論理積学習のため
df2.index = pd.RangeIndex(start=500, stop=1000, step=1)

df=pd.concat([df1,df2],axis=0)  #結合を行う 

(2)多次元入力にして学習

コードは変更なしだが、モデルは以下のとおり

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
flatten_1 (Flatten)          (None, 3)                 0
_________________________________________________________________
dense_1 (Dense)              (None, 16)                64
_________________________________________________________________
activation_1 (Activation)    (None, 16)                0
_________________________________________________________________
dense_2 (Dense)              (None, 16)                272
_________________________________________________________________
activation_2 (Activation)    (None, 16)                0
_________________________________________________________________
dense_3 (Dense)              (None, 16)                272
_________________________________________________________________
activation_3 (Activation)    (None, 16)                0
_________________________________________________________________
dense_4 (Dense)              (None, 2)                 34
_________________________________________________________________
activation_4 (Activation)    (None, 2)                 0
=================================================================
Total params: 642
Trainable params: 642
Non-trainable params: 0
_________________________________________________________________

一方、LSTMでは、入力を次元に合わせて変更する。

model = Sequential()
model.add(Reshape(observation_space.shape,
        input_shape=(1,)
        +observation_space.shape))
model.add(LSTM(50,
        input_shape=(3, 1),
        return_sequences=False,
        dropout=0.0))
model.add(Dense(n_action))
model.add(Activation('linear'))
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
reshape_1 (Reshape)          (None, 1, 3)              0
_________________________________________________________________
lstm_1 (LSTM)                (None, 50)                10800
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 102
_________________________________________________________________
activation_1 (Activation)    (None, 2)                 0
=================================================================
Total params: 10,902
Trainable params: 10,902
Non-trainable params: 0
_________________________________________________________________

(3)実行結果と考察

学習前

       a    b   s    c  pred  profit
0    1.0  0.0   0  1.0   NaN     NaN
1    0.0  1.0   0  1.0   1.0     1.0
2    1.0  0.0   0  1.0   0.0     0.0
3    0.0  1.0   0  1.0   1.0     1.0
4    0.0  0.0   0  0.0   0.0     1.0
5    1.0  0.0   0  1.0   0.0     0.0
6    0.0  0.0   0  0.0   0.0     1.0
7    1.0  1.0   0  1.0   0.0     0.0
8    1.0  0.0   0  1.0   0.0     0.0
9    1.0  1.0   0  1.0   0.0     0.0
10   1.0  0.0   0  1.0   0.0     0.0
11   1.0  0.0   0  1.0   0.0     0.0
12   0.0  1.0   0  1.0   1.0     1.0
13   1.0  1.0   0  1.0   0.0     0.0
14   0.0  0.0   0  0.0   0.0     1.0
15   0.0  0.0   0  0.0   0.0     1.0
16   1.0  1.0   0  1.0   0.0     0.0
17   1.0  1.0   0  1.0   0.0     0.0
18   0.0  1.0   0  1.0   1.0     1.0
19   0.0  1.0   0  1.0   1.0     1.0
20   0.0  1.0   0  1.0   1.0     1.0
21   1.0  0.0   0  1.0   0.0     0.0
22   1.0  1.0   0  1.0   0.0     0.0
23   0.0  0.0   0  0.0   0.0     1.0
24   1.0  1.0   0  1.0   0.0     0.0
25   1.0  1.0   0  1.0   0.0     0.0
26   1.0  1.0   0  1.0   0.0     0.0
27   1.0  1.0   0  1.0   0.0     0.0
28   1.0  1.0   0  1.0   0.0     0.0
29   0.0  0.0   0  0.0   0.0     1.0
..   ...  ...  ..  ...   ...     ...
970  0.0  0.0  10  0.0   1.0     0.0
971  0.0  0.0  10  0.0   1.0     0.0
972  1.0  1.0  10  1.0   1.0     1.0
973  0.0  1.0  10  0.0   1.0     0.0
974  1.0  0.0  10  0.0   1.0     0.0
975  1.0  1.0  10  1.0   1.0     1.0
976  0.0  0.0  10  0.0   1.0     0.0
977  0.0  0.0  10  0.0   1.0     0.0
978  0.0  1.0  10  0.0   1.0     0.0
979  0.0  0.0  10  0.0   1.0     0.0
980  1.0  1.0  10  1.0   1.0     1.0
981  1.0  1.0  10  1.0   1.0     1.0
982  1.0  1.0  10  1.0   1.0     1.0
983  1.0  0.0  10  0.0   1.0     0.0
984  0.0  1.0  10  0.0   1.0     0.0
985  1.0  1.0  10  1.0   1.0     1.0
986  0.0  0.0  10  0.0   1.0     0.0
987  0.0  1.0  10  0.0   1.0     0.0
988  0.0  1.0  10  0.0   1.0     0.0
989  0.0  0.0  10  0.0   1.0     0.0
990  1.0  0.0  10  0.0   1.0     0.0
991  0.0  1.0  10  0.0   1.0     0.0
992  1.0  0.0  10  0.0   1.0     0.0
993  1.0  0.0  10  0.0   1.0     0.0
994  0.0  0.0  10  0.0   1.0     0.0
995  0.0  0.0  10  0.0   1.0     0.0
996  1.0  0.0  10  0.0   1.0     0.0
997  1.0  0.0  10  0.0   1.0     0.0
998  0.0  0.0  10  0.0   1.0     0.0
999  1.0  0.0  10  0.0   NaN     NaN

[1000 rows x 6 columns]
523.0 372.0

学習後

done, took 168.876 seconds
       a    b   s    c  pred  profit
0    1.0  0.0   0  1.0   NaN     NaN
1    0.0  1.0   0  1.0   1.0     1.0
2    1.0  0.0   0  1.0   1.0     1.0
3    0.0  1.0   0  1.0   1.0     1.0
4    0.0  0.0   0  0.0   0.0     1.0
5    1.0  0.0   0  1.0   1.0     1.0
6    0.0  0.0   0  0.0   0.0     1.0
7    1.0  1.0   0  1.0   1.0     1.0
8    1.0  0.0   0  1.0   1.0     1.0
9    1.0  1.0   0  1.0   1.0     1.0
10   1.0  0.0   0  1.0   1.0     1.0
11   1.0  0.0   0  1.0   1.0     1.0
12   0.0  1.0   0  1.0   1.0     1.0
13   1.0  1.0   0  1.0   1.0     1.0
14   0.0  0.0   0  0.0   0.0     1.0
15   0.0  0.0   0  0.0   0.0     1.0
16   1.0  1.0   0  1.0   1.0     1.0
17   1.0  1.0   0  1.0   1.0     1.0
18   0.0  1.0   0  1.0   1.0     1.0
19   0.0  1.0   0  1.0   1.0     1.0
20   0.0  1.0   0  1.0   1.0     1.0
21   1.0  0.0   0  1.0   0.0     0.0
22   1.0  1.0   0  1.0   1.0     1.0
23   0.0  0.0   0  0.0   0.0     1.0
24   1.0  1.0   0  1.0   1.0     1.0
25   1.0  1.0   0  1.0   0.0     0.0
26   1.0  1.0   0  1.0   1.0     1.0
27   1.0  1.0   0  1.0   1.0     1.0
28   1.0  1.0   0  1.0   1.0     1.0
29   0.0  0.0   0  0.0   0.0     1.0
..   ...  ...  ..  ...   ...     ...
970  0.0  0.0  10  0.0   0.0     1.0
971  0.0  0.0  10  0.0   1.0     0.0
972  1.0  1.0  10  1.0   1.0     1.0
973  0.0  1.0  10  0.0   0.0     1.0
974  1.0  0.0  10  0.0   0.0     1.0
975  1.0  1.0  10  1.0   1.0     1.0
976  0.0  0.0  10  0.0   0.0     1.0
977  0.0  0.0  10  0.0   0.0     1.0
978  0.0  1.0  10  0.0   0.0     1.0
979  0.0  0.0  10  0.0   0.0     1.0
980  1.0  1.0  10  1.0   1.0     1.0
981  1.0  1.0  10  1.0   1.0     1.0
982  1.0  1.0  10  1.0   1.0     1.0
983  1.0  0.0  10  0.0   0.0     1.0
984  0.0  1.0  10  0.0   0.0     1.0
985  1.0  1.0  10  1.0   1.0     1.0
986  0.0  0.0  10  0.0   0.0     1.0
987  0.0  1.0  10  0.0   0.0     1.0
988  0.0  1.0  10  0.0   0.0     1.0
989  0.0  0.0  10  0.0   0.0     1.0
990  1.0  0.0  10  0.0   0.0     1.0
991  0.0  1.0  10  0.0   0.0     1.0
992  1.0  0.0  10  0.0   0.0     1.0
993  1.0  0.0  10  0.0   1.0     0.0
994  0.0  0.0  10  0.0   0.0     1.0
995  0.0  0.0  10  0.0   0.0     1.0
996  1.0  0.0  10  0.0   0.0     1.0
997  1.0  0.0  10  0.0   0.0     1.0
998  0.0  0.0  10  0.0   0.0     1.0
999  1.0  0.0  10  0.0   NaN     NaN

[1000 rows x 6 columns]
523.0 941.0

つまり論理和と論理積だとa=1, b=0 のときそれぞれ1と0と異なる解になるけど、見事に学習しているのが分かる。
ちなみに、上記はsというパラメータを導入して、論理和0、論理積10として学習している。ちなみに、この値は異なれば何でもいい。

ということで識別されさえすればそれによって、a+bとa*bの演算をするわけで、まさしく演算を理解したように見える。

しかし、正解率を見るとほぼ95%程度であり、前回と同じ精度となった。
この不思議さはどこからきているのだろう。
ということで今回もアイデアを思いつかなかった。

コード全体は以下に置いた

dqn/dqn_lstm_game01.py

まとめ

・論理和と論理積を同時学習できた
・pandasにおける加算とindexの張り直しのやり方を理解した
・多次元でのLSTMの振る舞いが理解できた

・正解率が95%になるのが謎になりつつある
。。。根拠無いけど、これLSTMでのフィッティングでの原理的な限界しめしているのかも??
   と思い出した。。。

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3