0
0

SageMakerでのHuggingFace Diffusersデプロイガイド(2/3):Stable Diffusion XL + Control Net Depthのデプロイ

Posted at

背景

  • 前回SageMakerでのHuggingFace Diffusersデプロイガイド(1/3):Stable Diffusion XLのデプロイ でHuggingFaceのStableDifusion XL をSageMakerにデプロイしました

  • 今回は、Stable Diffusion XL x ControlelNet Depthを使った画像生成を行うデプロイを行います

  • ControlNet Depthを使う場合は、predict_fnの処理を書かないといけないので、先ほどのHaggingFaceModelでは対応できません

  • 今回はPyTorchコンテナを利用して、StableDiffusion XL + ControlNet Depthのデプロイを行なっていきます

  • baseモデルはSDXLです

requirements.txt で必要なライブラリを追加する

PyTorchコンテナでStableDiffusionを動かす場合、diffuserstransformers などのLibraryをインストールする必要があります. requirements.txtをcode/に配置すると自動でSageMakerのdeploy時にLibraryがインストールされるようです.

.
├── code
│   ├── inference.py
│   └── requirements.txt
└── model.tar.gz

実装

inference.py

import torch
import numpy as np
from PIL import Image
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers.schedulers import EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, DPMSolverSinglestepScheduler
from diffusers.utils import load_image
import base64
import io
import json
import boto3
import tempfile
import logging


# ログ設定
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)


def model_fn(model_dir):
    try:
        depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
        feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
        controlnet = ControlNetModel.from_pretrained(
            "diffusers/controlnet-depth-sdxl-1.0",
            variant="fp16",
            use_safetensors=True,
            torch_dtype=torch.float16,
        )
        vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
        pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            controlnet=controlnet,
            vae=vae,
            variant="fp16",
            use_safetensors=True,
            torch_dtype=torch.float16,
        )
        pipe.enable_model_cpu_offload()
        return {
            "depth_estimator": depth_estimator,
            "feature_extractor": feature_extractor,
            "pipe": pipe,
        }
    except Exception as e:
        logger.error(f"Error in model_fn: {e}")
        raise e


def input_fn(request_body, request_content_type):
    if request_content_type == "application/json":
        try:
            request = json.loads(request_body)
            prompt = request["prompt"]
            image_url = request["image_url"]
            steps = request["steps"]
            sampler = request["sampler"]
            cfg_scale = request["cfg_scale"]
            seed = request["seed"]
            height = request["height"]
            width = request["width"]
            guidance_start = request["guidance_start"]
            guidance_end = request["guidance_end"]
            weight = request["weight"]
            loop_count = request["loop_count"]
            num_images_per_prompt = request["num_images_per_prompt"]
            if not isinstance(prompt, str):
                raise ValueError("`prompt` has to be of type `str` but is {}".format(type(prompt)))
            return prompt, image_url, steps, sampler, cfg_scale, seed, height, width, guidance_start, guidance_end, weight, loop_count, num_images_per_prompt
        except Exception as e:
            logger.error(f"Error in input_fn: {e}")
            raise e
    raise ValueError("Unsupported content type: {}".format(request_content_type))


def predict_fn(input_data, model):
    try:
        prompt, image_url, steps, sampler, cfg_scale, seed, height, width, guidance_start, guidance_end, weight, loop_count, num_images_per_prompt = input_data
        s3 = boto3.client('s3')
        bucket_name = image_url.split('/')[2]
        object_key = '/'.join(image_url.split('/')[3:])
        with tempfile.NamedTemporaryFile(suffix=".png") as tmp_file:
            s3.download_file(bucket_name, object_key, tmp_file.name)
            image = load_image(tmp_file.name)

        depth_estimator = model["depth_estimator"].to("cuda")
        feature_extractor = model["feature_extractor"]
        pipe = model["pipe"]

        if sampler == "Eular_a":
            pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
        elif sampler == "Eular":
            pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
        elif sampler == "DPM++2M":
            pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
        elif sampler == "DPM2":
            pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
        elif sampler == "DPM2_a":
            pipe.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipe.scheduler.config)
        elif sampler == "DPM++SDE":
            pipe.scheduler = DPMSolverSinglestepScheduler.from_config(pipe.scheduler.config)
        else:
            raise Exception(f"sampler: {sampler}, not accespable")

        def get_depth_map(image, target_height, target_width):
            image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
            with torch.no_grad(), torch.autocast("cuda"):
                depth_map = depth_estimator(image).predicted_depth

            depth_map = torch.nn.functional.interpolate(
                depth_map.unsqueeze(1),
                size=(target_height, target_width),
                mode="bicubic",
                align_corners=False,
            )
            depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
            depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
            depth_map = (depth_map - depth_min) / (depth_max - depth_min)
            image = torch.cat([depth_map] * 3, dim=1)

            image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
            image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
            return image

        depth_image = get_depth_map(image, height, width)
        depth_estimator.to("cpu")  # メモリを解放するためにモデルをCPUに戻す
        torch.cuda.empty_cache()

        pipe = pipe.to("cuda")
        controlnet_conditioning_scale = 0.5  # recommended for good generalization

        generator = torch.manual_seed(seed)

        result = []
        with torch.cuda.amp.autocast():
            for _ in range(loop_count):
                images = pipe(
                    prompt,
                    image=depth_image,
                    num_inference_steps=steps,
                    height=height,
                    width=width,
                    control_guidance_start=guidance_start,
                    control_guidance_end=guidance_end,
                    guidance_scale=cfg_scale,
                    generator=generator,
                    controlnet_conditioning_scale=controlnet_conditioning_scale,
                    num_images_per_prompt=num_images_per_prompt
                ).images
                result = result + images
        pipe.to("cpu")  # メモリを解放するためにモデルをCPUに戻す
        torch.cuda.empty_cache()

        # 画像をbase64エンコード
        image_strs = []
        for image in result:
            buffered = io.BytesIO()
            image.save(buffered, format="PNG")
            image_strs.append(base64.b64encode(buffered.getvalue()).decode("utf-8"))

        return {"images": image_strs}

    except Exception as e:
        logger.error(f"Error in predict_fn: {e}")
        raise e


def output_fn(prediction, content_type):
    if content_type == "application/json":
        return json.dumps(prediction)
    raise ValueError("Unsupported content type: {}".format(content_type))

まとめ

PyTorchイメージを使うと割と自由度高く推論コードが書け、それをSageMakerにデプロイできることがわかりました。次は、Cannyやinpaintのサンプルも書く予定です。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0