LoginSignup
2
2

AIによる背景判定処理の実験 前編 Python、PyTorch編

Last updated at Posted at 2023-12-30

はじめに

VRデバイスのMeta Quest3でどの程度のAI処理ができるか知るために実験したので記事にしました。
前編はPCで期待する効果が得られるかテストします。
具体的にはPythonとPyTorchを使ってAI(CNN、画像認識)により背景か前景かを判定します。
後編ではMeta Quest3で実際に動作させます。
アマレコVR※の背景透過処理※への応用を想定した実験となっています。

※ アマレコVR
私が制作しているMeta Quest用 VR動画プレイヤー

※背景透過処理
再生中の動画の画像をリアルタイムで解析し 変化のない部分を透明にする機能
他の動画やパススルー映像と合成することができます

参考記事

この実験は私個人のAIの勉強を兼ねています。以下の記事から多くを学び参考にさせていただきました。

PyTorchではなくkerasの使用例です。正解率が高かったのでCNNモデルを参考にしました。



PyTorchの使用例です。こちらの記事にそって実験しています。

環境

  • Windows10
  • Python Ver 3.11.4
  • PyTorch Ver 2.2.0.dev20230915+cu121

実験内容

あらかじめ犬や猫、車などの写真(一般例)を学習してから 入力された画像(未知の写真)に何が写っているかを判定する 画像処理系AIの代表例であるCNN(Convolutional Neural Network)を使います。

最初に、犬や猫の代わりに背景に属する画像と前景(被写体)に属する画像をCNNモデルを使って学習し保存します。
次に入力された画像を縦横16分割(全部で256分割)した各エリアに対し 保存したCNNモデルを使って背景と前景のどちらの特徴が強いか推論します。

上手くいけば次のような結果が得られるはずです。

  • 背景(被写体が含まれないエリア)
    どちらの数値(背景、前景)も低いか 背景の数値が高くなる

  • 前景(被写体が含まれるエリア)
    含まれる被写体の面積により 前景の数値が高くなる

結果

使用したサンプル動画

学習に使った画像

背景画像

動画のスクリーンショットから被写体が映っていないところを切り取って背景画像にしました。

前景画像

動画から被写体だけを切り取り、さらにできるだけ背景が残らないようにマスクします。こんな画像を14枚作成しました。(ピンクの部分は学習しません)

実際に学習する画像

それぞれの画像から32x32の領域を切り出して 一つ一つを小さい1枚の画像として学習します。
chip_bg.png
chip_fg.png

CNNモデル

import torch
import torch.nn as nn
from torch import Tensor

# ニューラルネットワークの定義
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
    
        # 画像処理
        self.feature = nn.Sequential(
            # ブロック1
            nn.Conv2d( 3, 16, kernel_size=3, padding=(1,1), padding_mode="replicate"), #448
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, groups=16, padding=(1,1), padding_mode="replicate"), #160
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=1), #272
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            nn.Dropout(0.25),
        
            # ブロック2
            nn.Conv2d(16, 16, kernel_size=3, groups=16, padding=(1,1), padding_mode="replicate"), #160
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=1), #544
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, groups=32, padding=(1,1), padding_mode="replicate"), #320
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=1), #1056
            nn.ReLU(),
            #nn.MaxPool2d((2,2)),
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Dropout(0.25)
        )
        
        # 平滑化
        self.flatten = nn.Flatten()
        
        # 全結合
        self.classifier = nn.Sequential(
            nn.Linear(32, 256), #8448
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 2), #514
            #nn.Softmax()
        )
    
    
    def forward(self, input_data:Tensor):
        input_data = self.feature(input_data)
        input_data = self.flatten(input_data)
        input_data = self.classifier(input_data)
        #input_data = F.softmax(input_data)
        return input_data

最終的に非力なスマホで使うので 性能(画像認識能力)を多少削ってでも計算コストが低くなるように設計しています。具体的には コンボリューション回数やパラメーター数が少なくなるようにしています。

コンボリューションは最初のRGBの3ch入力、16ch出力のフィルターを除きDepthwise Separable Convolutionを採用しました。
性能はあまり変わらず、劇的にコンボリューション回数を減らすことができます。

続いて、画像処理のブロック2の最後をMaxPoolからGAPへ変更しました。これもパラメータ数を劇的に減らすことができます。
約50万パラメータ(8x8x32x256+256)を8000程度(32x256+256)まで減らしています。そのかわり性能が低下します。

参考記事

後編ではこのCNNモデルをMeta Quest3で実行します。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2