Python
DeepLearning
Keras
LSTM
DQN

【dqn】LSTMで深層強化学習する♬~真理値表の学習

実は先の記事をアップしたら、下記の参考のような記事が関連記事としてリンクされていた。
DLで比較的難しい「論理回路の真理値表」を学習させていて、結果は失敗とのこと。
そもそも一番簡単な論理は学習できるよなぁ~
ということで先日のプログラムでやってみた。

【参考】
ディープラーニングで論理回路を学習、予測させてみた
【問題7】 真理値表からゲート回路を作る

やったこと

(1)真理値表
(2)コード
(3)実行結果
(4)ちょっと考察

(1)真理値表

真理値表って以下のようなものです。
※画像は参考②のものを使わせていただきました
ay_dr02_07_mon.gif
この真理値表は基本的な以下の真理値表の組み合わせで作成できます。
ay_dri207_fig01.gif
ay_dri207_fig02.gif
ay_dri207_fig03.gif

大切なことは、少なくともデジタル回路はこういう論理式から構成されているということです。

ということで、簡単そうなのでこの基本的は真理値表を学習できるのかというのをやってみました。

(2)コード

※コードは見出しからリンクしています
コードは前回のコードを以下の部分だけ変更しました。
一つ目はまず、真理値表っぽくするために正解の生成を以下のようにしました。

