LoginSignup
1
0

【開発】YOLOv8を使用した睡眠検知アプリを作ってみた

Posted at

はじめに

この記事では、YOLOv8を使用した睡眠検知アプリの開発についてまとめました。

去年から、筆者は資格勉強やプログラミングが生活の一部となり、徐々に生活習慣が変化しました。

以前は7時に起きて24時に就寝する規則正しい生活を送っていましたが、今では5時に起きて23時前には眠気に耐えきれず寝落ちするという不安定な生活になってしまいました。

この生活習慣の問題に対して、AIを使用した技術が活かせないかという思い付きが今回の開発の発端になります。

開発目的としては、「睡眠状態を検知し自動で消灯してくれるアプリ作って、健康的な睡眠を取ろう!」です。

YOLOの技術的説明はとても長く、複雑になるので以下の記事を参考にしてみてください。

また作成したコードは以下のリポジトリに上げていますので、是非参考にしていただければ幸いです。

1. YOLOの学習・推論

YOLOの使用には ultralytics という便利なライブラリがあります。このライブラリを使用することでYOLOの学習済みモデルのダウンロードや学習、推論が一気通貫的に行うことが出来ます。

1.1. 学習

YOLOには事前学習済みモデルがいくつかありますが、全ての物体、行動を検出できるわけではありません。

そのため、自身が行いたい、検出したいタスクに対して最適なモデルを学習させる必要があります。

ここではYOLOの学習手順について説明します。処理のフローとしては以下のようになります。

  1. ディレクトリの整理
  2. データ(画像ファイル)の用意
  3. データに対するアノテーション
  4. 学習

ディレクトリの整理

学習をする前に、作成するデータを保存するフォルダをしっかりと決めましょう。

具体的には以下の配置になります。

現在のディレクトリ
train.py
|
|
dataset
   |
    _ images
   |      |
   |       _ train
   |       _ val
    _ labels
          |
           _ train
           _ val

各train, vakには訓練データ、評価データとなる画像ファイルとアノテーションのテキストファイルを格納します。

データ(画像ファイル)の用意

この処理については自身が行いたいタスクに適した画像ファイルを用意する、もしくは作成する必要があります。筆者は、対象の状態が起床か睡眠かを判別する必要があったので、この2つに属する画像を自身で作成しました。

具体的な作成手順は以下になります。

以下はデータセット取得の手順である。
1.データセットはwebカメラ(BRIO 500)を使用。フレームサイズ:640 × 640とする。
2. 起床時、睡眠時の動画(1分)を録画する。服装の違いで推論結果が変わらないよう、異なる服装でも同じ動作を行うよう注意を払う。服装の組み合わせごとに、訓練動画:評価動画:テスト動画を3:2:1の割合(動画の枚数)で取得している。ここでは、一定時間ごとに異なる動作をしている。
3. 動画から一定時間ごとにフレームを抽出する。そのため一つ動画から複数枚の異なる動作のフレームが抽出できる。私のデータセットでは訓練データ:評価データ:テストデータを720:480:240枚とした。
4. 取得した画像ファイルをlabelImgを使用して、アノテーションを行う。

実行したコードは以下になります。

get_video.py
import cv2
import time
from glob import glob
import os


def get_output_video_dir(output_dir, train_status, awake_status, shirt_color, pant_color):
    video_path_per_status = os.path.join(output_dir, train_status, awake_status)
    if not os.path.exists(video_path_per_status):
        os.makedirs(video_path_per_status)

    #既存の動画ファイルを取得
    files = glob(video_path_per_status + "/*.mp4")
    if len(files) == 0:
        new_number = str(0).zfill(2)
    else:
        max_number = 0
        for file in files:
            number = int(file[-6:-4])
            if number > max_number:
                max_number = number
        new_number = str(max_number + 1).zfill(2)

        #新しいファイル名を作成
    output_video_dir = os.path.join(video_path_per_status,
                                     train_status + "_" + 
                                     awake_status + "_" + 
                                     shirt_color + "_" +
                                     pant_color + "_" +
                                     new_number + ".mp4")
    return output_video_dir

directory = "video_data"

