LoginSignup
101
63

More than 5 years have passed since last update.

誰でもできる深度推定 ~Depth Map Prediction in TensorFlow from beginning to end~

Last updated at Posted at 2016-12-18

誰でも出来る深度推定 ~Depth Map Prediction in TensorFlow from beginning to end~

これは、TensorFlow Advent Calendar 2016 19日目の記事です。関連記事は目次にまとめられています。

はじめに

TensorFlowは汎用的な機械学習フレームワークですが、特にニューラルネットワークを記述するために便利なAPIをたくさん備えています。
今回はTensorFlowを使って、チュートリアルでは扱っていない構造の新しいネットワークを構築し、学習から推定までをやってみます。わかりやすいよう、解説とコードをセットで進めていきます。

テーマ

深度推定のネットワークをテーマとし、実装の流れがシンプルな以下を選びました。

Depth Map Prediction from a Single Image using a Multi-Scale Deep Network
https://arxiv.org/abs/1406.2283

2014年の論文で、1枚の画像からカメラと物体との距離を推定する深度推定を行うためのアーキテクチャです。
少し古いのですが、この論文を機に深度推定は人気が出た(だと思っているのですが)ので、きっかけとして興味深い論文であると思っています。
今月開催された NIPS2016 でも複数アクセプトされています。

では、簡単にアーキテクチャを見てみましょう。

アーキテクチャ

非常にシンプルな畳込みネットワークの 2 段構成です。
入力は1枚の画像、出力としてピクセル単位で深度を推定した画像を出力します。

qiita-1.png

1段目を Coarse ネットワーク、2段目を Refine ネットワークと呼び、1段ずつ学習します。
Refine ネットワークには出力のエッジを強調する狙いがあるため、後半では畳み込みのstrideを1にしてサイズを保持しています。

早速実装してみます。

実装

TensorFlowによる実装では以下の4つのパートに分けると簡単に整理ができます。

  • データの入出力
  • ニューラルネットワークのアーキテクチャ
  • 誤差関数
  • 確率的な学習ループ

データの入出力

入力データとなる画像と教師画像を準備します。

例えば、NYU Depth Dataset V2 が手頃で良いと思います。
http://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html
オリジナルのRGB画像と、深度を計測して画像にしたファイルがセットになっています。

NYU Depth V2 セットはmat形式で提供されているので、8-bit pngに変換しておきます。

mat2png.py

import os
import numpy as np
import h5py
from PIL import Image


def convert_nyu(path):
    print("load dataset: %s" % (path))
    f = h5py.File(path)

    with open('train.csv', 'w') as output:
        for i, (image, depth) in enumerate(zip(f['images'], f['depths'])):
            ra_image = image.transpose(2, 1, 0)
            ra_depth = depth.transpose(1, 0)
            re_depth = (ra_depth/np.max(ra_depth))*255.0
            image_pil = Image.fromarray(np.uint8(ra_image))
            depth_pil = Image.fromarray(np.uint8(re_depth))
            image_name = os.path.join("data", "nyu_datasets", "%05d.jpg" % (i))
            image_pil.save(image_name)
            depth_name = os.path.join("data", "nyu_datasets", "%05d.png" % (i))
            depth_pil.save(depth_name)
            output.write("%s,%s" % (image_name, depth_name))
            output.write("\n")


if __name__ == '__main__':
    current_directory = os.getcwd()
    nyu_path = 'data/nyu_depth_v2_labeled.mat'
    convert_nyu(nyu_path)

準備した画像のリストをCSVファイルとして保存する処理も記述しておきました。

さて、TensorFlowには、自分でミニバッチ生成を記述しなくても良い便利な仕組みが備わっています。
それでもいくつかの選択肢が残されており、その中でも便利な2つの方法があります。

  • ファイルにデータのリストを記述しておく
  • tfrecords形式にデータを変換しておく

tfrecords形式とは、TensorFlowが公式にサポートしているデータの保存形式です。v.0.10から画像の圧縮もサポートされて便利になりました。
しかしtfrecordsにするために一手間必要なので、ここでは先程出力したCSVファイルを使ってミニバッチを作成する処理を行います。

