LoginSignup
0
0

美味しいバナナ判別ウェブアプリ作った話

Last updated at Posted at 2023-12-25

はじめに

大学の夏休み中に開催されたハッカソンに参加してきました。今回のハッカソンで開発した作品は、バナナの熟成度を判定し、完熟になるまでのおおよその日数を予測し提示するウェブアプリケーションです。

この作品では、Yolov5を使用してバナナの熟成度を判定する技術を採用しました。しかし、この技術をウェブ上で実行するための具体的なガイドやリソースが不足しており、比較的多くの人から実装方法について質問を受けたことから、この記事を作成した。

この記事では、Yolov5を使用してバナナの熟成度を判定し、それをウェブアプリケーションに統合する方法について詳しく説明します。実際のコード例やステップバイステップのガイドを提供し、興味を持つ人々が同様のプロジェクトを実施する際に役立つ情報を共有する。

仮想環境の構築、Yolov5の導入,学習済みデータの作成の仕方

仮想環境の構築とYolov5の導入,学習済みデータの作成の仕方について、以下のQiitaの記事を参考にすることをお勧めします。

この記事は、仮想環境のセットアップからYolov5のインストールまでのステップを詳しく解説しています。

ラベル付けの注意点

私のチームは、"ripe"、"half ripe"、"unripe" の3つのクラスに画像をラベル付けする作業を行いました。画像の数が多かったため、2人のメンバーが協力して、それぞれのパソコンで学習を行いました。学習したテキストデータは、以下のように記述されています。

テキストデータの例:

0 0.708833 0.540125 0.225667 0.577250

テキストデータの先頭にはクラスが割り当てられています。一つのパソコンで "ripe"、"half ripe"、"unripe" のラベル付けを行うと、それぞれのクラスに 0、1、2 のラベルが割り当てられます。このため、二つのパソコンでlabelimgを行うとクラス "0" が重複し、"ripe" と "unripe" が "ripe" として誤判定される問題が発生しました。

この時のクラス:

unripe = "0",ripe = "0", half ripe = "1"

私はこの問題に対処するために、重複したクラスのどちらかのテキストデータのクラスを "2" にしました。これにより、 "unripe" も検出することができました。

Streamlitでweb上で動かす

インポート

最初に必要なライブラリ等をインポートする。今回、streamlitを用いてweb上で動かすためにstreamlitが必須である。また、torchはPythonの機械学習用フレームワークです。通常PyTorchと呼ばれている。

import streamlit as st
import torch
import cv2
import numpy as np
from PIL import Image

YOLOv5モデルの読み込み

"torch.hub.load"は、PyTorchのモデルハブから事前にトレーニングされたモデルをダウンロードして読み込むための便利な関数である。今回は学習済みデータを作成したので学習モデルまでのパスをコード上の"path"に設定することで、モデルの事前トレーニング済みの重みや設定をダウンロードし、簡単にモデルを利用できる。
パスの指定で注意することはパスをc:...ではなくc:\...にすることである。理由はパスを C:... のように単一のバックスラッシュで指定すると、エスケープ文字として解釈され、予期せぬ動作が発生する可能性があるため、バックスラッシュを2つ重ねて C:\... のように指定することが必要である。

# 例: path='C:\\...\\yolov\\yolov5\\runs\\train\\exp2\\weights\\best.pt'
model = torch.hub.load("yolov5", 'custom', path='Path to the training data',source='local')

streamlitの画面レイアウトの設定

streamlitの基礎は以下のサイトを参照してください。

外付けカメラを使う際の注意点

外付けカメラを使用する場合は () 内の数字を変更する必要がある。カメラが内カメラのみの (既存の環境) ノートパソコンの場合は (0) で内カメラ、 (1) で外付けカメラとなる。

cap = cv2.videoCapture()

学習済みデータを用いた処理

BGR形式からRGB形式に変換

OpenCVは、画像をBGR形式で扱うが、ほとんどの画像処理ライブラリやモデルはRGB形式を前提としている。そのため、OpenCVで読み込んだ画像をそのまま他のライブラリやモデルに渡すと、色情報が正しく解釈されず、結果も正しくなくなる。今回のバナナ処理においては、バナナの熟度を判定するためのモデルを使用しているが、バナナの熟度は色によって大きく変わるため、色情報が正しくないと、モデルの判定結果も正しくなくなる可能性が高くなる。
よって、 OpenCV のピクセル形式は BGR 形式であるため、読み込んだ画像をBGR形式からRGB形式に変換する必要がある。

# BGR画像をRGBに変換(Convert BGR image to RGB)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

なぜ OpenCV のピクセル形式は BGR 形式なのか?

OpenCVの初期の開発者がBGRカラーフォーマットを選択した理由は、当時BGRカラーフォーマットがカメラメーカーやソフトウェアプロバイダーの間で人気があったのが理由とのこと。

検出確信度とクラスIDをもとに処理を行う

