3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

*Interp移植録 - 顔検出/CenterFace (OnnxInterp)

Posted at

0.Prologue

暇つぶしに、興味を引いた DNNアプリを *Interpに移植して遊んでいる。
本稿はその雑記&記録。

先日、「RetinaFace」と言う顔検出モデルを OnnxInterpに移植して遊んでみた[*1]。その検出力の高さに一人驚いてた訳だが、世間の噂によると「RetinaFace」は検出力は高いもののレスポンスは良くないとのことだ。マイPCは GPUを持たぬオンボロPCなので、まぁこんなものかと思っていたが、いわれてみれば確かに一呼吸おいて結果が返ってくるように思えて来た。そんなRetinaFaceの弱点を克服しようと、「YuNet」や「Ultra-Light-Fast-Generic-Face-Detector」などが提案されてきたというのがこの界隈の時の流れのようだ。

今回移植してみる「CenterFace」もそれら群雄割拠の1つ。さて、どのくらい速くなるのだろうか?

[*1]*Interp移植録 - 顔検出/RetinaFace (OnnxInterp)

1.Original Work

「CenterFace」の特徴は、アンカーフリーなモデルと言うところであろうか。RetinaFaceやYuNetは、グリッド毎に張られた複数のアンカーボックスを足場にして顔の検出を試みている。そのため、同じグリッドの回りのデータをアンカーボックスの個数分だけ繰返して処理することになる。アンカーフリーにして、この重複を取り除こうと言う戦術のようだ。反面、メッシュ・グリッドのピッチは細かくする必要があり、標準では4pixのグリッドを採用している。

CenterFaceのモデル・アーキテクチャは下図の様になっている。Backboneには MobileNetを配して、ここでも高速化を志向している。

image.png
(抜粋:「CenterFace: Joint Face Detection and Alignment Using Face as Point」より)

classificationヘッドの結果は、メッシュ・グリッドの大きさ(入力画像の1/16)を持つ一枚の Heatmapとして出力される。Heatmapの行・列はグリッドの位置に対応しており、値はそのグリッドに顔の中心点がある確率を表している。例えば、下図左の入力画像に対し右の Heatmapが出力される。(緑~赤が確率が高いグリッド)

img-heatmap.jpg

2.準備

CenterFaceの ONNXモデルは、上の「CenterFace」プロジェクトから調達する。

さて、ここで2晩試行錯誤することになった。入手したモデルは、入力がfloat32{10,3,32,32}に固定されており、どうすれば640x640の画像を与えることができるか分からなかった[*2]。ある投稿によると気にせずにそのまま640x640の画像を渡せば良いとあったが、試してみるとOnnxInterpが例外を吐く。またある投稿では画像の縦横のサイズは32の倍数でなればならないとあるが、640は条件を満たしている。はてな?である。

[*2]このモデルは Conv等の入力データのシェイプに左右されないオペレータだけで構成されているので、可変シェイプ入力が可能なモデルだ。

image.png

結局、次の Pythonスクリプト[*3]で、モデルの入力定義を可変シェイプに書き換えることで一件落着した。「そのまま渡せば」と言うケースは、そこで使用しているライブラリがこっそりと入力定義を書き換えているのだと想像する。OnnxInterpはそういうお行儀の悪いことはしないので嵌ったのだろう。

[*3]Elixirでやってやれなくはないが、まだ道具を整えていないので先輩の Pythonに頼った。

モデル入力定義の書き換え
import onnx

model = onnx.load_model("centerface.onnx")
d = model.graph.input[0].type.tensor_type.shape.dim
d[0].dim_value = 1
d[2].dim_value = -1  # dynamic dimension
d[3].dim_value = -1  # dynamic dimension
onnx.save_model(model,"centerface_dynamic.onnx" )

3.OnnxInterp用のLivebookノート

Mix.installの依存リストに記述するモジュールは下記の通り。PostDNN.meshgridが必要なので PostDNNを含める。

setup cell
File.cd!(__DIR__)
# for windows JP
System.shell("chcp 65001")

Mix.install([
  {:onnx_interp, path: ".."},
  {:cimg, "~> 0.1.16"},
  {:postdnn, "~> 0.1.4"},
  {:nx, "~> 0.4.0"},
  {:kino, "~> 0.7.0"}
])

モデルの出力は、入力画像の1/16の大きさを持つ2次元マップ(表)。先にも触れたが、マップの行・列がグリッドの位置に対応づく。本家の実装コード(下記)では、この2次元マップをそのまま参照しつつ BBoxのデコードを行っている。