dataset.py
import tensorflow as tf
from tensorflow.python.platform import gfile
import numpy as np
from PIL import Image

IMAGE_HEIGHT = 228
IMAGE_WIDTH = 304
TARGET_HEIGHT = 55
TARGET_WIDTH = 74

class DataSet:
    def __init__(self, batch_size):
        self.batch_size = batch_size

    def csv_inputs(self, csv_file_path):
        filename_queue = tf.train.string_input_producer([csv_file_path], shuffle=True)
        reader = tf.TextLineReader()
        _, serialized_example = reader.read(filename_queue)
        filename, depth_filename = tf.decode_csv(serialized_example, [["path"], ["annotation"]])
        # input
        jpg = tf.read_file(filename)
        image = tf.image.decode_jpeg(jpg, channels=3)
        image = tf.cast(image, tf.float32)       
        # target
        depth_png = tf.read_file(depth_filename)
        depth = tf.image.decode_png(depth_png, channels=1)
        depth = tf.cast(depth, tf.int64)
        # resize
        image = tf.image.resize_images(image, (IMAGE_HEIGHT, IMAGE_WIDTH))
        depth = tf.image.resize_images(depth, (TARGET_HEIGHT, TARGET_WIDTH))
        invalid_depth = tf.sign(depth)
        # generate batch
        images, depths, invalid_depths = tf.train.batch(
            [image, depth, invalid_depth],
            batch_size=self.batch_size,
            num_threads=4,
            capacity= 50 + 3 * self.batch_size
        )
        return images, depths, invalid_depths


def output_predict(depths, images, output_dir):
    print("output predict into %s" % len(output_dir))
    if not gfile.Exists(output_dir):
        gfile.MakeDirs(output_dir)
    for i, (image, depth) in enumerate(zip(images, depths)):
        pilimg = Image.fromarray(np.uint8(image))
        image_name = "%s/%05d_org.png" % (output_dir, i)
        pilimg.save(image_name)
        depth = depth.transpose(2, 0, 1)
        if np.max(depth) != 0:
            ra_depth = (depth/np.max(depth))*255.0
        else:
            ra_depth = depth*255.0
        depth_pil = Image.fromarray(np.uint8(ra_depth[0]), mode="L")
        depth_name = "%s/%05d.png" % (output_dir, i)
        depth_pil.save(depth_name)

深度の教師として無効な点も同時に計算します。結果の画像を出力するコードも書いておきます。

ニューラルネットワークのアーキテクチャ

TensorFlowでは、最初は以下の点に注意すると良いと思います。

  • スコープを使う
  • 再利用可能なように工夫しておく

これらを意識するだけで、アーキテクチャの構造がコードから汲み取りやすくなります。
まずは、再利用可能なスコープ付きレイヤー関数を定義します。

model_part.py
import tensorflow as tf

TOWER_NAME = 'tower'
UPDATE_OPS_COLLECTION = '_update_ops_'


def _variable_with_weight_decay(name, shape, stddev, wd, trainable=True):
    var = _variable_on_gpu(name, shape, tf.truncated_normal_initializer(stddev=stddev))
    if wd:
        weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
        tf.add_to_collection('losses', weight_decay)
    return var


def _variable_on_gpu(name, shape, initializer):
    var = tf.get_variable(name, shape, initializer=initializer)
    return var


def conv2d(scope_name, inputs, shape, bias_shape, stride, padding='VALID', wd=0.0, reuse=False, trainable=True):
    with tf.variable_scope(scope_name) as scope:
        if reuse is True:
            scope.reuse_variables()
        kernel = _variable_with_weight_decay(
            'weights',
            shape=shape,
            stddev=0.01,
            wd=wd,
            trainable=trainable
        )
        conv = tf.nn.conv2d(inputs, kernel, stride, padding=padding)
        biases = _variable_on_gpu('biases', bias_shape, tf.constant_initializer(0.1))
        bias = tf.nn.bias_add(conv, biases)
        conv_ = tf.nn.relu(bias, name=scope.name)
        return conv_


