1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

植物の写真を認識するAIアプリケーション

Last updated at Posted at 2024-09-24

はじめに

1.プログラミングスクールで学んだディープラーニング(Kerasを使用)を応用して画像処理を習得しました。
2.この技術を活かして、植物の写真を認識し、どの植物か判定するAIアプリケーションを作成しました。
3.学習内容の総まとめとしてQiitaの記事にしました。

解決したい課題

1.ユーザーの学習支援
植物の特徴や育て方、病害虫対策などを提供し、ユーザーの植物知識を向上させます。
2.パーソナライズされたアドバイス
ユーザーの特定のニーズに基づいて、植物の育成方法や手入れに関するアドバイスを提供します。

主な機能

  1. 植物識別
    ユーザーがアップロードした植物の写真を解析し、その種類を特定します。

今後さらに追加したい機能
2. 情報提供
学名、和名、生態、育成方法、適切な環境条件などの基本情報を提供します。
3. 病気診断
植物の病気や害虫を識別し、適切な対策や治療法を提案します。

データセット

Pl@ntNet-300K
このデータセットは、1081種類(クラス)をカバーし、306,146枚の植物画像が含まれています。クラスの曖昧さが高く、クラスの不均衡が特徴です。

実行環境

パソコン:MacBook Air
開発環境:Cursor
言語:Python
ライブラリ:TensorFlow/Keras Pillow (PIL) Django JSON Numpy

開発の流れ

1.アプリのディレクトリ構成
2.モデルの訓練、評価、予測
3.仮装環境構築と各ファイルの設定、ブラウザの動作確認

以下の手順で画像の分類をするコードを書きました。

アプリのディレクトリ構成

plantnetvgg16_2/                 # conda 仮想環境ディレクトリ

└── plantnet_clone_1/            # Django プロジェクトディレクトリ
    ├── manage.py                # Djangoプロジェクト管理スクリプト
    ├── plantnet_clone_1/        # プロジェクト設定フォルダ
       ├── __init__.py
       ├── settings.py          # プロジェクトの設定ファイル
       ├── urls.py              # プロジェクト全体のURL設定
       ├── asgi.py              # ASGI設定ファイル
       └── wsgi.py              # WSGI設定ファイル
    
    ├── identify/                # アプリディレクトリ
       ├── __init__.py
       ├── admin.py             # 管理画面設定
       ├── apps.py              # アプリ設定ファイル
       ├── models.py            # モデル定義ファイル
       ├── views.py             # ビュー(処理)定義ファイル
       ├── urls.py              # アプリのURL設定(必要に応じて作成)
       ├── migrations/          # マイグレーションファイル
          └── __init__.py
       └── templates/           # テンプレート(HTML)ディレクトリ
           └── identify/        # アプリ専用テンプレート(任意)
               └── result.html
               └── upload.html
    
    └── db.sqlite3               # デフォルトのSQLiteデータベースファイル(生成される)

実行したコード

モデル訓練用コード(データの読み込み、画像前処理、モデル保存)その後新たなモデルができる。

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import json
import os
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import mixed_precision

# Mixed precision policy
mixed_precision.set_global_policy('mixed_float16')

# MobileNetV2 のインポート
from tensorflow.keras.applications import MobileNetV2

# GPU メモリの設定
physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)

# データパスの設定
train_data_dir = '/Users/mame/Downloads/plantnet_300K/images/train'
validation_data_dir = '/Users/mame/Downloads/plantnet_300K/images/validation'

# データジェネレーターの設定
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)

val_datagen = ImageDataGenerator(rescale=1./255)

# データの読み込み
train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(224, 224),
    batch_size=16,
    class_mode='categorical',
    shuffle=True
)

validation_generator = val_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(224, 224),
    batch_size=16,
    class_mode='categorical'
)

# クラス数の設定
num_classes = len(train_generator.class_indices)

# MobileNetV2ベースモデルの作成
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# 新しいトップ層の追加
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)  # 中間層
predictions = Dense(num_classes, activation='softmax')(x)  # 出力層

# モデルの定義
model = Model(inputs=base_model.input, outputs=predictions)

# ベースモデルの重みを一部固定
for layer in base_model.layers[:-30]:  # 下位の層は凍結
    layer.trainable = False

