3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

BedrockのStabilityAIを使って画像生成をするgradioアプリ

Posted at

はじめに

 9月になってオレゴンリージョンのBedrockにStable Image Ultra, Stable Diffusion 3 Large Stable Image Coreの3つのモデルが追加されたので、例によってgradioでアプリ化したいと思います。

 新しい3つのモデルについてはこちらの記事に詳しく書かれています。

boto3の設定

 事前にboto3をインストールをインストール、環境変数にAWSのAPI キーを設定しておきましょう。

 そのあたりはこちらの記事が分かりやすかったです。

準備

 pipでgradioをインストールしておいて下さい。

ソースコード

gradio_bedrock_stable_diffusion.py
import os
import json
import base64
import logging
from io import BytesIO
from datetime import datetime
from typing import Dict, List, Tuple, Optional

import boto3
import gradio as gr
from PIL import Image

# ログ設定
logging.basicConfig(level=logging.ERROR, format='%(asctime)s [%(levelname)s] %(message)s')

# 定数
STABILITY_MODELS: Dict[str, str] = {
    "Stable Image Core": "stability.stable-image-core-v1:0",
    "SD3 Large": "stability.sd3-large-v1:0",
    "Stable Image Ultra": "stability.stable-image-ultra-v1:0"
}

ASPECT_RATIOS: List[str] = ["1:1", "16:9", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"]
FILE_FORMATS: List[str] = ["jpeg", "png"]

AWS_REGION: str = 'us-west-2'
OUTPUT_FOLDER: str = "output"


class BedrockImageGenerator:
    def __init__(self) -> None:
        """AWS Bedrockランタイムクライアントを初期化します。"""
        self.bedrock = boto3.client(
            service_name='bedrock-runtime',
            region_name=AWS_REGION
        )

    def get_model_id(self, model_name: str) -> str:
        """
        モデル名からモデルIDを取得します。

        Args:
            model_name (str): モデル名

        Returns:
            str: モデルID
        """
        return STABILITY_MODELS[model_name]

    def _invoke_model(self, model_id: str, body: str) -> Dict:
        """
        モデルを呼び出し、レスポンスを取得します。

        Args:
            model_id (str): モデルID
            body (str): モデルへのリクエストボディ(JSON形式)

        Returns:
            Dict: モデルのレスポンスデータ
        """
        try:
            response = self.bedrock.invoke_model(modelId=model_id, body=body)
            output_body = json.loads(response["body"].read().decode("utf-8"))
            return output_body
        except Exception as e:
            logging.error(f"モデル呼び出しエラー: {str(e)}")
            raise

    def _decode_image(self, base64_image: str) -> Image.Image:
        """
        Base64エンコードされた画像データをデコードします。

        Args:
            base64_image (str): Base64エンコードされた画像データ

        Returns:
            Image.Image: デコードされた画像オブジェクト
        """
        image_data = base64.b64decode(base64_image)
        image = Image.open(BytesIO(image_data))
        return image

    def _generate_filename(self, file_format: str) -> str:
        """
        タイムスタンプを基にファイル名を生成します。

        Args:
            file_format (str): 画像ファイルの形式("jpeg" または "png")

        Returns:
            str: 生成されたファイル名
        """
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        return f"{timestamp}.{file_format.lower()}"

    def _save_image(self, image: Image.Image, file_format: str) -> str:
        """
        画像を指定された形式で保存し、保存先のファイルパスを返します。

        Args:
            image (Image.Image): 保存する画像オブジェクト
            file_format (str): 保存するファイル形式

        Returns:
            str: 保存された画像のファイルパス
        """
        if not os.path.exists(OUTPUT_FOLDER):
            os.makedirs(OUTPUT_FOLDER)

        filename = self._generate_filename(file_format)
        filepath = os.path.join(OUTPUT_FOLDER, filename)

        if file_format.lower() == "jpeg":
            image = image.convert("RGB")  # JPEGはRGBに変換する必要がある

        image.save(filepath, format=file_format.upper())
        logging.info(f"画像が {filepath} に保存されました")

        return filepath

    def generate_image(
        self,
        model_name: str,
        prompt: str,
        negative_prompt: str,
        aspect_ratio: str,
        seed: int,
        file_format: str
    ) -> Tuple[Optional[Image.Image], str]:
        """
        指定されたパラメータを基に画像を生成し、保存します。

        Args:
            model_name (str): 使用するモデル名
            prompt (str): 画像生成のためのテキストプロンプト
            negative_prompt (str): 画像生成において除外する要素
            aspect_ratio (str): アスペクト比
            seed (int): シード値
            file_format (str): 保存する画像のファイル形式("jpeg" または "png")

        Returns:
            Tuple[Optional[Image.Image], str]: 生成された画像と保存ファイルのパス
        """
        body = json.dumps({
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "aspect_ratio": aspect_ratio,
            "seed": seed
        })
        model_id = self.get_model_id(model_name)

        try:
            output_body = self._invoke_model(model_id, body)
            base64_output_image = output_body["images"][0]
            image = self._decode_image(base64_output_image)
            filepath = self._save_image(image, file_format)
            return image, filepath
        except Exception as e:
            logging.error(f"画像生成エラー: {str(e)}")
            return None, ""


def create_gradio_interface() -> gr.Interface:
    """
    Gradioインターフェースを作成します。

    Returns:
        gr.Interface: Gradioインターフェースオブジェクト
    """
    generator = BedrockImageGenerator()

    def generate_wrapper(
        model_name: str,
        prompt: str,
        negative_prompt: str,
        aspect_ratio: str,
        seed: int,
        file_format: str
    ) -> Tuple[Optional[Image.Image], str]:
        """Gradioインターフェースのラッパー関数"""
        image, filepath = generator.generate_image(
            model_name, prompt, negative_prompt, aspect_ratio, seed, file_format
        )
        if image:
            return image, f"画像が保存されました: {filepath}"
        return None, "画像の生成に失敗しました。"

    return gr.Interface(
        fn=generate_wrapper,
        inputs=[
            gr.Radio(choices=list(STABILITY_MODELS.keys()), label="モデル", value="Stable Image Core"),
            gr.Textbox(lines=3, label="プロンプト"),
            gr.Textbox(lines=3, label="ネガティブプロンプト", value="worst quality,ugly,bad anatomy,jpeg artifacts"),
            gr.Radio(choices=ASPECT_RATIOS, label="アスペクト比", value="1:1"),
            gr.Slider(minimum=0, maximum=1000, value=0, label="シード値"),
            gr.Radio(choices=FILE_FORMATS, label="出力ファイル形式", value="png")
        ],
        outputs=[
            gr.Image(label="生成された画像"),
            gr.Textbox(label="保存結果")
        ],
        title="AWS Bedrock Stable Diffusion Text-to-Image",
        description=(
            "テキストプロンプトを入力して、AWS Bedrock上のStable Diffusionで画像を生成します。"
            "生成された画像は指定されたファイル形式で保存されます。"
        )
    )


if __name__ == "__main__":
    interface = create_gradio_interface()
    interface.launch()

実行

 ターミナルで以下を実行

python gradio_bedrock_stable_diffusion.py

 ターミナルに表示されたURLにアクセスすると以下の画面が表示

image.png

 左側の入力について説明すると

  • モデル・・・使用するモデルを3つのうちから選択
  • プロンプト・・・出力する画像を文章で入力
  • ネガティブプロンプト・・・画像を生成する際に除外したい要素
  • アスペクト比・・・出力するアスペクト比を選択
  • シード値・・・AIが画像を生成する際に使用するランダムな要素を制御するための数値、同じSeedを使用すると、同じテキストプロンプトでも同じ画像が生成されます
  • 出力ファイル形式・・・pngかjpegを選択

 ※ネガティブプロンプトに関してはこちらのサイトを参考にデフォルトで基本的なものを入れてあります。

 Submitボタンを押すと画像が生成され./outputフォルダに作成した画像が保存されます。

実行結果

 プロンプトに「Dark night with a full moon shining, brown tiger cat walking on a tightrope across power lines, Japanese animation style.」(満月が輝く暗い夜、電線を綱渡りする茶トラ猫、日本のアニメーション風)を入力しSubmitボタンを押したときの各モデルの出力結果

※プロンプトは英語の方が良い結果が出やすいようです。

Stable Image Core

20240930_200603.png

Stable Diffusion 3 Large

20240930_200553.png

Stable Image Ultra

20240930_200534.png

さいごに

 上記の出力結果はやっぱり、Stable Image Ultraが一番いいかな(高いけど…)。もっとプロンプト、ネガティブプロンプトを勉強すれば良い結果が出ると思うので、これからも精進していきたいと思います。読んでいただきありがとうございました。

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?