def fc(scope_name, inputs, shape, bias_shape, wd=0.04, reuse=False, trainable=True):
    with tf.variable_scope(scope_name) as scope:
        if reuse is True:
            scope.reuse_variables()
        flat = tf.reshape(inputs, [-1, shape[0]])
        weights = _variable_with_weight_decay(
            'weights',
            shape,
            stddev=0.01,
            wd=wd,
            trainable=trainable
        )
        biases = _variable_on_gpu('biases', bias_shape, tf.constant_initializer(0.1))
        fc = tf.nn.relu_layer(flat, weights, biases, name=scope.name)
        return fc

基本的なレイヤー(畳込みと全結合層)しか定義していませんが、今回のアーキテクチャの範囲では十分(活性化関数が固定されていることに注意してください)です。次に、アーキテクチャ全体を記述します。

model.py
import tensorflow as tf
import math
from model_part import conv2d
from model_part import fc

def inference(images, reuse=False, trainable=True):
    coarse1_conv = conv2d('coarse1', images, [11, 11, 3, 96], [96], [1, 4, 4, 1], padding='VALID', reuse=reuse, trainable=trainable)
    coarse1 = tf.nn.max_pool(coarse1_conv, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='VALID', name='pool1')
    coarse2_conv = conv2d('coarse2', coarse1, [5, 5, 96, 256], [256], [1, 1, 1, 1], padding='VALID', reuse=reuse, trainable=trainable)
    coarse2 = tf.nn.max_pool(coarse2_conv, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1')
    coarse3 = conv2d('coarse3', coarse2, [3, 3, 256, 384], [384], [1, 1, 1, 1], padding='VALID', reuse=reuse, trainable=trainable)
    coarse4 = conv2d('coarse4', coarse3, [3, 3, 384, 384], [384], [1, 1, 1, 1], padding='VALID', reuse=reuse, trainable=trainable)
    coarse5 = conv2d('coarse5', coarse4, [3, 3, 384, 256], [256], [1, 1, 1, 1], padding='VALID', reuse=reuse, trainable=trainable)
    coarse6 = fc('coarse6', coarse5, [6*10*256, 4096], [4096], reuse=reuse, trainable=trainable)
    coarse7 = fc('coarse7', coarse6, [4096, 4070], [4070], reuse=reuse, trainable=trainable)
    coarse7_output = tf.reshape(coarse7, [-1, 55, 74, 1])
    return coarse7_output


def inference_refine(images, coarse7_output, keep_conv, reuse=False, trainable=True):
    fine1_conv = conv2d('fine1', images, [9, 9, 3, 63], [63], [1, 2, 2, 1], padding='VALID', reuse=reuse, trainable=trainable)
    fine1 = tf.nn.max_pool(fine1_conv, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='fine_pool1')
    fine1_dropout = tf.nn.dropout(fine1, keep_conv)
    fine2 = tf.concat(3, [fine1_dropout, coarse7_output])
    fine3 = conv2d('fine3', fine2, [5, 5, 64, 64], [64], [1, 1, 1, 1], padding='SAME', reuse=reuse, trainable=trainable)
    fine3_dropout = tf.nn.dropout(fine3, keep_conv)
    fine4 = conv2d('fine3', fine3_dropout, [5, 5, 64, 1], [1], [1, 1, 1, 1], padding='SAME', reuse=reuse, trainable=trainable)
    return fine4

inference に coarseネットワーク、inference_refine に refineネットワークを記述しました。
関数化したことによって、ネットワークの構造が見やすくなりました。

誤差関数

TensorFlowはバージョンを追うごとに誤差関数を実装するためのAPIも便利になっています。

model.py
def loss(logits, depths, invalid_depths):
    logits_flat = tf.reshape(logits, [-1, 55*74])
    depths_flat = tf.reshape(depths, [-1, 55*74])
    invalid_depths_flat = tf.reshape(depths, [-1, 55*74])
    predict = tf.mul(logits_flat, invalid_depths_flat)
    target = tf.mul(depths_flat, invalid_depths_flat)
    d = tf.sub(predict, target)
    square_d = tf.square(d)
    sum_square_d = tf.reduce_sum(square_d, 1)
    sum_d = tf.reduce_sum(d, 1)
    sqare_sum_d = tf.square(sum_d)
    cost = tf.reduce_mean(sum_square_d / 55*74 - FLAGS.si_lambda*sqare_sum_d / math.pow(55*74, 2))
    return cost

loss関数はニューラルネットワークの目標との差を計算する尺度です。TensorFlowはミニバッチを前提としたloss計算に対応していますので、ミニバッチを意識せずに記述することが可能です。教師として無効な画素への対応も忘れずに行います。

確率的な学習ループ

バックプロパゲーションによる微分の評価と勾配法のループを実行するコードを記述します。
はじめて計算グラフにデータを流します。誤差関数までのコードでは、グラフを構築しているだけで計算していません。

普通は別の勾配法によるデバッグが必要ですが、TensorFlowのコードを信頼して記述するとすごく短くてすみます。

今回のネットワークは 2 段階で学習することを思いだします。実現するにはいくつか方法が考えられますが、ここでは理解し易いように愚直にそれぞれの実行コードを分けます。別に、勾配の伝播を止めることでも実現可能でよりスマートに記述できると思います。

スコープを使ったので、以下のように必要なパラメータを選択できます。

task.py
...
        # parameters
        coarse_params = {}
        refine_params = {}
        if REFINE_TRAIN:
            for variable in tf.all_variables():
                variable_name = variable.name
                if variable_name.find("/") < 0 or variable_name.count("/") != 1:
                    continue
                if variable_name.find('coarse') >= 0:
                    coarse_params[variable_name] = variable
                if variable_name.find('fine') >= 0:
                    refine_params[variable_name] = variable
        else:
            for variable in tf.trainable_variables():
                variable_name = variable.name
                if variable_name.find("/") < 0 or variable_name.count("/") != 1:
                    continue
                if variable_name.find('coarse') >= 0:
                    coarse_params[variable_name] = variable
                if variable_name.find('fine') >= 0:
                    refine_params[variable_name] = variable
...

パラメータの選択を含んだコード全体は以下のようになります。

task.py
from datetime import datetime
from tensorflow.python.platform import gfile
import numpy as np
import tensorflow as tf
from dataset import DataSet
from dataset import output_predict
import model
import train_operation as op

MAX_STEPS = 10000000
LOG_DEVICE_PLACEMENT = False
BATCH_SIZE = 5
TRAIN_FILE = "train.csv"
COARSE_DIR = "coarse"
REFINE_DIR = "refine"

REFINE_TRAIN = False
FINE_TUNE = False

def train():
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        dataset = DataSet(BATCH_SIZE)
        images, depths, invalid_depths = dataset.csv_inputs(TRAIN_FILE)

        keep_conv = tf.placeholder(tf.float32)
        keep_hidden = tf.placeholder(tf.float32)

        if REFINE_TRAIN:
            print("refine train.")
            coarse = model.inference(images, keep_conv, trainable=False)
            logits = model.inference_refine(images, coarse, keep_conv, keep_hidden)
        else:
            print("coarse train.")
            logits = model.inference(images, keep_conv, keep_hidden)
        # loss
        loss = model.loss(logits, depths, invalid_depths)
        # train operation
        train_op = op.train(loss, global_step, BATCH_SIZE)
        # init operation
        init_op = tf.initialize_all_variables()

        # Session
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=LOG_DEVICE_PLACEMENT))
        sess.run(init_op)

        # parameters
        coarse_params = {}
        refine_params = {}
        if REFINE_TRAIN:
            for variable in tf.all_variables():
                variable_name = variable.name
                if variable_name.find("/") < 0 or variable_name.count("/") != 1:
                    continue
                if variable_name.find('coarse') >= 0:
                    coarse_params[variable_name] = variable
                if variable_name.find('fine') >= 0:
                    refine_params[variable_name] = variable
        else:
            for variable in tf.trainable_variables():
                variable_name = variable.name
                if variable_name.find("/") < 0 or variable_name.count("/") != 1:
                    continue
                if variable_name.find('coarse') >= 0:
                    coarse_params[variable_name] = variable
                if variable_name.find('fine') >= 0:
                    refine_params[variable_name] = variable
        # define saver
        print coarse_params
        saver_coarse = tf.train.Saver(coarse_params)
        if REFINE_TRAIN:
            saver_refine = tf.train.Saver(refine_params)

        # train
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for step in xrange(MAX_STEPS):
            index = 0
            for i in xrange(1000):
                _, loss_value, logits_val, images_val = sess.run([train_op, loss, logits, images], feed_dict={keep_conv: 0.8, keep_hidden: 0.5})
                if index % 10 == 0:
                    print("%s: %d[epoch]: %d[iteration]: train loss %f" % (datetime.now(), step, index, loss_value))
                    assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
                if index % 500 == 0:
                    if REFINE_TRAIN:
                        output_predict(logits_val, images_val, "data/predict_refine_%05d_%05d" % (step, i))
                    else:
                        output_predict(logits_val, images_val, "data/predict_%05d_%05d" % (step, i))
                index += 1

            if step % 5 == 0 or (step * 1) == MAX_STEPS:
                if REFINE_TRAIN:
                    refine_checkpoint_path = REFINE_DIR + '/model.ckpt'
                    saver_refine.save(sess, refine_checkpoint_path, global_step=step)
                else:
                    coarse_checkpoint_path = COARSE_DIR + '/model.ckpt'
                    saver_coarse.save(sess, coarse_checkpoint_path, global_step=step)
        coord.request_stop()
        coord.join(threads)
        sess.close()