# モデルのコンパイル
model.compile(optimizer=Adam(learning_rate=0.0001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# コールバック設定
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-6)

# steps_per_epoch と validation_steps を設定
steps_per_epoch = train_generator.samples // train_generator.batch_size
validation_steps = validation_generator.samples // validation_generator.batch_size

# モデルのトレーニング
history = model.fit(
    train_generator,  # 修正: tf.data.Dataset.from_generator を削除
    steps_per_epoch=steps_per_epoch,
    epochs=10,
    validation_data=validation_generator,  # 修正: tf.data.Dataset.from_generator を削除
    validation_steps=validation_steps,
    callbacks=[early_stopping, reduce_lr]
)

# モデルとクラスインデックスの保存
os.makedirs('data', exist_ok=True)
model.save('data/MobileNetV2_model.keras')

with open('data/class_indices.json', 'w') as f:
    json.dump(train_generator.class_indices, f)

Figure_1.png

モデル評価

import tensorflow as tf
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input

# モデルの読み込み
model = tf.keras.models.load_model('/Users/mame/plantnetvgg16_2/plantnet_clone_1/data/MobileNetV2_model.keras')

# テストデータのディレクトリ
test_image_dir = '/Users/mame/Downloads/plantnet_300K/images/test'

# 画像データの前処理
test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

# バッチサイズを設定してデータを生成
test_generator = test_datagen.flow_from_directory(
    test_image_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',  # クラスラベルがone-hotエンコーディングされている場合
    shuffle=False
)

# モデルを評価 (損失値と精度を返す)
loss, accuracy = model.evaluate(test_generator)

# 評価結果を表示
print(f"Test Loss: {loss}")
print(f"Test Accuracy: {accuracy}")

Test Loss: 1.8676471710205078
Test Accuracy: 0.6171895265579224
Figure_2.png

仮装環境の構築

conda create -n plantnetvgg16_2 python=3.11 #仮想環境の作成
conda activate plantnetvgg16_2 #作成した環境を有効化します
pip install django tensorflow pillow scipy #必要なパッケージのインストール
django-admin startproject plantnet_clone_1
cd plantnet_clone_1 #Django プロジェクトの作成
python manage.py startapp identify #Django アプリ「identify」の作成

settings.pyの設定

INSTALLED_APPS = [
    # 既存のアプリ
    'identify',  # 追加するアプリ
]

モデルの定義

from django.db import models
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
import numpy as np

# MobileNetV2モデルをグローバルに準備してキャッシュする
model = MobileNetV2(weights='imagenet', include_top=False)



class PlantImage(models.Model):
    image = models.ImageField(upload_to='images/')
    uploaded_at = models.DateTimeField(auto_now_add=True)
    features = models.BinaryField(default=b'')  # 特徴量をバイナリとして保存

    def extract_features(self):
        # 画像の前処理
        img_tensor = image.load_img(self.image.path, target_size=(224, 224))
        img_array = image.img_to_array(img_tensor)
        img_array = np.expand_dims(img_array, axis=0)
        img_array = preprocess_input(img_array)

        # 特徴量の抽出
        features = model.predict(img_array)

        # 特徴量を1次元配列に変換
        features = features.flatten()
        if np.all(features == 0):
            print("Warning: All features are zero!")

        return features

class PlantFeature(models.Model):
    features = models.BinaryField()  # 特徴量をバイナリとして保存

    def save_features(self, features):
        # 特徴量をバイナリ形式で保存
        features_bin = features.tobytes()
        self.features = features_bin
        self.save()

    def get_features(self):
        # バイナリデータを numpy 配列に変換して返す
        features_array = np.frombuffer(self.features, dtype=np.float32)


        return features_array

Viewの設定

import tensorflow as tf
from django.conf import settings
from django.shortcuts import render
from django.http import JsonResponse
from .forms import UploadImageForm
from PIL import Image
import numpy as np
import json
import os

# Djangoプロジェクトのベースディレクトリを取得
base_dir = settings.BASE_DIR

# データフォルダへのパスを設定
data_dir = os.path.join(base_dir, 'data')

# クラスインデックスのパスを設定
class_indices_path = os.path.join(data_dir, 'class_indices.json')

# クラスインデックスのロード
with open(class_indices_path, 'r') as f:
    class_indices = json.load(f)

# クラスインデックスを反転させる
class_indices_reversed = {v: k for k, v in class_indices.items()}

# モデルをロード
model = tf.keras.models.load_model('data/MobileNetV2_model.keras')

def preprocess_image(image):
    """画像を前処理してモデル入力用に変換する。"""
    image = Image.open(image).resize((224, 224))
    image_array = np.array(image) / 255.0
    image_array = np.expand_dims(image_array, axis=0)
    return image_array

def identify_plant(request):
    # JSONファイルのパスを設定
    metadata_file_path = os.path.join(base_dir, 'data', 'plantnet300K_metadata.json')
    species_file_path = os.path.join(base_dir, 'data', 'plantnet300K_species_id_2_name.json')

    # JSONファイルの読み込み
    with open(metadata_file_path, 'r') as f:
        metadata = json.load(f)

    with open(species_file_path, 'r') as f:
        species_data = json.load(f)

    if request.method == 'POST':
        form = UploadImageForm(request.POST, request.FILES)
        if form.is_valid():
            image = form.cleaned_data['image']
            image_array = preprocess_image(image)
            predictions = model.predict(image_array)

            # 予測結果のログ
            predicted_index = np.argmax(predictions)
            print(f"Predicted index: {predicted_index}")

            # インデックスからクラスIDに変換
            predicted_class_id = class_indices_reversed.get(predicted_index, None)
            if predicted_class_id is None:
                print("Predicted class ID not found")
                return JsonResponse({"name": "エラーが発生しました", "description": "Class ID not found", "metadata": "情報が見つかりません"})

            print(f"Predicted class ID: {predicted_class_id}")

            # クラスIDから植物情報を取得
            try:
                # クラスIDからspecies_infoを取得
                species_info = species_data.get(predicted_class_id, None)
                print(f"Species Info: {species_info}")

                if species_info is None:
                    raise ValueError("Species info not found in species_data")

                # メタデータから該当するspecies_idを持つデータを検索
                plant_metadata = None
                for entry_key, entry_value in metadata.items():
                    if entry_value["species_id"] == predicted_class_id:
                        plant_metadata = entry_value
                        break

                if plant_metadata is None:
                    raise ValueError(f"Metadata not found for species_id: {predicted_class_id}")

                print(f"Plant Metadata: {plant_metadata}")

                predicted_class_name = species_info
                    
            except Exception as e:
                print(f"An error occurred: {e}")
                predicted_class_name = 'エラーが発生しました'
                plant_metadata = '情報が見つかりません'

            # すべてのメタデータ情報をJSONレスポンスに含める
            result = {
                "name": predicted_class_name,
                "description": f"This is a description of the predicted plant: {predicted_class_name}.",
                "metadata": plant_metadata  # すべてのメタデータ情報を含める
            }
            print("Result JSON:", result)
            return JsonResponse(result)

    else:
        form = UploadImageForm()

    return render(request, 'identify/upload.html', {'form': form})

Form(フォーム)の設定

from django import forms
from .models import PlantImage



class UploadImageForm(forms.Form):
    image = forms.ImageField()

Template(テンプレート)の設定

result.html

<!DOCTYPE html>
<html>
<head>
    <title>Plant Identification Result</title>
    <link rel="stylesheet" href="{% static 'css/styles.css' %}">
    <style>
        body {
            font-family: Arial, sans-serif;
            margin: 0;
            padding: 0;
            background-color: #f4f4f4;
            color: #333;
        }
        .container {
            max-width: 800px;
            margin: 20px auto;
            padding: 20px;
            background-color: white;
            border-radius: 8px;
            box-shadow: 0 4px 8px rgba(0,0,0,0.1);
        }
        h1 {
            color: #4CAF50;
            text-align: center;
        }
        .result-card {
            background-color: #ffffff;
            border: 1px solid #ddd;
            border-radius: 8px;
            padding: 20px;
            margin-top: 20px;
            text-align: center;
            box-shadow: 0 4px 8px rgba(0,0,0,0.1);
            word-wrap: break-word; /* 長いテキストを折り返す */
            white-space: pre-wrap; /* 空白や改行を保持 */
        }
        img {
            max-width: 100%;
            height: auto;
            border-radius: 8px;
            margin-top: 10px;
        }
        .no-image {
            color: #FF6347;
            font-weight: bold;
            text-align: center;
            margin-top: 20px;
        }
        .metadata {
            word-wrap: break-word; /* 長いテキストを折り返す */
            white-space: pre-wrap; /* 空白や改行を保持 */
            text-align: left;
            background-color: #f9f9f9;
            padding: 10px;
            border-radius: 8px;
            margin-top: 10px;
            border: 1px solid #ddd;
            overflow-wrap: break-word; /* 長い単語を折り返す */           
            max-width: 100%; /* メタデータの最大幅を100%に制限 */
            box-sizing: border-box; /* パディングやボーダーを含む幅を制限 */
        }


        .btn {
            display: inline-block;
            padding: 10px 20px;
            margin-top: 10px;
            background-color: #4CAF50;
            color: white;
            text-decoration: none;
            border-radius: 4px;
            font-weight: bold;
        }
        .btn:hover {
            background-color: #45a049;
        }
        @media (max-width: 600px) {
            .container {
                padding: 10px;
            }
            .btn {
                padding: 8px 16px;
                font-size: 14px;
            }
        }
    </style>
</head>
<body>
    <div class="container">
        <h1>予測されたクラス: {{ predicted_class }}</h1>
        <div class="result-card">
            <p>予測結果: {{ name }}</p>
            <p>説明: {{ description }}</p>
            <div class="metadata">
                <strong>メタデータ:</strong><br>
                {{ metadata }}
            </div>
        </div>
    </div>
</body>
</html>

upload.html

<!DOCTYPE html>
<html>
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Plant Identifier</title>
    <style>
        body {
            font-family: Arial, sans-serif;
            margin: 0;
            padding: 0;
            background-color: #f4f4f4;
            color: #333;
        }
        h1 {
            color: #4CAF50;
            text-align: center;
            margin-top: 20px;
        }
        .container {
            max-width: 600px;
            margin: 0 auto;
            padding: 20px;
            background-color: white;
            border-radius: 8px;
            box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
        }
        form {
            margin-top: 20px;
        }
        input[type="file"] {
            margin-bottom: 20px;
            padding: 10px;
            border: 1px solid #ddd;
            border-radius: 4px;
            width: 100%;
        }
        button {
            background-color: #4CAF50;
            color: white;
            padding: 15px 20px;
            border: none;
            border-radius: 4px;
            cursor: pointer;
            font-size: 16px;
            box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
        }
        button:hover {
            background-color: #45a049;
        }
        #result {
            margin-top: 20px;
            padding: 15px;
            background-color: #e7f9e7;
            border-radius: 8px;
            box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
            display: none;
        }
        #loading {
            display: none;
            text-align: center;
        }
        @media (max-width: 600px) {
            .container {
                padding: 10px;
            }
            button {
                padding: 12px 15px;
                font-size: 14px;
            }
        }
    </style>
    <script>
        function showLoading() {
            document.getElementById('loading').style.display = 'block';
        }

        function hideLoading() {
            document.getElementById('loading').style.display = 'none';
        }

        function displayResult(data) {
            hideLoading();
            document.getElementById('result').style.display = 'block';
            document.getElementById('result').innerHTML = `
                <h2>Prediction Result</h2>
                <p><strong>Name:</strong> ${data.name}</p>
                <p><strong>Description:</strong> ${data.description}</p>
                <p><strong>Metadata:</strong> ${JSON.stringify(data.metadata)}</p>
            `;
        }

        async function handleSubmit(event) {
            event.preventDefault();
            showLoading();
            
            const formData = new FormData(event.target);
            const response = await fetch("/", {
                method: "POST",
                body: formData
            });
            const result = await response.json();
            displayResult(result);
        }
    </script>
