6
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

画像生成AIにとって一番嫌なノイズを探す -実装編-

Last updated at Posted at 2024-11-12

概要

生成AIが発展する一方で、権利者が望まない無断のAI学習が問題になってきています
この記事では、学習用画像にノイズが入っていた場合、画像生成AIにどのような影響があるのかを調査しました

※この記事群は2024/10末時点の調査に基づいています

目次

  1. 画像生成AIにとって一番嫌なノイズを探す -結論編-
  2. 画像生成AIにとって一番嫌なノイズを探す -準備編-
  3. 画像生成AIにとって一番嫌なノイズを探す -調査編(1)-
  4. 画像生成AIにとって一番嫌なノイズを探す -調査編(2)-
  5. 画像生成AIにとって一番嫌なノイズを探す -調査編(3)-
  6. 画像生成AIにとって一番嫌なノイズを探す -調査編(4)-
  7. 画像生成AIにとって一番嫌なノイズを探す -実装編-

Stable Diffusion Web UI の導入

Dockerfile の修正

  • 10/27時点で stable-diffusion-webui-docker が参照している stable-diffusion-webui が v1.9.4 だったので、v1.10.0 にあげる
    • Stable Diffusion 3 medium がサポートされたのは v1.10.0 以降
    • PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION の定義が不足しているとエラーになったので追加
ENV PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
WORKDIR /
RUN --mount=type=cache,target=/root/.cache/pip \
  git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git && \
  cd stable-diffusion-webui && \
  # 現在の最新バージョンに変更
  git reset --hard v1.10.1 && \
  pip install -r requirements_versions.txt
RUN pip install -U typing_extensions
docker compose --profile download up --build
docker compose --profile auto up --build

kohya-ss/sd-scripts の導入

git clone https://github.com/kohya-ss/sd-scripts.git
sd-scripts>git checkout tags/v0.8.7
conda create -n lora pip python=3.10
conda activate lora
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
accelerate config
(lora) C:\MMD\sd-scripts>accelerate config
------------------------------------------------------------------------------------------------------------------------------------------------------In which compute environment are you running?
This machine
------------------------------------------------------------------------------------------------------------------------------------------------------Which type of machine are you using?
No distributed training
Do you want to run your training on CPU only (even if a GPU / Apple Silicon / Ascend NPU device is available)? [yes/NO]:NO
Do you wish to optimize your script with torch dynamo?[yes/NO]:NO
Do you want to use DeepSpeed? [yes/NO]: NO
What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:all
------------------------------------------------------------------------------------------------------------------------------------------------------Do you wish to use FP16 or BF16 (mixed precision)?
fp16
accelerate configuration saved at C:\Users\celes/.cache\huggingface\accelerate\default_config.yaml

タグの編集

image.png

Additional tags: toya
Weight threshold: 0.7

image.png

学習させたいワードを削除(トリガーとして加えた「toya」だけ残す)
※後日「watermark,english text」も学習に加えた

LoRA学習用バッチ

複数のデータセットを一度に学習できるよう、ループ処理で学習を実行出来るようにバッチ作成

launch.bat

@echo off
cls

setlocal

set PRETRAINED_MODEL_NAME_OR_PATH=<stable-diffusion-webui-dockerへのフルパス>\data\models\Stable-diffusion\anyloraCheckpoint_bakedvaeBlessedFp16.safetensors
set OUTPUT_DIR=<stable-diffusion-webui-dockerへのフルパス>\data\models\LoRA
set DATASET_DIR=<stable-diffusion-webui-dockerへのフルパス>\data\dataset
rem Windowsではxformersはエラーとなるため、無効化しておく
set XFORMERS_FORCE_DISABLE_TRITON=1

rem データ名のリストを定義
set DATA_NAMES=case03_72_watermark_20_mix_70_overlay_margin_35 case03_72_watermark_20_mix_M_70_overlay_margin_35

for %%D in (%DATA_NAMES%) do (
    echo =====================================================================================
    echo train_data_dir: %DATASET_DIR%\%%D%
    echo output_dir: %OUTPUT_DIR%\%%D
    mkdir "%OUTPUT_DIR%\%%D"
    
    echo Running training for %%D...
    
    accelerate launch --config_file="..\default_config.yaml" ..\train_network.py ^
        --pretrained_model_name_or_path="%PRETRAINED_MODEL_NAME_OR_PATH%"  ^
        --train_data_dir="%DATASET_DIR%\%%D%" ^
        --caption_extension=".txt" ^
        --resolution=512 ^
        --enable_bucket ^
        --dataset_repeats=1 ^
        --output_dir="%OUTPUT_DIR%\%%D%" ^
        --output_name=%%D ^
        --save_every_n_epochs=2 ^
        --train_batch_size=1 ^
        --max_train_epochs=20 ^
        --prior_loss_weight=1.0  ^
        --save_model_as=safetensors ^
        --learning_rate=1e-4  ^
        --optimizer_type="Adafactor"  ^
        --xformers  ^
        --mixed_precision="fp16"  ^
        --save_precision="fp16"  ^
        --cache_latents  ^
        --persistent_data_loader_workers  ^
        --gradient_checkpointing ^
        --max_data_loader_n_workers=16  ^
        --network_module=networks.lora 

)

endlocal
echo All training processes completed.

学習結果生成用スクリプト

各epochの学習結果画像を生成するのも一度に出力出来るよう、pythonスクリプトを作成

save_sd_images.py

import os
import requests
import base64
import io
from PIL import Image

