#はじめに
化合物でDeepLearningを始めようと思い、手始めにDeepChemのGraphConvModelをハックし、Kerasで実装することにした。そこでまずは、Kerasで実装されているものをmodelオブジェクトのsummaryメソッドにより出力することとした。
環境
- DeepChem 2.3
#方法
GraphConvModelのクラス定義がされているファイルの624行目にmodel.summary()を入れ、適当なデータで予測モデルを作成してみる。
/envs/deepchem/lib/python3.7/site-packages/deepchem/models/graph_conv.py
print(model.summary())
#結果
こんな感じ。論文を読んで大体概要は把握しているが、DeepChemは多少論文と違う作りになっており、解析はこれから行う。
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 75)] 0
__________________________________________________________________________________________________
input_2 (InputLayer) [(None, 2)] 0
__________________________________________________________________________________________________
input_3 (InputLayer) [(None,)] 0
__________________________________________________________________________________________________
input_6 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_7 (InputLayer) [(None, 2)] 0
__________________________________________________________________________________________________
input_8 (InputLayer) [(None, 3)] 0
__________________________________________________________________________________________________
input_9 (InputLayer) [(None, 4)] 0
__________________________________________________________________________________________________
input_10 (InputLayer) [(None, 5)] 0
__________________________________________________________________________________________________
input_11 (InputLayer) [(None, 6)] 0
__________________________________________________________________________________________________
input_12 (InputLayer) [(None, 7)] 0
__________________________________________________________________________________________________
input_13 (InputLayer) [(None, 8)] 0
__________________________________________________________________________________________________
input_14 (InputLayer) [(None, 9)] 0
__________________________________________________________________________________________________
input_15 (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
input_16 (InputLayer) [(None, 11)] 0
__________________________________________________________________________________________________
graph_conv (GraphConv) (None, 64) 102144 input_1[0][0]
input_2[0][0]
input_3[0][0]
input_6[0][0]
input_7[0][0]
input_8[0][0]
input_9[0][0]
input_10[0][0]
input_11[0][0]
input_12[0][0]
input_13[0][0]
input_14[0][0]
input_15[0][0]
input_16[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 64) 256 graph_conv[0][0]
__________________________________________________________________________________________________
graph_pool (GraphPool) (None, 64) 0 batch_normalization[0][0]
input_2[0][0]
input_3[0][0]
input_6[0][0]
input_7[0][0]
input_8[0][0]
input_9[0][0]
input_10[0][0]
input_11[0][0]
input_12[0][0]
input_13[0][0]
input_14[0][0]
input_15[0][0]
input_16[0][0]
__________________________________________________________________________________________________
graph_conv_1 (GraphConv) (None, 64) 87360 graph_pool[0][0]
input_2[0][0]
input_3[0][0]
input_6[0][0]
input_7[0][0]
input_8[0][0]
input_9[0][0]
input_10[0][0]
input_11[0][0]
input_12[0][0]
input_13[0][0]
input_14[0][0]
input_15[0][0]
input_16[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 64) 256 graph_conv_1[0][0]
__________________________________________________________________________________________________
graph_pool_1 (GraphPool) (None, 64) 0 batch_normalization_1[0][0]
input_2[0][0]
input_3[0][0]
input_6[0][0]
input_7[0][0]
input_8[0][0]
input_9[0][0]
input_10[0][0]
input_11[0][0]
input_12[0][0]
input_13[0][0]
input_14[0][0]
input_15[0][0]
input_16[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 128) 8320 graph_pool_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 128) 512 dense[0][0]
__________________________________________________________________________________________________
graph_gather (GraphGather) (64, 256) 0 batch_normalization_2[0][0]
input_2[0][0]
input_3[0][0]
input_6[0][0]
input_7[0][0]
input_8[0][0]
input_9[0][0]
input_10[0][0]
input_11[0][0]
input_12[0][0]
input_13[0][0]
input_14[0][0]
input_15[0][0]
input_16[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (64, 2) 514 graph_gather[0][0]
__________________________________________________________________________________________________
reshape (Reshape) (64, 1, 2) 0 dense_1[0][0]
__________________________________________________________________________________________________
input_4 (InputLayer) [(None,)] 0
__________________________________________________________________________________________________
trim_graph_output (TrimGraphOut (None, 1, 2) 0 reshape[0][0]
input_4[0][0]
__________________________________________________________________________________________________
input_5 (InputLayer) [(None,)] 0
__________________________________________________________________________________________________
softmax (Softmax) (None, 1, 2) 0 trim_graph_output[0][0]
==================================================================================================
Total params: 199,362
Trainable params: 198,850
Non-trainable params: 512
__________________________________________________________________________________________________