</head>
<body>
    <div class="container">
        <h1>Upload Plant Image</h1>
        <form method="POST" enctype="multipart/form-data" onsubmit="handleSubmit(event)">
            {% csrf_token %}
            <input type="file" name="image" accept="image/*" required>
            <br>
            <button type="submit">Identify Plant</button>
        </form>
        
        <div id="loading">
            <p>Loading...</p>
        </div>

        <div id="result"></div>
    </div>
</body>
</html>

URL設定

INSTALLED_APPS = [
    # 既存のアプリ
    'identify',  # 追加するアプリ
]

ブラウザの動作確認

python manage.py runserver

截屏2024-09-26 23.02.55.png

認識結果

截屏2024-09-26 23.05.44.png

課題

上記のように、テストデータに対する精度(約61.72%)が表示されています。モデルの性能としてはまだ改善の余地がありますが、訓練は成功しています。

植物画像認識モデルの評価結果とまとめ
精度について
テストデータに対する精度は約61.72%であり、モデルの性能としてはまだ改善の余地がありますが、訓練自体は成功しています。

まとめ
これまで学んだ技術を応用し、画像認識がどこまでできるかを試してみました。今回は植物の写真の認識に焦点を当て、いくつかのテストを実施しました。その結果、一般的によく見かける植物であれば正確に認識できることがわかりました。

特に、plantnet-300kのオープンソースから訓練したモデルの評価結果は以下の通りです。

Test Loss (損失): 0.411
Test Accuracy (精度): 87.47%
Top-k Accuracy:
Top-1 精度: 87.47%
Top-3 精度: 97.24%
Top-5 精度: 98.11%
Top-10 精度: 99.58%
自作モデルの精度はまだ低いものの、性能は改善の余地があり、今後の課題となります。

現段階では、学んだ範囲内で画像認識プログラムを作成しましたが、今後はUIデザインの改善や、植物の育て方に関する情報提供機能などを追加していきたいと考えています。

今後の展望
このプロジェクトに取り組んで楽しかったため、技術向上を目指し、今後も継続して勉強を続けていきたいと思います。

参考にした資料、文献

https://plantnet.org/en/
https://github.com/plantnet/PlantNet-300K.git

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?