本家実装コードの抜粋
    def decode(self, heatmap, scale, offset, landmark, size, threshold=0.1):
        heatmap = np.squeeze(heatmap)
        scale0, scale1 = scale[0, 0, :, :], scale[0, 1, :, :]
        offset0, offset1 = offset[0, 0, :, :], offset[0, 1, :, :]
        c0, c1 = np.where(heatmap > threshold)
        if self.landmarks:
            boxes, lms = [], []
        else:
            boxes = []
        if len(c0) > 0:
            for i in range(len(c0)):
                s0, s1 = np.exp(scale0[c0[i], c1[i]]) * 4, np.exp(scale1[c0[i], c1[i]]) * 4
                o0, o1 = offset0[c0[i], c1[i]], offset1[c0[i], c1[i]]
                s = heatmap[c0[i], c1[i]]
                x1, y1 = max(0, (c1[i] + o1 + 0.5) * 4 - s1 / 2), max(0, (c0[i] + o0 + 0.5) * 4 - s0 / 2)
                x1, y1 = min(x1, size[1]), min(y1, size[0])
                boxes.append([x1, y1, min(x1 + s1, size[1]), min(y1 + s0, size[0]), s])
                if self.landmarks:
                    lm = []
                    for j in range(5):
                        lm.append(landmark[0, j * 2 + 1, c0[i], c1[i]] * s1 + x1)
                        lm.append(landmark[0, j * 2, c0[i], c1[i]] * s0 + y1)
                    lms.append(lm)
            boxes = np.asarray(boxes, dtype=np.float32)
            keep = self.nms(boxes[:, :4], boxes[:, 4], 0.3)
            boxes = boxes[keep, :]
            if self.landmarks:
                lms = np.asarray(lms, dtype=np.float32)
                lms = lms[keep, :]
        if self.landmarks:
            return boxes, lms
        else:
            return boxes

しかしながら、本家設計者の思いを横に置き、モデルの出力データとその処理内容にフォーカスして考え直してみると、2次元マップに拘る必要がないことに気づく。要は、グリッド・ローカルな座標系で表されたBBoxの座標を、メッシュ全体のグローバルな座標系に変換したいのだ。それならば、RetinaFace移植でも同じことをしたではないか。RetinaFace移植コードを流用すれば、Nxの tensor演算との親和性が良く、さらに Elixirが苦手とするループ構造が要らないというおまけが付いてくる。ふむ、方針は流用と決まった。[*4]

残る検討事項は、本家実装コードでは NMSの前に Heatmapの値(評価値)による出力データのフィルタリングを行っており、これをどうするかだ。たぶん、BBoxデコードの計算量を減らすことが目的であろう。PostDNN.sieveを使えば同じ機能を実現できるのだが、仮に全ての BBoxをデコードするとしても高々25,600個(入力画像640x640の場合)なので、その節約効果は実装によるコードの複雑化に見合わないと思う。よって、"えいやぁ"と全てのBBoxをデコードし NMSに放り込むことにしよう。

以上を踏まえて起こしたコードが下記の CenterFaceモジュールだ。ほとんど RetinaFace移植のそれと同じ。主な違いは、アンカーボックスPostDNN.prioriboxがメッシュ・グリッドPostDNN.meshgridに変わった辺りだろうか。

入力画像のresizeは、aspect比保存で行う(:ulオプション)。画素値の型変換は CImg.to_binaryの {:range, {0.0, 255.0}}オプションで指定。fit2image_with_landmark/4では、各座標値を入力画像の座標系に戻す逆aspect変換(?)を行い、ランドマークのデコード済み座標を添付している。

[*4]もちろん前提条件が異なれば設計方針も異なる。例えば、リソースが貧弱なエッジデバイスに実装すならば、(C言語を用いて)本家と同じ設計を選択するだろう。

[モデル・カード]

  • inputs:
    [0] f32:{1,3,height,width} - RGB画像,NCHWレイアウト,画素はRGB各値(符号なし8bit整数)をそのままfloat 32bit型に変換

  • outputs:
    [0] f32:{1,1,height/4,width/4} - heatmap グリッドに顔(BBox)中心がある確からしさ
    [1] f32:{1,2,height/4,width/4} - scale BBoxのサイズ, グリッドに対する比率表記
    [2] f32:{1,2,height/4,width/4} - offset グリッド中心に対するBBoxの中心のオフセット
    [3] f32:{1,10,height/4,width/4} - landm ランドマーク(Xi,Yi) x 5, グリッドに対する比率表記

  • mesh grid:
    格子間隔 4の一様なグリッド, アンカーフリー