ほとんど完成ですが、coarseとrefineネットワークを順番に学習するためのパラメータロード部分の記述を加えます。

task.py
...
        # fine tune
        if FINE_TUNE:
            coarse_ckpt = tf.train.get_checkpoint_state(COARSE_DIR)
            if coarse_ckpt and coarse_ckpt.model_checkpoint_path:
                print("Pretrained coarse Model Loading.")
                saver_coarse.restore(sess, coarse_ckpt.model_checkpoint_path)
                print("Pretrained coarse Model Restored.")
            else:
                print("No Pretrained coarse Model.")
            refine_ckpt = tf.train.get_checkpoint_state(REFINE_DIR)
            if refine_ckpt and refine_ckpt.model_checkpoint_path:
                print("Pretrained refine Model Loading.")
                saver_refine.restore(sess, refine_ckpt.model_checkpoint_path)
                print("Pretrained refine Model Restored.")
            else:
                print("No Pretrained refine Model.")
...

完成しました。学習してみましょう。

実行

コード全体はgithubから取得して動かすことが出来ます。
https://github.com/MasazI/cnn_depth_tensorflow

学習とその過程で推定結果を出力します。

学習結果

学習が進む過程で、どのように推定できるか確認してみましょう。

学習初期

  • 入力
    initail
  • 推定
    initail_p

