49
39

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

CNNによるdepth推定 (CNN SLAM #3)

Posted at

CNN SLAMにおけるCNNを用いたdepth画像の推定についてここでは述べたいと思います。

論文では領域分割のタスクもCNNで解かせていますが、ここではdepth画像の推定のみに絞っていきたいと思います。どちらにせよ、ネットワークの基本的な構成はほとんど変わらないのでCNN SLAMを理解する上では、そこまで気にしなくても良いです。

答えはgithubにあり

神。

これをクローンすればdepth推定ができます!

レシピ
$ git clone https://github.com/iro-cp/FCRN-DepthPrediction.git
$ cd FCRN-DepthPrediction/tensorflow
$ python predict.py <学習済みモデルファイル> <depth推定させたいRGB画像>

tensorflowが必要なのでインストールします。私はPython3.5.3で、tensorflowは1.2.1のを入れました。

学習済みモデルファイルは、上記のgithubのページのREADMEにリンクが貼ってあります。Tensorflow Modelと書いてあるところをクリックすればモデルが手に入るので、こちらで学習させる必要もありません。

depth推定させたい画像はこちらで用意します。例えば、以下のような画像を与えて推定させると…

こういう出力が得られます。

真値がわからないので評価ができないのですが、形としては悪くないように感じられます。

学習をさせたいときは?

学習もさせたいときは、Network.py内のNetworkクラスの引数のtrainableにTrueを入れて、正解画像との差分を計算し、バックプロパゲーションするコードを新たに追加する必要があります。

私個人としてはfcrn.pyにあるような書き方は初めて見たので、いじり方がよくわかりませんでした。というわけで、(私にとって)シンプルな書き方に変えたのが以下になります。

network.py
'''
setup関数はNetworkクラス内のメンバ関数である
'''

    def setup(self, trainable):
        xs = self.inputs['data']

        conv1 = self.conv(xs, 7, 7, 64, 2, 2, relu=False, name='conv1')
        bn_conv1 = self.batch_normalization(conv1, relu=True, name='bn_conv1')
        pool1 = self.max_pool(bn_conv1, 3, 3, 2, 2, name='pool1')
        res2a_branch1 = self.conv(pool1, 1, 1, 256, 1, 1, biased=False, relu=False, name='res2a_branch1')
        bn2a_branch1 = self.batch_normalization(res2a_branch1, name='bn2a_branch1')

        res2a_branch2a = self.conv(pool1, 1, 1, 64, 1, 1, biased=False, relu=False, name='res2a_branch2a')
        bn2a_branch2a = self.batch_normalization(res2a_branch2a, relu=True, name='bn2a_branch2a')
        res2a_branch2b = self.conv(bn2a_branch2a, 3, 3, 64, 1, 1, biased=False, relu=False, name='res2a_branch2b')
        bn2a_branch2b = self.batch_normalization(res2a_branch2b, relu=True, name='bn2a_branch2b')
        res2a_branch2c = self.conv(bn2a_branch2b, 1, 1, 256, 1, 1, biased=False, relu=False, name='res2a_branch2c')
        bn2a_branch2c = self.batch_normalization(res2a_branch2c, name='bn2a_branch2c')

        res2a = self.add((bn2a_branch1, bn2a_branch2c), name='res2a')
        res2a_relu = self.relu(res2a, name='res2a_relu')
        res2b_branch2a = self.conv(res2a_relu, 1, 1, 64, 1, 1, biased=False, relu=False, name='res2b_branch2a')
        bn2b_branch2a = self.batch_normalization(res2b_branch2a, relu=True, name='bn2b_branch2a')
        res2b_branch2b = self.conv(bn2b_branch2a, 3, 3, 64, 1, 1, biased=False, relu=False, name='res2b_branch2b')
        bn2b_branch2b = self.batch_normalization(res2b_branch2b, relu=True, name='bn2b_branch2b')
        res2b_branch2c = self.conv(bn2b_branch2b, 1, 1, 256, 1, 1, biased=False, relu=False, name='res2b_branch2c')
        bn2b_branch2c = self.batch_normalization(res2b_branch2c, name='bn2b_branch2c')

        res2b = self.add((res2a_relu, bn2b_branch2c), name='res2b')
        res2b_relu = self.relu(res2b, name='res2b_relu')
        res2c_branch2a = self.conv(res2b_relu, 1, 1, 64, 1, 1, biased=False, relu=False, name='res2c_branch2a')
        bn2c_branch2a = self.batch_normalization(res2c_branch2a, relu=True, name='bn2c_branch2a')
        res2c_branch2b = self.conv(bn2c_branch2a, 3, 3, 64, 1, 1, biased=False, relu=False, name='res2c_branch2b')
        bn2c_branch2b = self.batch_normalization(res2c_branch2b, relu=True, name='bn2c_branch2b')
        res2c_branch2c = self.conv(bn2c_branch2b, 1, 1, 256, 1, 1, biased=False, relu=False, name='res2c_branch2c')
        bn2c_branch2c = self.batch_normalization(res2c_branch2c, name='bn2c_branch2c')

        res2c = self.add((res2b_relu, bn2c_branch2c), name='res2c')
        res2c_relu = self.relu(res2c, name='res2c_relu')
        res3a_branch1 = self.conv(res2c_relu, 1, 1, 512, 2, 2, biased=False, relu=False, name='res3a_branch1')
        bn3a_branch1 = self.batch_normalization(res3a_branch1, name='bn3a_branch1')

        res3a_branch2a = self.conv(res2c_relu, 1, 1, 128, 2, 2, biased=False, relu=False, name='res3a_branch2a')
        bn3a_branch2a = self.batch_normalization(res3a_branch2a, relu=True, name='bn3a_branch2a')
        res3a_branch2b = self.conv(bn3a_branch2a, 3, 3, 128, 1, 1, biased=False, relu=False, name='res3a_branch2b')
        bn3a_branch2b = self.batch_normalization(res3a_branch2b, relu=True, name='bn3a_branch2b')
        res3a_branch2c = self.conv(bn3a_branch2b, 1, 1, 512, 1, 1, biased=False, relu=False, name='res3a_branch2c')
        bn3a_branch2c = self.batch_normalization(res3a_branch2c, name='bn3a_branch2c')

        res3a = self.add((bn3a_branch1, bn3a_branch2c), name='res3a')
        res3a_relu = self.relu(res3a, name='res3a_relu')
        res3b_branch2a = self.conv(res3a_relu, 1, 1, 128, 1, 1, biased=False, relu=False, name='res3b_branch2a')
        bn3b_branch2a = self.batch_normalization(res3b_branch2a, relu=True, name='bn3b_branch2a')
        res3b_branch2b = self.conv(bn3b_branch2a, 3, 3, 128, 1, 1, biased=False, relu=False, name='res3b_branch2b')
        bn3b_branch2b = self.batch_normalization(res3b_branch2b, relu=True, name='bn3b_branch2b')
        res3b_branch2c = self.conv(bn3b_branch2b, 1, 1, 512, 1, 1, biased=False, relu=False, name='res3b_branch2c')
        bn3b_branch2c = self.batch_normalization(res3b_branch2c, name='bn3b_branch2c')

        res3b = self.add((res3a_relu, bn3b_branch2c), name='res3b')
        res3b_relu = self.relu(res3b, name='res3b_relu')
        res3c_branch2a = self.conv(res3b_relu, 1, 1, 128, 1, 1, biased=False, relu=False, name='res3c_branch2a')
        bn3c_branch2a = self.batch_normalization(res3c_branch2a, relu=True, name='bn3c_branch2a')
        res3c_branch2b = self.conv(bn3c_branch2a, 3, 3, 128, 1, 1, biased=False, relu=False, name='res3c_branch2b')
        bn3c_branch2b = self.batch_normalization(res3c_branch2b, relu=True, name='bn3c_branch2b')
        res3c_branch2c = self.conv(bn3c_branch2b, 1, 1, 512, 1, 1, biased=False, relu=False, name='res3c_branch2c')
        bn3c_branch2c = self.batch_normalization(res3c_branch2c, name='bn3c_branch2c')

        res3c = self.add((res3b_relu, bn3c_branch2c), name='res3c')
        res3c_relu = self.relu(res3c, name='res3c_relu')
        res3d_branch2a = self.conv(res3c_relu, 1, 1, 128, 1, 1, biased=False, relu=False, name='res3d_branch2a')
        bn3d_branch2a = self.batch_normalization(res3d_branch2a, relu=True, name='bn3d_branch2a')
        res3d_branch2b = self.conv(bn3d_branch2a, 3, 3, 128, 1, 1, biased=False, relu=False, name='res3d_branch2b')
        bn3d_branch2b = self.batch_normalization(res3d_branch2b, relu=True, name='bn3d_branch2b')
        res3d_branch2c = self.conv(bn3d_branch2b, 1, 1, 512, 1, 1, biased=False, relu=False, name='res3d_branch2c')
        bn3d_branch2c = self.batch_normalization(res3d_branch2c, name='bn3d_branch2c')

        res3d = self.add((res3c_relu, bn3d_branch2c), name='res3d')
        res3d_relu = self.relu(res3d, name='res3d_relu')
        res4a_branch1 = self.conv(res3d_relu, 1, 1, 1024, 2, 2, biased=False, relu=False, name='res4a_branch1')
        bn4a_branch1 = self.batch_normalization(res4a_branch1, name='bn4a_branch1')

        res4a_branch2a = self.conv(res3d_relu, 1, 1, 256, 2, 2, biased=False, relu=False, name='res4a_branch2a')
        bn4a_branch2a = self.batch_normalization(res4a_branch2a, relu=True, name='bn4a_branch2a')
        res4a_branch2b = self.conv(bn4a_branch2a, 3, 3, 256, 1, 1, biased=False, relu=False, name='res4a_branch2b')
        bn4a_branch2b = self.batch_normalization(res4a_branch2b, relu=True, name='bn4a_branch2b')
        res4a_branch2c = self.conv(bn4a_branch2b, 1, 1, 1024, 1, 1, biased=False, relu=False, name='res4a_branch2c')
        bn4a_branch2c = self.batch_normalization(res4a_branch2c, name='bn4a_branch2c')

        res4a = self.add((bn4a_branch1, bn4a_branch2c), name='res4a')
        res4a_relu = self.relu(res4a, name='res4a_relu')
        res4b_branch2a = self.conv(res4a_relu, 1, 1, 256, 1, 1, biased=False, relu=False, name='res4b_branch2a')
        bn4b_branch2a = self.batch_normalization(res4b_branch2a, relu=True, name='bn4b_branch2a')
        res4b_branch2b = self.conv(bn4b_branch2a, 3, 3, 256, 1, 1, biased=False, relu=False, name='res4b_branch2b')
        bn4b_branch2b = self.batch_normalization(res4b_branch2b, relu=True, name='bn4b_branch2b')
        res4b_branch2c = self.conv(bn4b_branch2b, 1, 1, 1024, 1, 1, biased=False, relu=False, name='res4b_branch2c')
        bn4b_branch2c = self.batch_normalization(res4b_branch2c, name='bn4b_branch2c')

        res4b = self.add((res4a_relu, bn4b_branch2c), name='res4b')
        res4b_relu = self.relu(res4b, name='res4b_relu')
        res4c_branch2a = self.conv(res4b_relu, 1, 1, 256, 1, 1, biased=False, relu=False, name='res4c_branch2a')
        bn4c_branch2a = self.batch_normalization(res4c_branch2a, relu=True, name='bn4c_branch2a')
        res4c_branch2b = self.conv(bn4c_branch2a, 3, 3, 256, 1, 1, biased=False, relu=False, name='res4c_branch2b')
        bn4c_branch2b = self.batch_normalization(res4c_branch2b, relu=True, name='bn4c_branch2b')
        res4c_branch2c = self.conv(bn4c_branch2b, 1, 1, 1024, 1, 1, biased=False, relu=False, name='res4c_branch2c')
        bn4c_branch2c = self.batch_normalization(res4c_branch2c, name='bn4c_branch2c')

        res4c = self.add((res4b_relu, bn4c_branch2c), name='res4c')
        res4c_relu = self.relu(res4c, name='res4c_relu')
        res4d_branch2a = self.conv(res4c_relu, 1, 1, 256, 1, 1, biased=False, relu=False, name='res4d_branch2a')
        bn4d_branch2a = self.batch_normalization(res4d_branch2a, relu=True, name='bn4d_branch2a')
        res4d_branch2b = self.conv(bn4d_branch2a, 3, 3, 256, 1, 1, biased=False, relu=False, name='res4d_branch2b')
        bn4d_branch2b = self.batch_normalization(res4d_branch2b, relu=True, name='bn4d_branch2b')
        res4d_branch2c = self.conv(bn4d_branch2b, 1, 1, 1024, 1, 1, biased=False, relu=False, name='res4d_branch2c')
        bn4d_branch2c = self.batch_normalization(res4d_branch2c, name='bn4d_branch2c')

        res4d = self.add((res4c_relu, bn4d_branch2c), name='res4d')
        res4d_relu = self.relu(res4d, name='res4d_relu')
        res4e_branch2a = self.conv(res4d_relu, 1, 1, 256, 1, 1, biased=False, relu=False, name='res4e_branch2a')
        bn4e_branch2a = self.batch_normalization(res4e_branch2a, relu=True, name='bn4e_branch2a')
        res4e_branch2b = self.conv(bn4e_branch2a, 3, 3, 256, 1, 1, biased=False, relu=False, name='res4e_branch2b')
        bn4e_branch2b = self.batch_normalization(res4e_branch2b, relu=True, name='bn4e_branch2b')
        res4e_branch2c = self.conv(bn4e_branch2b, 1, 1, 1024, 1, 1, biased=False, relu=False, name='res4e_branch2c')
        bn4e_branch2c = self.batch_normalization(res4e_branch2c, name='bn4e_branch2c')

        res4e = self.add((res4d_relu, bn4e_branch2c), name='res4e')
        res4e_relu = self.relu(res4e, name='res4e_relu')
        res4f_branch2a = self.conv(res4e_relu, 1, 1, 256, 1, 1, biased=False, relu=False, name='res4f_branch2a')
        bn4f_branch2a = self.batch_normalization(res4f_branch2a, relu=True, name='bn4f_branch2a')
        res4f_branch2b = self.conv(bn4f_branch2a, 3, 3, 256, 1, 1, biased=False, relu=False, name='res4f_branch2b')
        bn4f_branch2b = self.batch_normalization(res4f_branch2b, relu=True, name='bn4f_branch2b')
        res4f_branch2c = self.conv(bn4f_branch2b, 1, 1, 1024, 1, 1, biased=False, relu=False, name='res4f_branch2c')
        bn4f_branch2c = self.batch_normalization(res4f_branch2c, name='bn4f_branch2c')

        res4f = self.add((res4e_relu, bn4f_branch2c), name='res4f')
        res4f_relu = self.relu(res4f, name='res4f_relu')
        res5a_branch1 = self.conv(res4f_relu, 1, 1, 2048, 2, 2, biased=False, relu=False, name='res5a_branch1')
        bn5a_branch1 = self.batch_normalization(res5a_branch1, name='bn5a_branch1')

        res5a_branch2a = self.conv(res4f_relu, 1, 1, 512, 2, 2, biased=False, relu=False, name='res5a_branch2a')
        bn5a_branch2a = self.batch_normalization(res5a_branch2a, relu=True, name='bn5a_branch2a')
        res5a_branch2b = self.conv(bn5a_branch2a, 3, 3, 512, 1, 1, biased=False, relu=False, name='res5a_branch2b')
        bn5a_branch2b = self.batch_normalization(res5a_branch2b, relu=True, name='bn5a_branch2b')
        res5a_branch2c = self.conv(bn5a_branch2b, 1, 1, 2048, 1, 1, biased=False, relu=False, name='res5a_branch2c')
        bn5a_branch2c = self.batch_normalization(res5a_branch2c, name='bn5a_branch2c')

        res5a = self.add((bn5a_branch1, bn5a_branch2c), name='res5a')
        res5a_relu = self.relu(res5a, name='res5a_relu')
        res5b_branch2a = self.conv(res5a_relu, 1, 1, 512, 1, 1, biased=False, relu=False, name='res5b_branch2a')
        bn5b_branch2a = self.batch_normalization(res5b_branch2a, relu=True, name='bn5b_branch2a')
        res5b_branch2b = self.conv(bn5b_branch2a, 3, 3, 512, 1, 1, biased=False, relu=False, name='res5b_branch2b')
        bn5b_branch2b = self.batch_normalization(res5b_branch2b, relu=True, name='bn5b_branch2b')
        res5b_branch2c = self.conv(bn5b_branch2b, 1, 1, 2048, 1, 1, biased=False, relu=False, name='res5b_branch2c')
        bn5b_branch2c = self.batch_normalization(res5b_branch2c, name='bn5b_branch2c')

        res5b = self.add((res5a_relu, bn5b_branch2c), name='res5b')
        res5b_relu = self.relu(res5b, name='res5b_relu')
        res5c_branch2a = self.conv(res5b_relu, 1, 1, 512, 1, 1, biased=False, relu=False, name='res5c_branch2a')
        bn5c_branch2a = self.batch_normalization(res5c_branch2a, relu=True, name='bn5c_branch2a')
        res5c_branch2b = self.conv(bn5c_branch2a, 3, 3, 512, 1, 1, biased=False, relu=False, name='res5c_branch2b')
        bn5c_branch2b = self.batch_normalization(res5c_branch2b, relu=True, name='bn5c_branch2b')
        res5c_branch2c = self.conv(bn5c_branch2b, 1, 1, 2048, 1, 1, biased=False, relu=False, name='res5c_branch2c')
        bn5c_branch2c = self.batch_normalization(res5c_branch2c, name='bn5c_branch2c')

        res5c = self.add((res5b_relu, bn5c_branch2c), name='res5c')
        res5c_relu = self.relu(res5c, name='res5c_relu')
        layer1 = self.conv(res5c_relu, 1, 1, 1024, 1, 1, biased=True, relu=False, name='layer1')
        layer1_BN = self.batch_normalization(layer1, relu=False, name='layer1_BN')
        layer2 = self.up_project(layer1_BN, [3, 3, 1024, 512], id='2x', stride=1, BN=True)
        layer3 = self.up_project(layer2, [3, 3, 512, 256], id='4x', stride=1, BN=True)
        layer4 = self.up_project(layer3, [3, 3, 256, 128], id='8x', stride=1, BN=True)
        layer5 = self.up_project(layer4, [3, 3, 128, 64], id='16x', stride=1, BN=True)
        layer5_drop = self.dropout(layer5, name='drop', keep_prob=1.)
        self.predict = self.conv(layer5_drop, 3, 3, 1, 1, 1, name='ConvPred')

        if trainable:
            ts = self.inputs['true']
            differ = tf.subtract(x=self.predict, y=ts)
            abs_differ = tf.abs(differ)
            self.loss = tf.reduce_mean(abs_differ, name='loss')
            self.train_step = tf.train.GradientDescentOptimizer(0.001).minimize(self.loss)

self.inputsは辞書型で、trainable==Trueのときは入力画像を格納するplaceholderと正解画像を格納するplaceholderがあり、Falseのときは入力画像を格納するplaceholderのみがあればプログラムがちゃんと動くようにしています。

まだ調査中ですが、損失関数によって推定精度に違いが出るようです。追加学習させるときは最も精度の出るような損失関数を採用したいですね。

49
39
4

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
49
39

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?