LoginSignup
19
17

More than 3 years have passed since last update.

単眼深度推定モデル MiDaS の解説と SageMaker へのデプロイ

Last updated at Posted at 2021-01-19

TL;DR

  • 単眼深度推定モデル MiDaS の論文を少し解説します。
  • MiDaS の PyTorch モデルを SageMaker でデプロイして試してみました。
  • 推定した深度から法線(傾き)も計算してみました。
  • SageMaker デプロイから法線推定までのコードはこちらです。

MiDaS の概要

論文のタイトルのとおりなんですが、複数種のデータセットで学習された、Zero-shot (Fine-tuning なし) で使える単眼深度推定モデルです。MiDaS v2.1は10個のデータセット (ReDWeb, DIML, Movies, MegaDepth, WSVD, TartanAir, HRWSI, ApolloScape, BlendedMVS, IRS) で学習されています。これまでだと、自動運転などに向けた屋外の深度推定や、SLAM などに向けた屋内の深度推定といった用途ごとにモデルを作っているケースがありますが、MiDaS だと1つで完結するので便利です。

R. Ranftl, K. Lasinger, D. Hafner, K. Schindler, V. Koltun,
Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer,
TPAMI 2020, https://arxiv.org/abs/1907.01341

実際、PyTorch の MiDas のページ (https://pytorch.org/hub/intelisl_midas_v2/) をみるとこんな図が示されています。確かに屋内・屋外問わず対応していて便利そうです!しかも、Zero-shot、つまりは Fine-tuning なしで、いきなりこうした結果をえることができます。

スクリーンショット 2021-01-19 11.34.45.png

MiDaS の論文解説

特にこの論文の本編である3章から5章までを解説したいと思います。これらをざっくりと説明すると

  • 3章: 既存のデータセットもたくさんあって、そもそも深度の測り方がぜんぜん違うよ
  • 4章: 3D Movies を加えてデータを増やしたいけど、正確な深度を得られないよ
  • 5章:上記の多様なデータに対して学習するため、ロス関数と最適化を工夫するよ といったものです。

3章 既存データについて

Computer Vision をやっている人には馴染みがあるかもしれませんが、深度推定のデータセットは結構バラバラです。

  • 環境(屋内、屋外)や対象(静止物体、移動物体)
  • 深度の数値的な特徴(Sparse/Dense, 絶対/相対距離)
  • 測定・アノテーションに依存する精度(レーザー測距、ToF: Time-of-Flight, SfM, Stereo, 人手のアノテーション、合成データ)

例えばですが、人間は2つの目を使って奥行きを検知しますが、それは相対的なものです。物に手を伸ばすときに、物と手が近いかどうかを認識して、物を手に持つことができます。ただ、そのとき、物までの距離が何cmとかはわからないです。このアイデアを使った Stereo のデータセットは、データが相対的な距離として示されます。一方で、レーザー測距だと絶対距離を測ることができます。
深度データは精度も課題です。例えば、光をあてて距離を測る TOF だと、光を吸収する黒色の物体に対して、正しい距離を測るのは難しいです。

4章 3D Movies について

3D Movies は、左目・右目用の映像が同時に投影されて、専用のメガネをかけてフィルターして3次元っぽく見えるというものですね。つまりは、上で述べたようなステレオの原理で3次元ぽく見えるようにしている動画です。左目・右目用の映像は、奥行きを見せるために若干ずれています。このずれた画像に対して、Stereo matching を使うことで、画像のどこがどの程度ずれているのかというを知ることができます。このずれ (視差, Disparity) は相対的な深さに変換でき、視差が大きいほど近いことを表します。例えば、目の近くにものをおいて、右目を閉じて、左目を閉じて、とすると場所が大きく変わって見えますよね。

上記のようにStereo Matching を使えば、3D Movies からそのまま相対的な深度を得られそうですが、論文によるとそうではないようです。理由として、3D Movies は見た目の良さを重視して、視差を調整しているようなのです。具体的には、視差のばらつきが大きくならないよう、一定の範囲に抑えているようです。

正しいDisparityを利用するため、この論文では Optical Flow を使っています。Optical Flow は、2枚のずれた画像(左画像と右画像と呼びます)において、左画像における任意のピクセルが右画像のどのピクセルに対応するかを特定し、その差を計算します。Disparity を計算する Stereo Matching と似ていますが、Stereo Matching は差 (Disparity) が正であることを仮定しますが、Optical Flow では差 (Flow) が負であっても扱えるとのことです。そこで、左画像から右画像への差$D_{L\rightarrow R}$と、右画像から左画像への差$D_{R\rightarrow L}$を、それぞれ Optical Flow で計算して、その差が大きいものはデータから除外しています (Left-Right Consistency Check)。てっきり Disparity をチェックするのかと思ったのですが、"We retain the horizontal component of the flow as a proxy for disparity." とあるので、Disparity の代わりに Flow でチェックするってことですね。
最後にダメ押し的に、Semantic Segmentationを使って空を検出して、とても遠くにある空の Disparity を小さくしています。

5章 多様なデータに対する学習

この論文では、データセットの多様性の中で、特に以下の3つの課題にアプローチしています。

  1. 深度の表現方法の違い (Direct vs Inverse Depth)
  2. スケールの曖昧性(スケールが与えられていないデータセットがある)
  3. シフトの曖昧性(映像の後処理で、シフトと呼ぶ定数のDisparityが与えられている)

これに対していくつかのロスと最適化方法を提案しています。

Scale- and shift-invariant loss

スケールとシフトが未知の場合に備えて、スケールとシフトを推定・調整しながら計算するロスです。
まず普通のロスだと推定した Disparity $d$と正解のDisparity $d^* $の差になりますね。しかし、それぞれスケールやシフトを補正しないといけない可能性があるので、補正済みのものを$\hat{d}, \hat{d^*}$として表しておき、その差にもとづくロス関数を以下のように定義しておきます。補正をどうするかは後で述べます。関数 $\rho$ は適当なロス関数で、$M$ はピクセル数です。

{\mathcal{L}}_{ssi}(\hat{d}, \hat{d^*}) = \frac{1}{2M}\sum_{i=1}^M \rho(\hat{d_i} - \hat{d^*_i})

では、次に補正をどうするか考えましょう。スケール(つまり定数倍)とシフト(つまり定数加算)を考えればよいので、それぞれの未知変数を $s, t$ とおいて、以下のような最適化問題を解けば、$s, t$ が求まります。$d_i, d_i^* $は既知なので、いわゆる最小二乗法で解けますね。

 (s,t) = \mathrm{arg}\min_{s,t}\sum_{i=1}^M(sd_i + t-d_i^*)^2

その上で以下のような補正を行います。

 \hat{d} = s d + t, \hat{d^*}=d^*

上記の方法では、正解の $d^* $ は補正していません。実際には $d^* $には不正確な値や外れ値が入っているので、$d^* $をそのまま使うのも良くない可能性がありますし、これをベースにした補正済みの推定値$\hat{d}$も正しくない可能性があります。そこで論文ではヒューリスティックな方法も提案しています。

t(d) = \mathrm{median}(d), s(d) = \frac{1}{M}\sum_{i=1}^M|d-t(d)|

つまり、シフトは画像内のdisparity の中央値、スケールはシフトからの差分の絶対値の平均です。この $s,t$ を使ったシフト・スケールの変換を $d, d^* $の両方に対して適用して、もとのロス関数を計算する方法を提案しています。ロス関数はL2ロスだと外れ値に弱いので、L1ロスを利用しています。

Related loss functions

ここでは既存のロス関数のなかで、関連するロス関数を取り上げています。例えば、${\mathcal{L}}_{silog}(z, z^* ) $は、スケールを考慮したロスを対数深度から計算します。また、異なる位置の深度の関係を考慮したロス、disparity そのものでなく勾配を利用したロスなどを挙げています。ここは既存のロス関数を使うだけなので、詳細は論文を確認してみてください。

Final loss

まず、上の Related loss functions の1つである、勾配を利用したロス Normalized Multiscale Gradient Loss (NMG Loss) に似たような関数を使います。

{\mathcal{L}}_{reg}(\hat{d}, \hat{d}^*) = \frac{1}{M}\sum_{k=1}^K\sum_{i=1}^M (|\nabla_x R^k_i| +|\nabla_y R^k_i|)

$R_i$ は disparity の正解との差分 $\hat{d_i}-\hat{d^*}_i$です。推定値と正解値でそれぞれ、$x$方向、$y$方向で微分、つまり隣のピクセルの Disparity から引き算をしていったものを作って、その微分したもの同士の絶対値誤差を計算します。$k$はスケールを表していているようです。

そして最初のロス ${\mathcal{L}}_{ssi}$を組み込んで最終的には以下のとおりです。

{\mathcal{L}}_l = \frac{1}{N_l} \sum_{n=1}^{N_l} {\mathcal{L}_{ssi}}(\hat{d^n_i} - (\hat{d^*_i})^n)+\alpha {\mathcal{L}_{reg}}(\hat{d^n_i} - (\hat{d^*_i})^n)

ここで $l$ は学習データセットを指しているので、ロスは学習データごとに求まるということですね。

Mixing strategy

いよいよどうやって複数のデータセットを混ぜて学習するかを説明します。

すごく単純な方法ですが、すべてのデータセットからデータを均等に引いてミニバッチを作ります。例えば、ミニバッチのサイズが50、データセットの数が10なら、1つのデータセットから5データずつ引いてきます。

ロスは学習データごとに定義されていて複数あるので、パラメータの良し悪しを一概に判定することはできません。例えば、あるロスAは小さいがロスBが大きい場合、それはロスAが大きくロスBが小さい場合と比べて良いのでしょうか?しかし、少なくとも言えそうなのは、ロスAとBがともに大きい場合よりも良いということは言えそうです。

このように、あるパラメータ$\theta_i$ によって得られるすべてのロスの値が、他のすべてのパラメータ$\theta_j$から計算されるすべてのロスの値よりも悪くなければ、そのパラメータ$\theta_i$ はパレート最適であるといい、こういったパラメータを求めるように学習します。具体的には以下の論文紹介にある方法が採用されているとのことです。

論文紹介: Multi-Task Learning as Multi-Objective Optimization
https://qiita.com/koreyou/items/57c00bc314a68432de25

(補足) 使っているネットワーク構造について

論文中では明言されていないような気がしますが、Github のコードを見た感じでは、RefineNet がベースのようで、Encoder 部分は ResNeXt101をデフォルトで使っているようです。ちなみに、RefineNet のオリジナルは ResNet101を Encoder に使っていました。

SageMaker へのモデルデプロイ

ここまではアルゴリズムについていろいろと説明しましたが、MiDas は Zero-shot なモデルなので学習は不要で、学習アルゴリズムを理解しなくても、モデルの使い方さえわかれば使えます。早速 SageMaker にデプロイして推論 API を作ってみましょう。以下を読むのも面倒な方は、gistにあげたコードを SageMaker Python SDK が入った環境で試してみてくだい。

デプロイを始めるにあたって

MiDaS は v2.1 を使います。PyTorch hub のサイト にいくと推論スクリプトがありますが、残念ながらそのまま SageMaker では使うことができませんでした。私の場合、以下のモデル読み込みで、PyTorchの不具合のようなものにあたってしまいました。推奨バージョンが Python 3.7 なので、 Python 3.6 の SageMaker でエラーが出ているのかもしれません。

import torch
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
midas.eval()

気を取り直して、torch hub からモデルを読むのではなく、モデルの定義やパラメータを別途用意して、SageMaker にホストするようにします。

必要なファイルの取得

torch hub からいきなりモデルをロードできなかったので、モデルを定義する python ファイルと、PyTorch のモデルパラメータを取得しましょう。

!mkdir ./midas_model
!wget https://github.com/intel-isl/MiDaS/releases/download/v2_1/model-f6b98070.pt -O ./midas_model/model.pt

!mkdir ./src
!git clone https://github.com/intel-isl/MiDaS.git
!mv ./MiDaS/midas ./src

モデルを定義するファイルを ./src において、モデルパラメータは./midas_model/model.pt におきました。

推論コードの作成

SageMaker のいつもの通り、 model_fn, transform_fn (または、input_fn, predict_fn, output_fn) を作成します。ここでは、model_fn と input_fn と predict_fn を作成しています。

model_fn

さきほどのモデルを定義するpythonファイルのなかに midas_net.py というファイルがあり、そのなかに定義がありますので、パラメータファイル midas_model/model.pt を指定して読み込みます。このパラメータファイルは、デプロイ時に自動的に model_dir に展開されますので、そこから読むようにしましょう。

input_fn

この関数は本来省略可能ですが、省略するとすべてのデータは PyTorch Tensor の形式で predict_fn に送られます。このあと、predict_fn では前処理 transformをかけないといけないのですが、これは numpy 形式を想定しています。そのため、input_fn は numpy のデータをうけて、そのまま numpy にして predict_fn に送るようにします。

predict_fn

model_fn から model, input_fn から numpy を受け取って前処理と推論を行います。

デプロイ

デプロイは機械的な作業です。実質的なコードは5行ですね。

  1. モデル (パラメータのファイル)をtar.gzに固めてS3にアップロード
  2. アップロードしたモデル、src においたモデルの定義ファイルを指定して、PyTorchModel を使ってデプロイ

注意事項として、この時点では pytorch 1.5.0 でデプロイしてください。1.6.0はデプロイの仕組みが異なるようなのでうまく動きません。

推論の実行

MiDaSは384x384のサイズで精度が良くなるとの説明があるのでリサイズしておくりましょう。imagesにある適当な画像を読んで、numpy 形式で送ります。

new_length = 384

img = Image.open(os.path.join("images",os.listdir("./images")[4]))

w, h = img.size
resize_img = img.resize((new_length, new_length))
data = np.array(resize_img).transpose([2,0,1])

output = predictor.predict(img)

これは某社の写真ですが、暗いほど遠いところにあることを表しているので、Zero-shot にもかかわらずかなり精度が良さそうです!初めて気がついたのですが、天井のごちゃごちゃした部分も部分的に認識されていますね。
result.png

法線の推定

これはおまけですが、深度がもとまれば、その微分を応用して法線 (傾き) を求めることができます。微分のコードは gist の方を見てください。

index.png

色と方向の対応がわかりにくいと思うので、試しに半球を描いて法線をもとめてみます。
index.png

つまり天井は下の方を向いている(ピンク)、床は上の方をむいている(緑)、並んでいる机の左側は手前左を向いている(紫)など、きれいに面の向きが取れていますね。

さいごに

MiDaS はなかなか精度の良いモデルであることがわかりました。もし SageMaker でデプロイしたら、最後にエンドポイントを削除し忘れないように注意しておきましょう。gist の最後のコードを実行すると消えると思います。

19
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
19
17