# 使用するLoRAのリスト
loras = [
    "case03_72_watermark_20_mix_70_overlay_margin_35",
    "case03_72_watermark_20_mix_M_70_overlay_margin_35",
]

# その他の設定
positive_prompt = "<ポジティブプロンプト>"
negative_prompt = "<ネガティブプロンプト>"
output_dir = "<出力先ディレクトリ>"
lora_dir = "<stable-diffusion-webui-dockerへのフルパス>/data/models/Lora"

# APIエンドポイント
api_url = "http://localhost:7860/sdapi/v1/txt2img"


def adjust_lora_epoch_name(lora_epoch_name):
    # 末尾が数字かどうかを判定
    if lora_epoch_name[-6:].isdigit():
        return lora_epoch_name  # 末尾が数字の場合はそのまま返す
    else:
        return f"{lora_epoch_name}-000020"  # 末尾が数字でない場合は -000020 を追加


# 各プロンプトに対して画像生成リクエストを送信
for idx, lora_name in enumerate(loras):
    for append_negative_prompt in [
        "",
        "watermark,english text,",
    ]:
        negative_dir_name = "1_normal" if not append_negative_prompt else "2_negative"

        # 出力ディレクトリが存在しない場合は作成
        os.makedirs(
            os.path.join(output_dir, lora_name, negative_dir_name, "parts"),
            exist_ok=True,
        )
        os.makedirs(
            os.path.join(output_dir, lora_name, negative_dir_name, "thumb"),
            exist_ok=True,
        )

        # フォルダ内のファイルを走査
        for root, dirs, files in os.walk(f"{lora_dir}/{lora_name}"):
            for file in files:
                # 指定の拡張子を持つファイルを抽出
                if file.endswith(".safetensors"):
                    # 拡張子を除いたファイル名を取得
                    lora_epoch_name = os.path.splitext(file)[0]

                    for n in range(3):
                        print(
                            f"Generating images for prompt: '{lora_epoch_name} - {negative_dir_name}[{n}]' ..."
                        )

                        # APIリクエスト用のペイロード
                        payload = {
                            "prompt": f"{positive_prompt}<lora:{lora_epoch_name}:1>",
                            "negative_prompt": append_negative_prompt + negative_prompt,
                            "styles": [],
                            "seed": -1,
                            "subseed": -1,
                            "subseed_strength": 0,
                            "seed_resize_from_h": -1,
                            "seed_resize_from_w": -1,
                            "sampler_name": "DPM++ 2M",
                            "scheduler": "Automatic",
                            "batch_size": 1,
                            "n_iter": 6,
                            "steps": 20,
                            "cfg_scale": 7,
                            "width": 512,
                            "height": 512,
                            "restore_faces": False,
                            "tiling": False,
                            "do_not_save_samples": False,
                            "do_not_save_grid": False,
                            "eta": 0,
                            "denoising_strength": 0,
                            "s_min_uncond": 0,
                            "s_churn": 0,
                            "s_tmax": 0,
                            "s_tmin": 0,
                            "s_noise": 0,
                            "override_settings": {
                                "sd_model_checkpoint": "anyloraCheckpoint_bakedvaeBlessedFp16"
                            },
                            "override_settings_restore_afterwards": True,
                            "sampler_index": "Euler",
                            "send_images": True,
                            "save_images": False,
                        }

                        # POSTリクエストを送信
                        response = requests.post(api_url, json=payload)

                        # レスポンスが正常か確認
                        if response.status_code == 200:
                            response_data = response.json()

                            # 各生成画像の保存と結合用リスト
                            images = []

                            for m, image_base64 in enumerate(response_data["images"]):
                                # 画像データをデコードしてPIL形式に変換
                                image_data = base64.b64decode(image_base64)
                                image = Image.open(io.BytesIO(image_data))
                                image_path = os.path.join(
                                    output_dir,
                                    lora_name,
                                    negative_dir_name,
                                    "parts",
                                    f"{lora_epoch_name}_{negative_dir_name}_parts_{m + 1}.png",
                                )
                                image.save(image_path)
                                # print(f"Image saved to {image_path}")

                                # 画像サイズを256x256にリサイズ
                                thumb_image = image.resize((256, 256))
                                thumb_image_path = os.path.join(
                                    output_dir,
                                    lora_name,
                                    negative_dir_name,
                                    "thumb",
                                    f"{lora_epoch_name}_{negative_dir_name}_thumb_{m + 1}.png",
                                )
                                thumb_image.save(thumb_image_path)

                                # print(f"Image saved to {thumb_image_path}")

                                # 結合するのは小さい方
                                images.append(thumb_image)

                            # 画像の結合(2行3列のグリッド)
                            if len(images) == 6:
                                combined_width = 3 * images[0].width
                                combined_height = 2 * images[0].height
                                combined_image = Image.new(
                                    "RGB", (combined_width, combined_height)
                                )

                                # 各位置に画像を配置
                                for i in range(6):
                                    x = (i % 3) * images[i].width
                                    y = (i // 3) * images[i].height
                                    combined_image.paste(images[i], (x, y))

                                combined_lora_epoch_name = adjust_lora_epoch_name(
                                    lora_epoch_name
                                )

                                # 結合画像の保存
                                combined_image_path = os.path.join(
                                    output_dir,
                                    lora_name,
                                    negative_dir_name,
                                    f"{combined_lora_epoch_name}_{negative_dir_name}_{(n + 1):02d}.png",
                                )
                                combined_image.save(combined_image_path)
                                print(f"Image saved to {combined_image_path}")
                        else:
                            print(f"Error {response.status_code}: {response.text}")

print("All prompts processed.")

Appendix

6
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?