LoginSignup
2
2

Rust WebAssembly でリアルタイムで顔にモザイク処理

Last updated at Posted at 2024-01-14

rustface

rustfaceというクレートを見つけました。

rustfaceSeetaface Engineという C++ で書かれた顔認識ライブラリ(検出・位置調整・認識)のうち「検出」の部分を Rust で書いたものらしい。Seetafaceは CPU だけ使用してそれなりに高速に動くのが特徴のようです。Rust で書かれてるなら wasm-pack で簡単に WebAssembly 化できそうなのでやってみました。

成果物(1/18 追記:モザイク処理を追加しました。)

output.gif

安心してください。顔画像はどこにも送信されません。(カメラによってはアスペクトレシオが崩れるようです:sweat:

wasm-pack

wasm-packのインストール。簡単。

cargo install wasm-pack

私の環境では何やらパッケージが足りないとかエラーが出ました。以下の2つを追加しました。

sudo apt install pkg-config libssl-dev

新しいプロジェクトの作成。

wasm-pack new face-detection

Cargo.toml

wasm-pack newで生成されたCargo.tomlrustfaceを追加。デフォルトではマルチスレッドが有効になっているのでdefault-features = falserayonが使われないようします。

Cargo.toml
[package]
name = "face-detection"
version = "0.1.0"
authors = ["benki"]
edition = "2021"

[lib]
crate-type = ["cdylib", "rlib"]

[features]
default = ["console_error_panic_hook"]

[dependencies]
# rayon クレートを無効化するために default-features を無効に
+ rustface = { version = "0.1", default-features = false }
wasm-bindgen = "0.2"
console_error_panic_hook = { version = "0.1", optional = true }

[dev-dependencies]
wasm-bindgen-test = "0.3"

[profile.release]
opt-level = "s"
# 最適化
+ codegen-units = 1
+ lto = "fat"
+ strip = true

# wasm-opt の無効化
+ [package.metadata.wasm-pack.profile.release]
+ wasm-opt = false

src/lib.rs

モデル

rustfaceの検出器を初期化するときに seeta_fd_frontal_v1.0.bin というモデルが必要になります。wasm で読み込むのはちょっと手間がかかるのでバイナリに埋め込みます。include_bytes!マクロを使えばよいです。

src/lib.rs
// seeta_fd_frontal_v1.0.bin を model.bin にリネームして
// Cargo.toml と同じ場所に置いてます。
const MODEL: &[u8] = include_bytes!("../model.bin");

検出器の初期化・設定

検出器はthread_local!マクロとRefCellでグローバル変数に保持しておきます。setup関数でモデルの読み込み、検出器の初期化、設定を行います。諸々の設定の意味はよく分かりません(おい)。set_pyraid_scale_factorがパフォーマンスへの影響が大きいようです。0.1〜0.99の値を入力しますが、値が大きいほど遅く検出精度がよくなり、小さいほど速く精度が悪くなります。

src/lib.rs
// 検出器を保持しておくためのグローバル変数を初期化
thread_local! {
    static DETECTOR: RefCell<Option<Box<dyn Detector>>> = RefCell::new(None);
}

#[wasm_bindgen]
pub fn setup(min_face_sizes: u32, score_thresh: f64, pyramid_scale_factor: f32, slide_window_step: u32) {
    // モデルを読み込む
    let model = rustface::read_model(MODEL).expect("failed to read model.");
    // 検出器を初期化
    let mut detector = rustface::create_detector_with_model(model);
    // 諸々の設定
    detector.set_min_face_size(min_face_sizes);
    detector.set_score_thresh(score_thresh);
    detector.set_pyramid_scale_factor(pyramid_scale_factor);
    detector.set_slide_window_step(slide_window_step, slide_window_step);
    // 検出器をグローバル変数にセット
    DETECTOR.set(Some(detector));
}

顔検出

検出する処理です。detect関数の引数は canvas context から取得した RGBA 画像データと、その幅・高さです。まず、RGBA 画像データをグレースケールに変換します。検出器に渡す画像はグレースケールである必要があります。検出器のdetect関数にグレースケール画像を渡すとVec<FaceInfo>が返ってくるので。FaceInfo のうちの最小限必要な顔の x・y 座標、幅、高さだけを JavaScript 側に返すようにしています。

src/lib.rs
#[wasm_bindgen]
pub fn detect(rgba: &[u8], width: u32, height: u32) -> Vec<Info> {
    // RGBA 画像をグレースケールに変換。
    let grayscale = rgba
        .chunks(4)
        // 整数演算で高速化
        .map(|v| ((19 * v[0] as u16) >> 8) + ((183 * v[1] as u16) >> 8) + ((53 * v[2] as u16) >> 8))
        .map(|v| v as u8)
        .collect::<Vec<_>>();

    // ImageData 形式に変換
    let img = rustface::ImageData::new(&grayscale, width, height);

    // グローバル変数に保持している検出器を取得
    DETECTOR.with(|detector| {
        let Some(ref mut detector) = *detector.borrow_mut() else {
            return vec![];
        };
        // 検出
        detector
            .detect(&img)
            .iter()
            // x, y 座標、幅、高さだけ抜き出す
            .map(|info| Info {
                x: info.bbox().x(),
                y: info.bbox().y(),
                width: info.bbox().width(),
                height: info.bbox().height(),
            })
            .collect()
    })
}
src/lib.rs 全文
src/lib.rs
mod utils;

use rustface::Detector;
use std::cell::RefCell;
use wasm_bindgen::prelude::*;

const MODEL: &[u8] = include_bytes!("../model.bin");

thread_local! {
    static DETECTOR: RefCell<Option<Box<dyn Detector>>> = RefCell::new(None);
}

#[wasm_bindgen]
pub fn setup(min_face_sizes: u32, score_thresh: f64, pyramid_scale_factor: f32, slide_window_step: u32) {
    let model = rustface::read_model(MODEL).expect("failed to read model.");
    let mut detector = rustface::create_detector_with_model(model);
    detector.set_min_face_size(min_face_sizes);
    detector.set_score_thresh(score_thresh);
    detector.set_pyramid_scale_factor(pyramid_scale_factor);
    detector.set_slide_window_step(slide_window_step, slide_window_step);
    DETECTOR.set(Some(detector));
}

#[wasm_bindgen]
pub struct Info {
    x: i32,
    y: i32,
    width: u32,
    height: u32,
}

#[wasm_bindgen]
impl Info {
    #[wasm_bindgen(getter)]
    pub fn x(&self) -> i32 {
        self.x
    }
    #[wasm_bindgen(getter)]
    pub fn y(&self) -> i32 {
        self.y
    }
    #[wasm_bindgen(getter)]
    pub fn width(&self) -> u32 {
        self.width
    }
    #[wasm_bindgen(getter)]
    pub fn height(&self) -> u32 {
        self.height
    }
}

#[wasm_bindgen]
pub fn detect(rgba: &[u8], width: u32, height: u32) -> Vec<Info> {
    let grayscale = rgba
        .chunks(4)
        .map(|v| ((19 * v[0] as u16) >> 8) + ((183 * v[1] as u16) >> 8) + ((53 * v[2] as u16) >> 8))
        .map(|v| v as u8)
        .collect::<Vec<_>>();

    let img = rustface::ImageData::new(&grayscale, width, height);

    DETECTOR.with(|detector| {
        let Some(ref mut detector) = *detector.borrow_mut() else {
            return vec![];
        };
        detector
            .detect(&img)
            .iter()
            .map(|info| Info {
                x: info.bbox().x(),
                y: info.bbox().y(),
                width: info.bbox().width(),
                height: info.bbox().height(),
            })
            .collect()
    })
}

ビルド

--target webで Web 用のface_detection.jsface_detection_bg.wasmが ./pkg ディレクトリ以下に生成されます。

wasm-pack build --target web

index.html

JavaScript 側の処理を書いていきます。

index.html
<!DOCTYPE html>
<html>

<head>
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <title>face detection wasm</title>
</head>

<body>
    <div id="fps"></div>
    <video autoplay muted hidden id="v1"></video>
    <canvas id="c1"></canvas>
    <script type="module">
        // wasm の関数をインポート
        import init, {setup, detect} from "./pkg/face_detection.js"
        const onload = async () => {
            // wasm モジュールの初期化
            await init();
            // 検出器の初期化・設定
            setup(20, 2.8, 0.5, 4);

            const width = 320;
            const height = 240;

            // カメラの設定
            const constraints = {
                audio: false,
                video: true,
                faceingMode: "user",
                width: {exact: width},
                height: {exact: height},
                aspectRatio: {exact: width / height}
            };
            // カメラを取得
            const stream = await navigator.mediaDevices.getUserMedia(constraints);
            // ビデオタグにカメラ映像を再生
            // ビデオタグは hidden なので実際には表示されない
            const video = document.querySelector("#v1");
            video.srcObject = stream;
            video.play();

            // カメラ映像を表示するためのキャンバスを取得
            const canvas = document.querySelector("#c1");
            canvas.width = width;
            canvas.height = height;

            const ctx = canvas.getContext("2d");
            ctx.strokeStyle = "red";
            ctx.lineWidth = 3;

            // FPS の表示処理
            let last_time;
            const fps_div = document.querySelector("#fps");

            const show_fps = () => {
                if (!last_time) {
                    last_time = performance.now();
                    return;
                }
                let delta = (performance.now() - last_time) / 1000;
                last_time = performance.now();
                const fps = Math.floor(1 / delta);
                fps_div.innerText = `FPS: ${fps}`;
            };

            const draw_image = () => {
                ctx.clearRect(0, 0, width, height);
                // カメラ映像の描画
                ctx.drawImage(video, 0, 0, width, height);
                // 描画された映像の RGBA データを取得
                const rgba = ctx.getImageData(0, 0, width, height).data;
                // 顔の位置を検出
                detect(rgba, width, height).forEach((info) => {
                    // 検出された位置を矩形で描画
                    ctx.strokeRect(info.x, info.y, info.width, info.height);
                });
                show_fps();
                requestAnimationFrame(draw_image);
            };
            draw_image();
        };
        window.addEventListener("DOMContentLoaded", onload);
    </script>
</body>

</html>

まとめ

第 8 世代のモバイル向け Core i5 搭載ノート PC で 30 FPS ぐらい出るのでまあまあ速いのかな?:thinking:

モザイク処理

顔の位置を検出して赤枠で囲むだけはつまらないので、検出された顔の位置にモザイクをかける処理を追加しました。8 x 8 ピクセルの RGB 平均値を Rust WebAssembly 側で計算し配列に格納して、JavaScript 側に返す実装となります。8 x 8 ピクセル平均値を計算するのにfor文でゴリゴリ回してもいいんですが、イテレータで回してtraitも使うとコードがスッキリして良い感じになりますね。モザイク処理するコードは以下のようになりました。(コメントで何をしているか書いているつもりですが、なんやら意味不明ですね:sweat_smile:

src/lib.rs
fn mosaic(
    rgba: &[u8],
    img_width: u32,
    x: i32,
    y: i32,
    face_width: u32,
    face_height: u32,
    block_size: usize,
) -> Vec<Row> {
    rgba
        // ビデオ画像の1行を rgba (4バイト) × 画像幅の配列に分ける
        .chunks(4 * img_width as usize)
        // 検出された顔の位置までスキップ
        .skip(y as usize)
        // 検出された顔の高さ分を取得
        .take(face_height as usize)
        // 検出された顔のビューを取得
        .collect::<Vec<_>>()
        // 高さ方向のモザイクブロックサイズごとに処理する
        .chunks_exact(block_size)
        .filter_map(|rows| {
            // 1 行ごとに処理する
            rows.iter()
                .filter_map(|row| {
                    let left = 4 * x as usize;
                    // 検出された顔の位置のスライスを取得
                    let cols = row
                        .get(left..left + 4 * face_width as usize)?
                        // 横方向のブロックサイズ分のピクセルを取得
                        .chunks_exact(4 * block_size)
                        .filter_map(|pixels| {
                            pixels
                                .chunks(4)
                                // RGBA を RGB に変換
                                .map(Rgb::new)
                                // 横方向の平均値を計算
                                .reduce(|acc, e| acc.average(e))
                        })
                        // 横方向の平均値をまとめる
                        .collect();
                    Some(Row { cols })
                })
                // 縦方向のブロックサイズ分の平均値をまとめる
                .reduce(|acc, e| acc.average(e))
        })
        .collect()
}

Rgbの平均値を求めるために下記のようなtraitRgb構造体に実装しています。

src/lib.rs
trait Average {
    type Output;
    fn average(self, rhs: Self::Output) -> Self::Output;
}

impl Average for Rgb {
    type Output = Self;
    fn average(self, rhs: Self::Output) -> Self::Output {
        Rgb {
            r: ((self.r() as u16 + rhs.r() as u16) >> 1) as u8,
            g: ((self.g() as u16 + rhs.g() as u16) >> 1) as u8,
            b: ((self.b() as u16 + rhs.b() as u16) >> 1) as u8,
        }
    }
}
モザイク処理のコード全部
mod utils;

use rustface::Detector;
use std::cell::RefCell;
use wasm_bindgen::prelude::*;

const MODEL: &[u8] = include_bytes!("../model.bin");

thread_local! {
    static DETECTOR: RefCell<Option<Box<dyn Detector>>> = RefCell::new(None);
}

#[wasm_bindgen]
pub fn setup() {
    let model = rustface::read_model(MODEL).expect("failed to read model.");
    let mut detector = rustface::create_detector_with_model(model);
    detector.set_min_face_size(20);
    detector.set_score_thresh(2.8);
    detector.set_pyramid_scale_factor(0.5);
    detector.set_slide_window_step(4, 4);
    DETECTOR.set(Some(detector));
}

#[wasm_bindgen]
pub struct Info {
    x: i32,
    y: i32,
    mosaic: Vec<Row>,
}

#[wasm_bindgen]
impl Info {
    #[wasm_bindgen(getter)]
    pub fn x(&self) -> i32 {
        self.x
    }
    #[wasm_bindgen(getter)]
    pub fn y(&self) -> i32 {
        self.y
    }
    #[wasm_bindgen(getter)]
    pub fn mosaic(self) -> Vec<Row> {
        self.mosaic
    }
}

trait Average {
    type Output;
    fn average(self, rhs: Self::Output) -> Self::Output;
}

#[wasm_bindgen]
pub struct Row {
    cols: Vec<Rgb>,
}

#[wasm_bindgen]
impl Row {
    #[wasm_bindgen(getter)]
    pub fn cols(self) -> Vec<Rgb> {
        self.cols
    }
}

impl Average for Row {
    type Output = Self;
    fn average(self, rhs: Self::Output) -> Self::Output {
        let cols = self
            .cols()
            .into_iter()
            .zip(rhs.cols())
            .map(|(acc, e)| acc.average(e))
            .collect();
        Row { cols }
    }
}

#[wasm_bindgen]
pub struct Rgb {
    r: u8,
    g: u8,
    b: u8,
}

impl Rgb {
    fn new(rgba: &[u8]) -> Self {
        Self {
            r: rgba[0],
            g: rgba[1],
            b: rgba[2],
        }
    }
}

#[wasm_bindgen]
impl Rgb {
    #[wasm_bindgen(getter)]
    pub fn r(&self) -> u8 {
        self.r
    }
    #[wasm_bindgen(getter)]
    pub fn g(&self) -> u8 {
        self.g
    }
    #[wasm_bindgen(getter)]
    pub fn b(&self) -> u8 {
        self.b
    }
}

impl Average for Rgb {
    type Output = Self;
    fn average(self, rhs: Self::Output) -> Self::Output {
        Rgb {
            r: ((self.r() as u16 + rhs.r() as u16) >> 1) as u8,
            g: ((self.g() as u16 + rhs.g() as u16) >> 1) as u8,
            b: ((self.b() as u16 + rhs.b() as u16) >> 1) as u8,
        }
    }
}

#[wasm_bindgen]
pub fn detect(rgba: &[u8], img_width: u32, img_height: u32, block_size: usize) -> Vec<Info> {
    let grayscale = rgba
        .chunks(4)
        .map(|v| ((19 * v[0] as u16) >> 8) + ((183 * v[1] as u16) >> 8) + ((53 * v[2] as u16) >> 8))
        .map(|v| v as u8)
        .collect::<Vec<_>>();

    let img = rustface::ImageData::new(&grayscale, img_width, img_height);

    DETECTOR.with(|detector| {
        let Some(ref mut detector) = *detector.borrow_mut() else {
            return vec![];
        };
        detector
            .detect(&img)
            .iter()
            .map(|info| {
                let x = info.bbox().x();
                let y = info.bbox().y();
                let mosaic = mosaic(
                    rgba,
                    img_width,
                    x,
                    y,
                    info.bbox().width(),
                    info.bbox().height(),
                    block_size,
                );
                Info { x, y, mosaic }
            })
            .collect()
    })
}

fn mosaic(
    rgba: &[u8],
    img_width: u32,
    x: i32,
    y: i32,
    face_width: u32,
    face_height: u32,
    block_size: usize,
) -> Vec<Row> {
    rgba.chunks(4 * img_width as usize)
        .skip(y as usize)
        .take(face_height as usize)
        .collect::<Vec<_>>()
        .chunks_exact(block_size)
        .filter_map(|rows| {
            rows.iter()
                .filter_map(|row| {
                    let left = 4 * x as usize;
                    let cols = row
                        .get(left..left + 4 * face_width as usize)?
                        .chunks_exact(4 * block_size)
                        .filter_map(|pixels| {
                            pixels
                                .chunks(4)
                                .map(Rgb::new)
                                .reduce(|acc, e| acc.average(e))
                        })
                        .collect();
                    Some(Row { cols })
                })
                .reduce(|acc, e| acc.average(e))
        })
        .collect()
}
index.html
<!DOCTYPE html>
<html>

<head>
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <title>face detection wasm</title>
</head>

<body>
    <video autoplay muted hidden id="v1"></video>
    <canvas id="c1"></canvas>
    <div id="fps"></div>
    <input id="blocksize" type="range" min="4" max="32" step="2" value="8" />
    <label for="blocksize">Block size: 8</label>
    <script type="module">
        import init, {setup, detect} from "./pkg/face_detection.js"
        const onload = async () => {
            await init();
            setup();

            const width = 320;
            const height = 240;

            const constraints = {
                video: true,
                faceingMode: {exact: "user"},
                width: width,
                height: height,
            };
            const stream = await navigator.mediaDevices.getUserMedia(constraints);
            const video = document.querySelector("#v1");
            video.srcObject = stream;
            video.play();

            const canvas = document.querySelector("#c1");
            canvas.width = width;
            canvas.height = height;

            const ctx = canvas.getContext("2d");

            let last_time;
            const fps_div = document.querySelector("#fps");

            const show_fps = () => {
                if (!last_time) {
                    last_time = performance.now();
                    return;
                }
                let delta = (performance.now() - last_time) / 1000;
                last_time = performance.now();
                const fps = Math.floor(1 / delta);
                fps_div.innerText = `FPS: ${fps}`;
            };

            const slider = document.querySelector("#blocksize");
            const label = document.querySelector("label");
            slider.oninput = () => {
                label.innerText = `Block size: ${slider.value}`;
            };

            const draw_image = () => {
                ctx.drawImage(video, 0, 0, width, height);
                const rgba = ctx.getImageData(0, 0, width, height).data;
                const block_size = slider.value;

                detect(rgba, width, height, block_size).forEach((info) => {
                    const top_ = info.x;
                    const left_ = info.y;

                    info.mosaic.forEach((row, j) => {
                        row.cols.forEach((rgb, i) => {
                            const x = top_ + i * block_size;
                            const y = left_ + j * block_size;
                            ctx.fillStyle = `rgb(${rgb.r}, ${rgb.g}, ${rgb.b})`
                            ctx.fillRect(x, y, block_size, block_size);
                        });
                    });
                });
                show_fps();
                requestAnimationFrame(draw_image);
            };
            draw_image();
        };
        window.addEventListener("DOMContentLoaded", onload);
    </script>
</body>

</html>
2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2