1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

kerasで事前学習済みモデルをTimeDistributed()と併用する

Last updated at Posted at 2017-09-18

このポストでは、kerasの使い方でも少しアドバンストな、以下の要素をカバーします:

  • 動画の各時点の処理には、事前学習済みのモデルを使用
  • TimeDistributed()で、動画を処理
  • 事前学習済みモデルは多出力となっているものを1つにまとめることで、TimeDistributed()に接続可能なようにする

事前学習済みモデルの使い方

Kerasで学習済みモデルを提供してくれている方はたくさんいらっしゃいます。ここでは例として、
https://github.com/michalfaber/keras_Realtime_Multi-Person_Pose_Estimation
のモデルを使ってみます。事前学習済みモデルも同ページのリンクから取得出来、model/model.h5として保存しておきます。

モデルを含むモジュールの定義

まず、学習済みモデルのアーキテクチャをひとつのモジュール(deepPose.py)として定義します。

np_branch1 = 38
np_branch2 = 19
stages = 6
graph = None

def vgg_block(...):
    ...
def stage1_block(...):
    ...
def stageT_block(...):
    ...

def deepPose(input_shape=None,weights_path=None):
    global graph

    img_input = Input(shape=input_shape,name="img_input")

    # VGG
    stage0_out = vgg_block(img_input)

    # stage 1
    stage1_branch1_out = stage1_block(stage0_out, np_branch1, 1)
    stage1_branch2_out = stage1_block(stage0_out, np_branch2, 2)
    x = Concatenate()([stage1_branch1_out, stage1_branch2_out, stage0_out])

    # stage t >= 2
    for sn in range(2, stages + 1): 
        stageT_branch1_out = stageT_block(x, np_branch1, sn, 1)
        stageT_branch2_out = stageT_block(x, np_branch2, sn, 2)
        if (sn < stages):
            x = Concatenate()([stageT_branch1_out, stageT_branch2_out, stage0_out])

    model = Model(inputs=img_input, outputs=[stageT_branch1_out, stageT_branch2_out, stage0_out])

    if weights_path:
        print "pre-trained parameter being loaded from %s"%weights_path
        model.load_weights(weights_path)
        print "...done"
    graph = tf.get_default_graph() # もし、fit_generator()の中などのでバッチを作る際、一瞬グラフを切り替えないと行けないことがある。その際に利用する

    return model

特に注意するべきなのは以下の点です

全体としての構造

最終的に、Model()として定義を行い、それをreturnする必要があります。そのため、Inputを定め、最後にアウトプットとなるレイヤーまで作成を行います。

パラメータの読み込み

パラメータの読み込みは、モデル定義の中で行うのが良いと思います。これを、もっと後から読み込もうと思うと、おそらくモデルと保存されたパラメータの構造が異なってしまうため、うまく読み込めません。名前も変わってくるだろうと思いますので、モデルの中で読み込んでしまうことが推奨されます。なお、kerasで提供しているデフォルトのvggモジュールなども、同様にモデル定義の直後での読み込みになっています。

保存したグラフの使い方

中にコメントがありますが、get_default_graph()を保存しておく必要が出てくる場合があります。一番あるのは、fit_generator()の中でデータをロードする際、別のNNで前処理をかけようとする場合です。

上のサンプルですと、これは人間の動きを検出して特徴量へ変換することをしてくれるので、場合によっては、前処理に含めたいこともあるかと思います。そのような場合は、グラフを切り替えないとプログラムが止まる問題に突き当たります。対処法は、以下のように、as_default()で一時的にグラフを切り替えることです:

with graph.as_default():
    batchY = premodel.predict(np.reshape(batchX,(-1,self.sizeY,self.sizeX,self.nColor)))

モジュールの使い方

ここが一番、kerasの特徴的なところになります。

dp = deepPose(input_shape=input_shape,weights_path="model/model.h5")

inX1 = Input(shape=(self.sizeY,self.sizeX,self.nColor),name="inX1")
r1,r2,r3 = dp(inX1)

inX2 = Input(shape=(self.sizeY,self.sizeX,self.nColor),name="inX2")
l1,l2,l3 = dp(inX2)

Functional API特有の「2段階」での定義の仕方

上の例を見ますと、以下のようにしてテンソルが作られていることがわかります

  • 「関数」を定義する。この際に、設定項目やパラメータの読み込みに必要な情報を定義してしまう
  • 定義された「関数」を元に、テンソル変換を定義する

この2段階の考え方が、まさにfunctional APIのやり方になります。若干煩雑かもしれませんが、以下のようなメリットも生んでいます

同一「関数」の使い回しによる、パラメータ共有

上の例だと、同じdpというモデルが2回別のテンソルに適用されています。この2つの「関数」は、パラメータが共有されるように内部処理されるという性質があり、様々な場面で活躍します。例えば、Siamese NNなどを構築する際には、まさに上記のように定義を行うことで実現されます。

TimeDistributed()の使い方

TimeDistributed()は、上記のパラメータ共有した関数を、時間軸方向に適用するための関数です。例えば、

h = Conv3D(self.nActivities,(5,3,3), activation="relu",padding="same")(h)
h = TimeDistributed(GlobalMaxPooling2D())(h)
h = GlobalAveragePooling1D()(h)

などが代表的な使い方です。
今回のdeepPoseに関して使ってみると以下のようになります。


inX = Input(shape=(self.nLength,self.sizeY,self.sizeX,self.nColor),name="inX")
h = TimeDistributed(dp,input_shape=K.int_shape(inX))(inX)

などとして使うことができます。

尚、dpの出力は3つのテンソルです。その場合、上のようなコードはエラーを出して止まります。なぜなら、TimeDistributed()はあくまで1つのテンソルが出力されるのを期待するからです。そのためには、以下のように、もうひと手間をかけてテンソルの出力をマージしてあげることが必要になります。

新しい関数を作って、多出力をまとめる

すでに新しい関数を作る方法を学んでいますので、ここではその手法に従って、3つのテンソルを1つにまとめたものを定義しましょう。

inX = Input(shape=(self.nLength,self.sizeY,self.sizeX,self.nColor),name="inX")

def merged_dp(input_shape):
    inX1 = Input(shape=input_shape)
    dp = deepPose(input_shape=input_shape,weights_path="model/model.h5")
    r1,r2,r3 = dp(inX1)
    merged = Concatenate()([r1,r2,r3])
    return Model(inputs=inX1,outputs=merged)

mydp = merged_dp(input_shape = K.int_shape(inX)[2:])

h = TimeDistributed(mydp,input_shape=K.int_shape(inX))(inX)
```

以上のようにすることで共通のパラメータをコピーした```deepPose```を各時間に適用した後他のNNへと接続することができるようになります
1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?