"results.pred" はモデルが予測した結果全体を格納する配列である。
"detection[0]"から"detection[3]"までは、検出された物体のバウンディングボックス(物体を囲む矩形)の座標を表している。これらは通常、座標(x1, y1, x2, y2)を示している。
"detection[4]"は、検出の信頼度スコアを表していて、スコアは0から1の間であり、1に近いほどモデルがその検出を信頼していることを示している。このコードでは、信頼度スコアが0.6以上の検出のみを考慮する。
"detection[5]"は、検出された物体のクラスIDを表している。このコードでは、クラスIDが0、1、2の場合、それぞれ"半熟のバナナ"、“完熟のバナナ”、"未熟のバナナ"とラベル付けされており、それ以外のクラスIDの場合は、"unknown"とラベル付けされる。

# バナナの検出確信度が0.6以上の場合のみ処理を行う
for detection in results.pred[0]:
    if detection[4] >= 0.6:
        # クラスIDと検出確信度を取得
        class_id = int(detection[5])
        #confidence = detection[5]

        # 検出されたクラスを出力
        if class_id == 0:
            detected_class = "半熟のバナナ"
        elif class_id == 1:
            detected_class = "完熟のバナナ"
        elif class_id == 2:
            detected_class = "未熟のバナナ"
        else:
            detected_class = "unknown"

        st.markdown(f"<h2>このバナナは {detected_class} です。</h2>", unsafe_allow_html=True)

Yolov5を触って思ったこと

利点

・初心者でもモデルを作成できる
・静止画は正確性が高い

欠点

・リアルタイム処理の誤検出が多い
・モデル作成に時間がかかる(デスクトップパソコンで一日かかりました)

ハッカソンの振り返りと感想

参考資料がまだ少ない最新の技術に挑戦したので、バックエンドを一人で取り組んだのは大変だった。結果は受賞はできなかったが、高評価を貰えた。でも、やっぱり悔しい!!

全体のソースコード

main8202.py
# 2023/08/02 14:00 Tokol1
# yolov5と同じフォルダにmain8202.pyを置く


import streamlit as st
import torch
import cv2
import numpy as np
from PIL import Image


# YOLOv5モデルの読み込み(Load My learned YOLOv5 model)
# 例: path='C:\\~\\yolov\\yolov5\\runs\\train\\exp2\\weights\\best.pt'
model = torch.hub.load("yolov5", 'custom', path='Path to the training data',source='local')



# Streamlitアプリ名(Streamlit app name)
st.title("   Ripenana  ")

# ウェブカメラを起動するためのスタートボタン
start_button1 = st.button("スマホカメラ")

start_button2 = st.button("ウェブカメラ")

# Upload image through Streamlit
uploaded_image = st.file_uploader("画像をアップロード↓", type=["jpg", "jpeg", "png"])

end_button = st.button("終了")

if end_button:
    st.write("終了します")
    st.stop()


# ウェブカメラを開く(Open webcam)
# cap = cv2.VideoCapture(0)の場合は、PCに接続されたカメラを使用します。
# If cap = cv2.VideoCapture(0), use the camera connected to the PC.
# cap = cv2.VideoCapture(1)の場合は、PCに接続されたカメラのうち、2番目に接続されたカメラを使用します。
# If cap = cv2.VideoCapture(1), use the second camera connected to the PC.
if start_button1:
    judge = 0

    cap = cv2.VideoCapture(0)

    # 画像表示のための最大幅を設定(Set maximum width for displaying images)
    max_width = 800

    # 画像表示用のプレースホルダーを作成(Create a placeholder for the image)
    placeholder = st.empty()

    while cap.isOpened():
        ret, frame = cap.read()

        if not ret:
         break

        # BGR画像をRGBに変換(Convert BGR image to RGB)
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # 推論を実行(Perform inference)
        image = Image.fromarray(frame_rgb)
        results = model(image)

        # 検出結果を表示(Display detection results)
        # 必要なし(Not needed)
        #st.write("### Detection Results:")
        #st.dataframe(results.pandas().xyxy[0])

        # フレームに境界ボックスを描画(Draw bounding boxes on the frame)
        annotated_frame = np.array(results.render()[0])

        # 境界ボックス付きのリアルタイム画像を表示(Display image with bounding boxes)
        placeholder.image(annotated_frame, caption="Detection Result", use_column_width=True, channels="RGB", output_format="auto")

        # バナナの検出確信度が0.6以上の場合のみ処理を行う
        for detection in results.pred[0]:
            if detection[4] >= 0.6:
                # クラスIDと検出確信度を取得
                class_id = int(detection[5])
                #confidence = detection[5]

                # 検出されたクラスを出力
                if class_id == 0:
                    detected_class = "半熟のバナナ"
                elif class_id == 1:
                    detected_class = "完熟のバナナ"
                elif class_id == 2:
                    detected_class = "未熟のバナナ"
                else:
                    detected_class = "unknown"

                st.markdown(f"<h2>このバナナは {detected_class} です。</h2>", unsafe_allow_html=True)

                if class_id == 0:
                    st.markdown(" このバナナは常温で3~7日後に完熟バナナになります。3日後にもう一度確認しよう!! ")
                elif class_id == 1:
                    st.markdown(" このバナナは食べごろです。今食べるべし!! ") 
                elif class_id == 2:
                    st.markdown(" このバナナは常温で4~7日後に半熟バナナになります。4日後もう一度確認しよう!! ") 

                # 折りたたみ要素
                with st.expander('検出信頼度を表示'):
                    st.subheader('検出信頼度')
                    st.write(f"検出信頼度は {detection[4]:.2f}", unsafe_allow_html=True)

                # 信頼度が0.6以上のクラスが検出されたら、ループを抜ける
                judge = 1

        if judge == 1:
            break
    
    # キャプチャを解放(Release the capture)
    cap.release()
    cv2.destroyAllWindows()