is_neoti = 0
frame_list = []
#最高録画時間は4時間
over_time = 60

cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 640) 
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
original_fps = int(cap.get(cv2.CAP_PROP_FPS))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
output_video_dir = get_output_video_dir(directory, "test", "sleep", "white", "blue")
print(output_video_dir)
out = cv2.VideoWriter(output_video_dir, fourcc, original_fps, (width, height), isColor=True)

frame_count = 0

print("start")

time.sleep(3)

print("record")

start_time = time.time()
while True:

    ret, frame = cap.read()

    current_time = time.time()  # 現在時刻を取得
    elapsed_time = current_time - start_time  # 経過時間を計算

    # 画像を保存
    frame_list.append(frame)

    if elapsed_time > over_time:
        break

for frame in frame_list:
     out.write(frame)

print(len(frame_list))

cap.release()
out.release()

video2frame.py
import cv2
from glob import glob
import os




#動画からフレーム画像を抽出する関数
#一つの動画から任意の数のフレームを抽出する
def output_frame(num_frame):

    data_dir = './video_data'

    output_dir = './dataset'

    print(output_dir)

    train_video_dir_list = glob(os.path.join(data_dir, 'train') + '/*' + '/*.mp4')
    valid_video_dir_list = glob(os.path.join(data_dir, 'val') + '/*' + '/*.mp4')
    test_video_dir_list = glob(os.path.join(data_dir, 'test') + '/*' + '/*.mp4')

    print('train_video_dir_list:', train_video_dir_list)

    #画像ファイルの出力先   
    train_frame_dir = os.path.join(output_dir, 'images', 'train')
    valid_frame_dir = os.path.join(output_dir, 'images', 'val')
    test_frame_dir = os.path.join(output_dir, 'images', 'test')

    print('フレーム画像の出力を開始します。')

    #動画ファイルをフレーム画像に変換

    #訓練データ
    video2frame(train_video_dir_list, train_frame_dir, num_frame)

    #検証データ
    video2frame(valid_video_dir_list, valid_frame_dir, num_frame)

    #テストデータ
    video2frame(test_video_dir_list, test_frame_dir, num_frame)


    print('フレーム画像の出力が完了しました。')


#動画からフレーム画像を等間隔に抽出する関数
def video2frame(video_dir_list, frame_dir, num_frame=10):
    for video_dir in video_dir_list:
        cap = cv2.VideoCapture(video_dir)
        if not cap.isOpened():
            print('Error: Could not open video.')
            return

        frame_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_rate = int(frame_num / num_frame)

        for i in range(num_frame):
            cap.set(cv2.CAP_PROP_POS_FRAMES, i * frame_rate)
            ret, frame = cap.read()
            if ret:
                frame_name = os.path.splitext(os.path.basename(video_dir))[0] + '_frame' + str(i) + '.jpg'
                frame_path = os.path.join(frame_dir, frame_name)
                cv2.imwrite(frame_path, frame)
            else:
                print('Error: Could not read frame.')
                break

        cap.release()

def main():
    output_frame(20)

if __name__ == '__main__':
    main()

次にデータに対するアノテーション

上記の手順では動画から画像ファイルの生成を行いました。しかしこの画像に対して領域とラベルが定義されていません。そのため、画像1枚ごとに領域とラベルを対応付ける必要があります。この作業を アノテーション と呼びます。

画像のアノテーションを行うモジュールにはいくつか種類がありますが、筆者は labelImg というモジュールを使用しました。

labelImgの使い方はとても簡単です。以下の手順を行うだけで利用できます。

pip install labelImg
labelImg

labelImgを開くと以下のようなGUIが表示されます。

image.png

アノテーションの具体的な手順については以下になります。

  1. 左の欄にあるアノテーション方式を YOLO に変更。(おそらく最初はPascalVOCになっていると思われます)
  2. Open dir からアノテーションを行いたい画像ファイルが保存されたフォルダを選択。
  3. Change Save dir からアノテーション結果を保存するフォルダを選択。
  4. 表示されている画像を右クリックし Create Rect Box、またはキーボードの Wを押し判定したい領域を描画。
  5. 判定したいラベルを記入。
  6. キーボードの Dを押し、次の画像ファイルのアノテーションを行う。
  7. 3 ~ 5をアノテーションが終了するまで繰り返す。