学習が進む過程

  • 入力
    interm
  • 推定
    interm_p

少しずつ大まかな近い、遠いがわかるようになってきました。

学習後期

  • 入力
    interm
  • 推定
    interm_p

はっきりとしてきました。

補足

githubのコードはテストケースを含んでいませんが、ほとんど学習と同様に記述可能ですし、メタデータから構築すればコード量を削減できます。
画像1枚に対する推定で、画像中のコンテキストなどを利用しない場合であれば、この枠組で解ける問題も多いでしょう。

一方で、今回紹介したアーキテクチャと同じ枠組みでは解けない問題も多いとともに、テーマに絞って考えてみても、複雑な状況や屋外ではうまく推定できませんから、広く実用的なものにするためには改善が必要です。TensorFlowを使うと、改善に関する本質的な部分に注力できることが多いため、実験段階で使うメリットは十分にあると考えています。

また、今回のコードはGoogleが提供するTensorFlowの学習・推定プラットフォーム Cloud Machine Learning に組み込みやすい構造になっていますので、比較的簡単に試すことが可能だと思います。

おわりに

今回は深度推定する簡単なネットワークを記述してきましたが、TensorFlow はより興味深いニューラルネットワーク、確率分布の近似・推定、強化学習、分散処理などに応用ができるパワフルなフレームワークです。皆さんで楽しんで使っていきましょう。

追記 6th Nov 2018

長らくTensorFlowのバージョンが古いことが原因で、使えないと言ったISSUEをいただいたのですが、Githubのリポジトリにおいては親切に新しいTensorFlowに対応したコードをプルリクエストいただきました。ありがとうございます。
Thank you for your informative PR, Bernhard!

101
63
5

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
101
63