if start_button2:
    judge = 0

    cap = cv2.VideoCapture(1)

    # 画像表示のための最大幅を設定(Set maximum width for displaying images)
    max_width = 800

    # 画像表示用のプレースホルダーを作成(Create a placeholder for the image)
    placeholder = st.empty()

    while cap.isOpened():
        ret, frame = cap.read()

        if not ret:
         break

        # BGR画像をRGBに変換(Convert BGR image to RGB)
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # 推論を実行(Perform inference)
        image = Image.fromarray(frame_rgb)
        results = model(image)

        # 検出結果を表示(Display detection results)
        # 必要なし(Not needed)
        #st.write("### Detection Results:")
        #st.dataframe(results.pandas().xyxy[0])

        # フレームに境界ボックスを描画(Draw bounding boxes on the frame)
        annotated_frame = np.array(results.render()[0])

        # 境界ボックス付きのリアルタイム画像を表示(Display image with bounding boxes)
        placeholder.image(annotated_frame, caption="Detection Result", use_column_width=True, channels="RGB", output_format="auto")

        # バナナの検出確信度が0.6以上の場合のみ処理を行う
        for detection in results.pred[0]:
            if detection[4] >= 0.6:
                # クラスIDと検出確信度を取得
                class_id = int(detection[5])
                #confidence = detection[5]

                # 検出されたクラスを出力
                if class_id == 0:
                    detected_class = "半熟のバナナ"
                elif class_id == 1:
                    detected_class = "完熟のバナナ"
                elif class_id == 2:
                    detected_class = "未熟のバナナ"
                else:
                    detected_class = "unknown"

                st.markdown(f"<h2>このバナナは {detected_class} です。</h2>", unsafe_allow_html=True)

                if class_id == 0:
                    st.markdown(" このバナナは常温で3~7日後に完熟バナナになります。3日後にもう一度確認しよう!! ")
                elif class_id == 1:
                    st.markdown(" このバナナは食べごろです。今食べるべし!! ") 
                elif class_id == 2:
                    st.markdown(" このバナナは常温で4~7日後に半熟バナナになります。4日後もう一度確認しよう!! ") 

                # 折りたたみ要素
                with st.expander('検出信頼度を表示'):
                    st.subheader('検出信頼度')
                    st.write(f"検出信頼度は {detection[4]:.2f}", unsafe_allow_html=True)

                # 信頼度が0.6以上のクラスが検出されたら、ループを抜ける
                judge = 1

        if judge == 1:
            break
    
    # キャプチャを解放(Release the capture)
    cap.release()
    cv2.destroyAllWindows()

if uploaded_image is not None:
    # Display the uploaded image
    # st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)

    # Perform inference if the user uploads an image
    image = Image.open(uploaded_image)
    results = model(image)

    # Display detection results
    # st.write("### Detection Results:")
    # st.dataframe(results.pandas().xyxy[0])

    # Display image with bounding boxes
    st.image(results.render()[0], caption="Detection Result", use_column_width=True)

    # バナナの検出確信度が0.6以上の場合のみ処理を行う
    for detection in results.pred[0]:
        if detection[4] >= 0.3:
            # クラスIDと検出確信度を取得
            class_id = int(detection[5])
            #confidence = detection[5]

            # 検出されたクラスを出力
            if class_id == 0:
                detected_class = "半熟のバナナ"
            elif class_id == 1:
                detected_class = "完熟のバナナ"
            elif class_id == 2:
                detected_class = "未熟のバナナ"
            else:
                detected_class = "unknown"

            st.markdown(f"<h2>このバナナは {detected_class} です。</h2>", unsafe_allow_html=True)

            if class_id == 0:
                st.markdown(" このバナナは常温で3~7日後に完熟バナナになります。3日後にもう一度確認しよう!! ")
            elif class_id == 1:
                st.markdown(" このバナナは食べごろです。 今食べるべし!! ") 
            elif class_id == 2:
                st.markdown(" このバナナは常温で4~7日後に半熟バナナになります。4日後もう一度確認しよう!! ") 

            # 折りたたみ要素
            with st.expander('検出信頼度を表示'):
                st.subheader('検出信頼度')
                st.write(f"検出信頼度は {detection[4]:.2f}", unsafe_allow_html=True)
            
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0