Change Save dirで選択したフォルダにはアノテーションデータとは別に classes.tex というラベルの名称が保存されたテキストファイルが作成されます。

この手順により、YOLOの学習に使用する画像ファイルとアノテーションデータ(テキストファイル)を作成できました。

それでは実際に学習をしていきましょう。

学習

学習をさせる前に、YOLOに学習データと評価データがどのディレクトリにあるかを教える必要があります。ultralyticsでは 以下のような yamlファイルを作成し、データのパスが記載されたyamlファイルのパスを渡すことでYOLOの学習を行うことが出来ます。

config.yaml
path: C:/Users/user/VScode/python/neoti/yolov8_sleep_recognition/dataset"  # dataset root dir
train: C:/Users/user/VScode/python/neoti/yolov8_sleep_recognition/dataset/images/train  # train images
val: C:/Users/user/VScode/python/neoti/yolov8_sleep_recognition/dataset/images/val  # valid images 
test: C:/Users/user/VScode/python/neoti/yolov8_sleep_recognition/dataset/images/test  # test images

# Classes
names: 
  0: awake
  1: sleep

nc: 2  # number of classes

各項目は以下のような要素で構成されています。

path: 訓練データ、評価データなどが格納されたフォルダを格納しているフォルダへのパス(ここではdataset)
train: 訓練データが格納されたフォルダのパス
val: 評価データが格納されたフォルダのパス
test: テストデータが可能されたフォルダのパス
names: 検出したいラベル名
nc: 検出したいラベルの数

namesにはlabelImgで作成したラベル名と同じ順番、同じ名称を記載するようにしましょう。

それでは、用意したデータを用いて学習させていきましょう。

train.py
from ultralytics import YOLO

def train(model_name):
    # モデルの初期化
    model = YOLO(model_name)

    data_path = "C:/Users/user/VScode/python/neoti/yolov8_sleep_recognition/config.yaml"

    # モデルの学習
    model.train(data=data_path, epochs=500, patience=50, batch=8, imgsz=640, save=True,
                device="cuda:0", verbose=True, val=True)

    print('学習が完了しました。')

def main():
    # モデルの学習
    model_name = 'yolov8n'
    train(model_name)

if __name__ == '__main__':
    main()

学習結果は以下のようになりました。一番軽量なモデルを使用したため、学習時間は1時間もかかりませんでした。

1.2. 推論

ultralyticsを使用したYOLOの学習では、評価データなどを使用して様々な評価指標基にした結果を保存してくれます。

以下はその結果になります。しっかりと起床と睡眠の判別が出来ていることが分かります。

image.png
image.png
image.png
image.png
image.png

2. アプリ開発

それでは、自身のデータセットで学習・推論させた最適なYOLOモデルを使用してアプリ開発を行っていきましょう。

ちなみに最適なモデルは学習終了時に作成される runs/detect/train{n}/weights/best.pt に保存されています。

学習、推論を行ったファイルの一つ上の階層に app.pyを作成します。

作成したコードは以下になります。(以下の記事を参考にさせていただきました。)

app.py
import tkinter as tk
from tkinter import font
import numpy as np
import matplotlib.pyplot as plt
import cv2
from ultralytics import YOLO
from PIL import Image, ImageTk
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import matplotlib.animation as animation

def get_class_name(all_classes, all_confidences):
    if len(all_classes) == 0:
        return 0  # 検出されない場合もawakeとする
    else:
        awake_index = np.array(np.where(all_classes == 0))
        sleep_index = np.array(np.where(all_classes == 1))

        if awake_index.shape[1] == 0 and sleep_index.shape[1] != 0:
            return 1
        elif awake_index.shape[1] != 0 and sleep_index.shape[1] == 0:
            return 0
        else:
            awake_probs = np.max(all_confidences[awake_index])
            sleep_probs = np.max(all_confidences[sleep_index])

            if awake_probs > sleep_probs:
                return 0
            else:
                return 1