df = pandas.DataFrame({"a": numpy.random.rand(1000)//0.5, "b": numpy.random.rand(1000)//0.5})  #//0.5
df["c"]= ((df["a"] * df["b"]).shift(LAG-1)).fillna(0) #論理学習のため

一行目で0,1を出力しています。そして、二行目で演算して真理値を定義しています。
ここでは乗算なので、上記のX=A*Bに相当します。以前は二つ前のa,bを見て二つ後のcを演算していましたが、今回は同じ行を学習するようにしています。
あと、profitの計算を以下のとおり直接的なものに変更しました。

def calc_profit(action, df, index):
    if action == 0: 
        if df["c"][index] ==0:
            return 1
        else:
            return 0 
    elif action == 1: 
        if df["c"][index] ==1:
            return 1
        else:
            return 0 
    return 0

学習モデルは以下のとおり

model = Sequential()
model.add(Reshape(observation_space.shape,
                  input_shape=(1,)+observation_space.shape))
model.add(LSTM(32, input_shape=(1, 2), 
          return_sequences=False,
          dropout=0.0))
model.add(Dense(n_action))
model.add(Activation('linear'))
print(model.summary())

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
reshape_1 (Reshape)          (None, 1, 2)              0
_________________________________________________________________
lstm_1 (LSTM)                (None, 32)                4480
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 66
_________________________________________________________________
activation_1 (Activation)    (None, 2)                 0
=================================================================
Total params: 4,546
Trainable params: 4,546
Non-trainable params: 0
_________________________________________________________________

(3)実行結果

実行結果は以下のとおりになりました。
これは学習前の再現性を見た結果です。ほぼ500程度のスコアでこれは乱数で0,1を当てようとしたときのスコアとして妥当です。

       a    b    c  pred  profit
0    1.0  0.0  0.0   NaN     NaN
1    0.0  1.0  0.0   1.0     0.0
2    1.0  0.0  0.0   0.0     1.0
3    0.0  0.0  0.0   0.0     1.0
4    1.0  0.0  0.0   0.0     1.0
5    1.0  0.0  0.0   0.0     1.0
6    0.0  1.0  0.0   1.0     0.0
7    0.0  1.0  0.0   1.0     0.0
8    0.0  0.0  0.0   0.0     1.0
9    0.0  0.0  0.0   0.0     1.0
10   1.0  0.0  0.0   0.0     1.0
11   1.0  1.0  1.0   0.0     0.0
12   1.0  1.0  1.0   0.0     0.0
13   0.0  1.0  0.0   1.0     0.0
14   1.0  1.0  1.0   0.0     0.0
15   1.0  0.0  0.0   0.0     1.0
16   0.0  1.0  0.0   1.0     0.0
17   0.0  1.0  0.0   1.0     0.0
18   1.0  0.0  0.0   0.0     1.0
19   1.0  1.0  1.0   0.0     0.0
20   1.0  0.0  0.0   0.0     1.0
21   1.0  0.0  0.0   0.0     1.0
22   1.0  0.0  0.0   0.0     1.0
23   0.0  1.0  0.0   1.0     0.0
24   1.0  0.0  0.0   0.0     1.0
25   1.0  1.0  1.0   0.0     0.0
26   1.0  1.0  1.0   0.0     0.0
27   1.0  0.0  0.0   0.0     1.0
28   0.0  1.0  0.0   1.0     0.0
29   0.0  0.0  0.0   0.0     1.0
..   ...  ...  ...   ...     ...
970  1.0  0.0  0.0   0.0     1.0
971  1.0  0.0  0.0   0.0     1.0
972  0.0  0.0  0.0   0.0     1.0
973  0.0  0.0  0.0   0.0     1.0
974  1.0  1.0  1.0   0.0     0.0
975  0.0  0.0  0.0   0.0     1.0
976  0.0  1.0  0.0   1.0     0.0
977  1.0  1.0  1.0   0.0     0.0
978  1.0  1.0  1.0   0.0     0.0
979  1.0  0.0  0.0   0.0     1.0
980  0.0  0.0  0.0   0.0     1.0
981  0.0  1.0  0.0   1.0     0.0
982  0.0  1.0  0.0   1.0     0.0
983  0.0  0.0  0.0   0.0     1.0
984  1.0  1.0  1.0   0.0     0.0
985  1.0  0.0  0.0   0.0     1.0
986  0.0  0.0  0.0   0.0     1.0
987  1.0  0.0  0.0   0.0     1.0
988  1.0  1.0  1.0   0.0     0.0
989  1.0  1.0  1.0   0.0     0.0
990  1.0  1.0  1.0   0.0     0.0
991  1.0  1.0  1.0   0.0     0.0
992  0.0  1.0  0.0   1.0     0.0
993  0.0  0.0  0.0   0.0     1.0
994  1.0  1.0  1.0   0.0     0.0
995  0.0  0.0  0.0   0.0     1.0
996  0.0  0.0  0.0   0.0     1.0
997  0.0  1.0  0.0   1.0     0.0
998  0.0  1.0  0.0   1.0     0.0
999  0.0  1.0  0.0   NaN     NaN

[1000 rows x 5 columns]
244.0 520.0

以下は1000回回したときの結果です。結果は950点くらいになり、かなり向上しています。しかし、あと学習回数を増やしても(100000回)このあたりの得点(960点辺り)以上の値にはなりませんでした。

done, took 12.939 seconds
       a    b    c  pred  profit
0    1.0  0.0  0.0   NaN     NaN
1    0.0  1.0  0.0   0.0     1.0
2    1.0  0.0  0.0   0.0     1.0
3    0.0  0.0  0.0   0.0     1.0
4    1.0  0.0  0.0   0.0     1.0
5    1.0  0.0  0.0   0.0     1.0
6    0.0  1.0  0.0   0.0     1.0
7    0.0  1.0  0.0   0.0     1.0
8    0.0  0.0  0.0   0.0     1.0
9    0.0  0.0  0.0   0.0     1.0
10   1.0  0.0  0.0   0.0     1.0
11   1.0  1.0  1.0   1.0     1.0
12   1.0  1.0  1.0   1.0     1.0
13   0.0  1.0  0.0   0.0     1.0
14   1.0  1.0  1.0   1.0     1.0
15   1.0  0.0  0.0   0.0     1.0
16   0.0  1.0  0.0   0.0     1.0
17   0.0  1.0  0.0   0.0     1.0
18   1.0  0.0  0.0   0.0     1.0
19   1.0  1.0  1.0   1.0     1.0
20   1.0  0.0  0.0   0.0     1.0
21   1.0  0.0  0.0   0.0     1.0
22   1.0  0.0  0.0   0.0     1.0
23   0.0  1.0  0.0   0.0     1.0
24   1.0  0.0  0.0   0.0     1.0
25   1.0  1.0  1.0   1.0     1.0
26   1.0  1.0  1.0   1.0     1.0
27   1.0  0.0  0.0   0.0     1.0
28   0.0  1.0  0.0   0.0     1.0
29   0.0  0.0  0.0   0.0     1.0
..   ...  ...  ...   ...     ...
970  1.0  0.0  0.0   0.0     1.0
971  1.0  0.0  0.0   0.0     1.0
972  0.0  0.0  0.0   0.0     1.0
973  0.0  0.0  0.0   0.0     1.0
974  1.0  1.0  1.0   1.0     1.0
975  0.0  0.0  0.0   0.0     1.0
976  0.0  1.0  0.0   0.0     1.0
977  1.0  1.0  1.0   1.0     1.0
978  1.0  1.0  1.0   1.0     1.0
979  1.0  0.0  0.0   0.0     1.0
980  0.0  0.0  0.0   1.0     0.0
981  0.0  1.0  0.0   0.0     1.0
982  0.0  1.0  0.0   0.0     1.0
983  0.0  0.0  0.0   0.0     1.0
984  1.0  1.0  1.0   1.0     1.0
985  1.0  0.0  0.0   0.0     1.0
986  0.0  0.0  0.0   0.0     1.0
987  1.0  0.0  0.0   0.0     1.0
988  1.0  1.0  1.0   1.0     1.0
989  1.0  1.0  1.0   1.0     1.0
990  1.0  1.0  1.0   1.0     1.0
991  1.0  1.0  1.0   1.0     1.0
992  0.0  1.0  0.0   0.0     1.0
993  0.0  0.0  0.0   0.0     1.0
994  1.0  1.0  1.0   1.0     1.0
995  0.0  0.0  0.0   1.0     0.0
996  0.0  0.0  0.0   0.0     1.0
997  0.0  1.0  0.0   0.0     1.0
998  0.0  1.0  0.0   0.0     1.0
999  0.0  1.0  0.0   NaN     NaN

[1000 rows x 5 columns]
244.0 942.0

コードの解生成部分を以下のとおり変更して、X=A+Bを作ってみた

df = pandas.DataFrame({"a": numpy.random.rand(1000)//0.5, "b": numpy.random.rand(1000)//0.5})  #//0.5
df["c"]= ((df["a"]  + df["b"]-df["a"]*df["b"]).shift(LAG-1)).fillna(0) #論理学習のため

そして、当然だけど、同じように学習した。。。

done, took 166.969 seconds
       a    b    c  pred  profit
0    0.0  0.0  0.0   NaN     NaN
1    1.0  1.0  1.0   1.0     1.0
2    0.0  0.0  0.0   0.0     1.0
3    0.0  0.0  0.0   0.0     1.0
4    0.0  0.0  0.0   0.0     1.0
5    1.0  1.0  1.0   1.0     1.0
6    1.0  0.0  1.0   1.0     1.0
7    1.0  0.0  1.0   1.0     1.0
8    1.0  0.0  1.0   1.0     1.0
9    0.0  0.0  0.0   0.0     1.0
10   0.0  1.0  1.0   1.0     1.0
11   1.0  1.0  1.0   1.0     1.0
12   0.0  1.0  1.0   1.0     1.0
13   0.0  0.0  0.0   0.0     1.0
14   0.0  1.0  1.0   1.0     1.0
15   0.0  0.0  0.0   0.0     1.0
16   1.0  1.0  1.0   0.0     0.0
17   0.0  1.0  1.0   1.0     1.0
18   1.0  1.0  1.0   1.0     1.0
19   0.0  0.0  0.0   0.0     1.0
20   0.0  1.0  1.0   1.0     1.0
21   0.0  1.0  1.0   1.0     1.0
22   0.0  1.0  1.0   1.0     1.0
23   1.0  0.0  1.0   1.0     1.0
24   0.0  0.0  0.0   0.0     1.0
25   0.0  0.0  0.0   0.0     1.0
26   0.0  0.0  0.0   0.0     1.0
27   1.0  1.0  1.0   1.0     1.0
28   1.0  1.0  1.0   1.0     1.0
29   1.0  1.0  1.0   1.0     1.0
..   ...  ...  ...   ...     ...
970  1.0  1.0  1.0   0.0     0.0
971  0.0  1.0  1.0   1.0     1.0
972  1.0  1.0  1.0   1.0     1.0
973  0.0  0.0  0.0   0.0     1.0
974  0.0  0.0  0.0   0.0     1.0
975  0.0  0.0  0.0   0.0     1.0
976  1.0  1.0  1.0   1.0     1.0
977  1.0  1.0  1.0   1.0     1.0
978  0.0  0.0  0.0   0.0     1.0
979  0.0  0.0  0.0   0.0     1.0
980  0.0  1.0  1.0   1.0     1.0
981  0.0  1.0  1.0   1.0     1.0
982  1.0  1.0  1.0   1.0     1.0
983  1.0  0.0  1.0   1.0     1.0
984  0.0  0.0  0.0   0.0     1.0
985  1.0  1.0  1.0   1.0     1.0
986  0.0  0.0  0.0   0.0     1.0
987  0.0  0.0  0.0   0.0     1.0
988  1.0  0.0  1.0   1.0     1.0
989  1.0  1.0  1.0   0.0     0.0
990  1.0  0.0  1.0   1.0     1.0
991  0.0  0.0  0.0   0.0     1.0
992  0.0  1.0  1.0   1.0     1.0
993  1.0  1.0  1.0   1.0     1.0
994  0.0  1.0  1.0   1.0     1.0
995  1.0  1.0  1.0   0.0     0.0
996  0.0  1.0  1.0   1.0     1.0
997  1.0  1.0  1.0   1.0     1.0
998  0.0  0.0  0.0   0.0     1.0
999  1.0  0.0  1.0   NaN     NaN

[1000 rows x 5 columns]
736.0 942.0

(4)ちょっと考察

これやっていて、少なくとも問題が二つ発生します。

一つは、なぜ学習後の正解率が95%で、100%にならないのか?
もう一つは、もっと深刻な話かもしれませんが、最初の真理値表を学習したモデルは二つ目の真理値表は正解できません。また、逆も同じです。

第一の疑問は、今回考えても迷宮入りだと思っています。なのでここでは考えないことにします。
第二の疑問は、当たり前と言ってしまえばな疑問です。
しかし、この手のフィッティングではありがちな問題だと思います。一種の過学習みたいなものですが、情報が不足しているとしか言いようのない問題です。
つまり、もともと事前情報としてX=A・Bです。とかX=A+Bとかという情報があって、そのうえで判断するというものです。
つまり、この演算子込みで覚える必要があったということです。
その情報がなければこの手の問題の正解は確率的なものとなるのだと思います。
例えば、出現順序に情報が隠されている場合は単純なLSTMで正解できるようになるかもしれません。

というわけど、重要なことは単純にフィッティングできるからと言ってそれを状況を替えても同じように有効だという主張は時に破綻することがあるだろうと考えておく必要があるということです。

それにしても第一の疑問は何なんでしょうね??

まとめ

今回は記事にするかどうか瀬戸際な内容ですが、ご容赦ください。
・真理値表の学習をdqnしてみた
・そもそも演算子込みで学習しないと意味がないことを理解した

・正解率は95%以上にはなかなか上がらないが理由は不明である
 ちなみに、以下のとおり単純なMLPモデルでも同じような値に収束しており、これはLSTMモデルの問題ではなく、なんとなく一般的なある限界を示しているような気がする。

model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(n_action))
model.add(Activation('linear'))
done, took 190.631 seconds
       a    b    c  pred  profit
0    0.0  0.0  0.0   NaN     NaN
1    0.0  1.0  1.0   1.0     1.0
2    1.0  1.0  1.0   1.0     1.0
3    0.0  0.0  0.0   0.0     1.0
4    0.0  1.0  1.0   1.0     1.0
5    0.0  0.0  0.0   0.0     1.0
6    0.0  0.0  0.0   0.0     1.0
7    0.0  1.0  1.0   1.0     1.0
8    0.0  0.0  0.0   1.0     0.0
9    0.0  1.0  1.0   1.0     1.0
10   1.0  0.0  1.0   1.0     1.0
11   0.0  0.0  0.0   0.0     1.0
12   0.0  1.0  1.0   1.0     1.0
13   0.0  1.0  1.0   1.0     1.0
14   0.0  0.0  0.0   0.0     1.0
15   0.0  0.0  0.0   0.0     1.0
16   0.0  0.0  0.0   0.0     1.0
17   0.0  0.0  0.0   0.0     1.0
18   1.0  0.0  1.0   1.0     1.0
19   1.0  1.0  1.0   1.0     1.0
20   0.0  1.0  1.0   1.0     1.0
21   1.0  1.0  1.0   1.0     1.0
22   0.0  0.0  0.0   0.0     1.0
23   0.0  1.0  1.0   1.0     1.0
24   1.0  1.0  1.0   1.0     1.0
25   1.0  1.0  1.0   1.0     1.0
26   1.0  0.0  1.0   1.0     1.0
27   0.0  0.0  0.0   0.0     1.0
28   0.0  0.0  0.0   0.0     1.0
29   0.0  0.0  0.0   0.0     1.0
..   ...  ...  ...   ...     ...
970  0.0  1.0  1.0   1.0     1.0
971  1.0  1.0  1.0   1.0     1.0
972  0.0  1.0  1.0   1.0     1.0
973  0.0  1.0  1.0   1.0     1.0
974  1.0  0.0  1.0   0.0     0.0
975  0.0  0.0  0.0   0.0     1.0
976  0.0  0.0  0.0   0.0     1.0
977  1.0  1.0  1.0   1.0     1.0
978  0.0  1.0  1.0   1.0     1.0
979  1.0  0.0  1.0   1.0     1.0
980  1.0  0.0  1.0   1.0     1.0
981  1.0  1.0  1.0   1.0     1.0
982  0.0  1.0  1.0   1.0     1.0
983  1.0  0.0  1.0   1.0     1.0
984  0.0  0.0  0.0   0.0     1.0
985  0.0  0.0  0.0   0.0     1.0
986  0.0  1.0  1.0   1.0     1.0
987  1.0  0.0  1.0   1.0     1.0
988  1.0  1.0  1.0   1.0     1.0
989  0.0  1.0  1.0   1.0     1.0
990  1.0  0.0  1.0   1.0     1.0
991  0.0  1.0  1.0   1.0     1.0
992  1.0  0.0  1.0   1.0     1.0
993  1.0  1.0  1.0   1.0     1.0
994  1.0  0.0  1.0   1.0     1.0
995  0.0  1.0  1.0   1.0     1.0
996  0.0  1.0  1.0   1.0     1.0
997  1.0  0.0  1.0   1.0     1.0
998  1.0  0.0  1.0   1.0     1.0
999  1.0  1.0  1.0   NaN     NaN

[1000 rows x 5 columns]
762.0 944.0