centerface
defmodule CenterFace do
  import Nx.Defn

  @width  640
  @height 640

  alias OnnxInterp, as: NNInterp
  use NNInterp,
    model: "./model/centerface_dynamic.onnx",
    url: "https://github.com/shoz-f/onnx_interp/releases/download/models/centerface_dynamic.onnx",
    inputs: [f32: {1,3,@height,@width}],
    outputs: [f32: {1,1,div(@height,4),div(@width,4)}, f32: {1,2,div(@height,4),div(@width,4)}, f32: {1,2,div(@height,4),div(@width,4)}, f32: {1,10,div(@height,4),div(@width,4)}]

  def apply(img) do
    # preprocess
    bin = CImg.builder(img)
      |> CImg.resize({@width, @height}, :ul, 0)
      |> CImg.to_binary([{:range, {0.0, 255.0}}, :nchw])

    # prediction
    outputs = session()
      |> NNInterp.set_input_tensor(0, bin)
      |> NNInterp.invoke()

    [heatmap, scale, offset, landm] = Enum.with_index([1, 2, 2, 10], fn dim,i ->
        NNInterp.get_output_tensor(outputs, i) |> Nx.from_binary(:f32) |> Nx.reshape({dim, :auto})
      end)

    # postprocess
    scores = Nx.transpose(heatmap)
    boxes  = decode_boxes(offset, scale)
    landm  = Nx.transpose(landm)

    {:ok, res} = NNInterp.non_max_suppression_multi_class(__MODULE__,
        Nx.shape(scores), Nx.to_binary(boxes), Nx.to_binary(scores),
        iou_threshold: 0.2, score_threshold: 0.2,
        boxrepr: :corner)

    {:ok, fit2image_with_landmark(landm, res["0"], inv_aspect(img))}
  end


  @grid PostDNN.meshgrid({@width, @height}, [4], [:center, :normalize, :transpose])

  defp decode_boxes(offset, size) do
    # decode box center coordinate on {1.0, 1.0}
    center = offset
      |> Nx.reverse(axes: [0])     # swap (y,x) -> (x,y)
      |> Nx.multiply(@grid[2..3]) # * grid_pitch(x,y)
      |> Nx.add(@grid[0..1])      # + grid(x,y)

    # decode box half size
    half_size = size
      |> Nx.reverse(axes: [0])     # swap (y,x) -> (x,y)
      |> Nx.exp()
      |> Nx.multiply(@grid[2..3]) # * grid_pitch(x,y)
      |> Nx.divide(2.0)

    # decode boxes
    [Nx.subtract(center, half_size), Nx.add(center, half_size)]
      |> Nx.concatenate()
      |> PostDNN.clamp({0.0, 1.0})
      |> Nx.transpose()
  end

  defp fit2image_with_landmark(landm, nms_res, {inv_x, inv_y} \\ {1.0, 1.0}) do
    Enum.map(nms_res, fn [score, x1, y1, x2, y2, index] ->
      grid = Nx.slice_along_axis(@grid, index, 1, axis: 1) |> Nx.squeeze()

      landmark = landm[index]
        |> Nx.reshape({:auto, 2})
        |> Nx.reverse(axes: [0])
        |> Nx.multiply(grid[2..3]) # * grid_pitch(x,y)
        |> Nx.add(grid[0..1])      # + grid(x,y)
        |> Nx.multiply(Nx.tensor([inv_x, inv_y]))
        |> Nx.to_flat_list()
        |> Enum.chunk_every(2)

      [score, x1*inv_x, y1*inv_y, x2*inv_x, y2*inv_y, landmark]
    end)
  end

  defp inv_aspect(img) do
    {w, h, _, _} = CImg.shape(img)
    if w > h, do: {1.0, w / h}, else: {h / w, 1.0}
  end
end

デモ・モジュール DemoCenterFaceは、RetinaFace移植のそれに同じ。

demo_centerface
defmodule DemoCenterFace do
  def run(path) do
    img = CImg.load(path)

    with {:ok, res} = CenterFace.apply(img) do
      res
      |> draw_item(CImg.builder(img), {0, 255, 0})
      |> CImg.display_kino(:jpeg)
    end
  end

  defp draw_item(boxes, canvas, color \\ {255, 255, 255}) do
    Enum.reduce(boxes, canvas, fn [_score, x1, y1, x2, y2, _landmark], canvas ->
      CImg.fill_rect(canvas, x1, y1, x2, y2, color, 0.3)
    end)
  end
end

4.デモンストレーション

CenterFaceを起動する。

CenterFace.start_link([])

画像を与え、顔検出を行う。

DemoCenterFace.run("10.jpg")

image.png

5.Epilogue

誤検出がちらほらと見受けられるものの、レスポンスは RetinaFace-ResNet50に比して約2倍速くなった。顔の検出力も RetinaFaceに比べ遜色ないように思う。これは使えそうだ。

顔検出モデルの移植は、これで4つ目となる。今回の CenterFaceはアンカーフリーなモデルで、RetinaFaceやYuNetとは後処理が大きく異なるのではと身構えていたが、蓋を開けてみれば前例と似たり寄ったりのコードとなった。もしかしたらこの分野の後処理は、同様の処理に収斂するのかも知れないなぁ。

(END)

Appendix

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?