class Application(tk.Frame):
    def __init__(self, master, video_source=0, model_weight=None):
        super().__init__(master)

        self.master.geometry("1500x750")
        self.master.title("Tkinter with Video Streaming and Capture")
        self.model = YOLO(model_weight)

        # ---------------------------------------------------------
        # ポイント設定
        # ---------------------------------------------------------
        self.camera_x = 10
        self.camera_y = 10
        self.camera_w = 640
        self.camera_h = 480
        self.camera_padx = 10
        self.camera_pady = 10

        self.graph_x = self.camera_x + self.camera_w + self.camera_padx
        self.graph_y = self.camera_y
        self.graph_w = 640
        self.graph_h = 480
        self.graph_padx = 10
        self.graph_pady = 10

        # ---------------------------------------------------------
        # グラフの設定
        # ---------------------------------------------------------
        self.class_array = np.array([])
        self.mean_rate_array = np.array([])
        self.window_time = 100
        self.next_time = 3
        self.processing_time = 40
        self.delay = self.next_time + self.processing_time  # [mili seconds]
        self.reset_threshold = int(self.window_time / (self.delay / 1000))
        self.xrange = np.arange(0, self.window_time, self.delay / 1000)
        self.warn_threshold = 0.5
        self.sleep_threshold = 0.9
        self.wait_time = 10

        # ---------------------------------------------------------
        # フォント設定
        # ---------------------------------------------------------
        self.font_frame = font.Font(family="Meiryo UI", size=15, weight="normal")
        self.font_btn_big = font.Font(family="Meiryo UI", size=20, weight="bold")
        self.font_btn_small = font.Font(family="Meiryo UI", size=15, weight="bold")

        self.font_lbl_bigger = font.Font(family="Meiryo UI", size=45, weight="bold")
        self.font_lbl_big = font.Font(family="Meiryo UI", size=30, weight="bold")
        self.font_lbl_middle = font.Font(family="Meiryo UI", size=15, weight="bold")
        self.font_lbl_small = font.Font(family="Meiryo UI", size=12, weight="normal")

        # ---------------------------------------------------------
        # ビデオソースのオープン
        # ---------------------------------------------------------

        self.vcap = cv2.VideoCapture(video_source)
        self.width = self.vcap.get(cv2.CAP_PROP_FRAME_WIDTH)
        self.height = self.vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)

        # ---------------------------------------------------------
        # ウィジェットの作成
        # ---------------------------------------------------------

        self.create_widgets()


        # ---------------------------------------------------------
        # 状態の更新フラグ
        # ---------------------------------------------------------
        self.is_updated = False

        self.update()

    def create_widgets(self):
        # Frame_Camera
        self.frame_cam = tk.LabelFrame(self.master, text='Camera', font=self.font_frame)
        self.frame_cam.place(x=self.camera_x, y=self.camera_y, width=self.camera_w, height=self.camera_h)
        self.frame_cam.grid_propagate(0)

        # 画像用Canvas
        self.canvas1 = tk.Canvas(self.frame_cam, width=self.camera_w, height=self.camera_h)
        self.canvas1.grid(column=0, row=0, padx=10, pady=10)

        # Graph
        self.frame_graph = tk.LabelFrame(self.master, text='Graph', font=self.font_frame)
        self.frame_graph.place(x=self.graph_x, y=self.graph_y, width=self.graph_w, height=self.graph_h)
        self.frame_graph.grid_propagate(0)

        self.fig = plt.Figure()
        self.ax = self.fig.add_subplot(111)
        self.ax.axhline(y=self.warn_threshold, color='orange', linestyle='--')
        self.ax.axhline(y=self.sleep_threshold, color='red', linestyle='--')
        self.ax.set_xlabel('Time [s]')
        self.ax.set_ylabel('Sleep Rate')
        self.ax.text(0, self.warn_threshold, 'warn', color='orange')
        self.ax.text(0, self.sleep_threshold, 'sleep', color='red')
        self.ax.set_xlim(0, self.window_time)
        self.ax.set_ylim(0, 1)

        self.canvas2 = FigureCanvasTkAgg(self.fig, master=self.frame_graph)
        self.canvas2.get_tk_widget().pack(fill=tk.BOTH, expand=True)

        # Control
        self.control = tk.LabelFrame(self.master, text='Control', font=self.font_frame)
        self.control.place(x=10, y=550, width=self.camera_w + self.graph_padx + self.graph_w + self.graph_padx + 150, height=120)
        self.control.grid_propagate(0)

        # Window Time
        self.lbl_window_time = tk.Label(self.control, text="Window Time", font=self.font_lbl_small)
        self.lbl_window_time.grid(column=0, row=0, padx=10, pady=10)
        self.entry_window_time = tk.Entry(self.control, font=self.font_lbl_small)
        self.entry_window_time.grid(column=1, row=0, padx=10, pady=10)

        # Warn Threshold
        self.lbl_warn_threshold = tk.Label(self.control, text="Warn Threshold", font=self.font_lbl_small)
        self.lbl_warn_threshold.grid(column=2, row=0, padx=10, pady=10)
        self.entry_warn_threshold = tk.Entry(self.control, font=self.font_lbl_small)
        self.entry_warn_threshold.grid(column=3, row=0, padx=10, pady=10)

        # Sleep Threshold
        self.lbl_sleep_threshold = tk.Label(self.control, text="Sleep Threshold", font=self.font_lbl_small)
        self.lbl_sleep_threshold.grid(column=4, row=0, padx=10, pady=10)
        self.entry_sleep_threshold = tk.Entry(self.control, font=self.font_lbl_small)
        self.entry_sleep_threshold.grid(column=5, row=0, padx=10, pady=10)

        # Update Button
        self.btn_update = tk.Button(self.control, text='Update', font=self.font_btn_big, command=self.update_settings)
        self.btn_update.grid(column=6, row=0, padx=20, pady=10)

        # Close Button
        self.btn_close = tk.Button(self.control, text='Close', font=self.font_btn_big, command=self.press_close_button)
        self.btn_close.grid(column=7, row=0, padx=20, pady=10)

    def update_settings(self):
        try:
            self.window_time = int(self.entry_window_time.get())
            self.warn_threshold = float(self.entry_warn_threshold.get())
            self.sleep_threshold = float(self.entry_sleep_threshold.get())
            
            self.reset_threshold = int(self.window_time / (self.delay / 1000))
            self.xrange = np.arange(0, self.window_time, self.delay / 1000)

            self.mean_rate_array = np.array([])
            self.class_array = np.array([])
            self.xrange = np.arange(0, self.window_time, self.delay / 1000)

            self.x = self.xrange[:len(self.mean_rate_array)]
            self.y = self.mean_rate_array

            self.ax.clear()
            self.ax.axhline(y=self.warn_threshold, color='orange', linestyle='--')
            self.ax.axhline(y=self.sleep_threshold, color='red', linestyle='--')

            self.ax.set_xlabel('Time [s]')
            self.ax.set_ylabel('Sleep Rate')

            self.ax.text(0, self.warn_threshold, 'warn', color='orange')
            self.ax.text(0, self.sleep_threshold, 'sleep', color='red')

            self.ax.set_xlim(0, self.window_time)
            self.ax.set_ylim(0, 1)

            self.canvas2.draw()
        except ValueError:
            print("Invalid input for one of the settings")

    def update(self):
        # ビデオソースからフレームを取得
        _, frame = self.vcap.read()

        results = self.model(frame, show=False, conf=0.65, iou=0.5, device='cuda:0')

        all_classes = results[0].boxes.cls.cpu().numpy()
        all_confidences = results[0].boxes.conf.cpu().numpy()

        detected_class = get_class_name(all_classes, all_confidences)  # 0:awake, 1:sleep

        self.class_array  = np.append(self.class_array, detected_class)

        self.reset_threshold = int(self.window_time / (self.delay / 1000))

        if len(self.class_array) > self.reset_threshold:
            self.class_array = self.class_array[1:]

        self.wait_threshold = int(self.wait_time / (self.delay / 1000))
        if len(self.class_array) > self.wait_threshold:
            mean_rate = np.sum(self.class_array[self.class_array != None]) / len(self.class_array)
        else:
            mean_rate = 0  # 初期のの推論結果は変動幅が大きいため0とする

        self.mean_rate_array = np.append(self.mean_rate_array, mean_rate)

        if len(self.mean_rate_array) > self.reset_threshold:
            self.mean_rate_array = self.mean_rate_array[1:]

        self.x = self.xrange[:len(self.mean_rate_array)]
        self.y = self.mean_rate_array

        # グラフの更新
        self.ax.clear()

        # グラフに閾値を表示
        self.ax.axhline(y=self.warn_threshold, color='orange', linestyle='--')
        self.ax.axhline(y=self.sleep_threshold, color='red', linestyle='--')

        # ラベルの設定
        self.ax.set_xlabel('Time [s]')
        self.ax.set_ylabel('Sleep Rate')

        self.ax.text(0, self.warn_threshold, 'warn', color='orange')
        self.ax.text(0, self.sleep_threshold, 'sleep', color='red')

        # 閾値を超えたら色を変える
        if self.y[-1] > self.sleep_threshold:
            self.ax.plot(self.x, self.y, color='red')
        elif self.y[-1] > self.warn_threshold:
            self.ax.plot(self.x, self.y, color='orange')
        else:
            self.ax.plot(self.x, self.y, color='blue')
        self.canvas2.draw()
        
        # 画像の更新
        annotated_img = results[0].plot()
        annotated_img = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
        self.photo = ImageTk.PhotoImage(image=Image.fromarray(annotated_img))

        # self.photo -> Canvas
        self.canvas1.create_image(0, 0, image=self.photo, anchor=tk.NW)

        self.master.after(self.next_time, self.update)

    def press_close_button(self):
        self.master.destroy()

def main():
    pretrained_weight_path = "yolov8_sleep_recognition/runs/detect/train10/weights/best.pt"

    root = tk.Tk()
    app = Application(master=root, video_source=0, model_weight=pretrained_weight_path)
    app.mainloop()

if __name__ == "__main__":
    main()

おおまかな処理の流れは以下になります。

  1. opencvから取得したフレームをYOLOに渡し、推論結果を取得(ここでは信頼度、バウンディボックス付きの画像)。
  2. 各対象の信頼度を比較し、最大値を持つ対象のラベルを取得。
  3. 任意で設定した window_time 内にあるラベル郡の平均値を取る。
  4. 取得したラベル郡の平均値を睡眠状態の比率として取得し、現在から最大で window_time前までをグラフに描画する。
  5. 2で取得したバウンディボックス付きの画像を表示。
  6. 1~5を繰り返す。

他にも、window_timeや各閾値の更新処理なども実装しています。

このようにして、設定した閾値を超えた時点で対象が睡眠状態であると判別できるようになりました。

まとめ

ここまで読んでくださり、ありがとうございます。

これまでの開発記事は自然言語処理方面に特化していて、画像処理方面は手薄だったのでいい経験になりました。

実は開発当初はYOLOを使用する予定はありませんでした。計画段階では画像認識ではなく、あくまで 動画分類 のタスクから睡眠を検出しようとしていました。mediapipeから各関節点を取得しLSTMで学習させたところ精度が向上しませんでした。プライベートの予定も忙しかったので、急遽同様の動画データを使用できる方法を探したところ、高速かつ精度が高いYOLOを使用することに決定しました。

しかし、YOLOを実行したことはあっても実際に学習させて実環境に応用したことが無かったので、とても充実した時間でした。

本来SwitcbotのAPIが利用できるボタンを押すボットを使用して、自動消灯まで実装を考えたのですが、どうしてかWifiに繋がらず自動消灯まではたどり着けませんでした。どうやらSesameという企業も同様のデバイスを販売しているとのことなので、そちらも試してみたいと思います。

(断念したLSTM+mediapipeのリポジトリは以下にありますので、興味がある方は見てみてください…)

参考文献

【物体検出手法の歴史 : YOLOの紹介】
Ultralytics YOLOv8 モード
YOLO v8/YOLO v9で物体検出|独自(カスタム)データの学習と推論を実践
パソコンに接続したカメラの映像をGUIに表示する。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0