0
1

SageMakerでのHuggingFace Diffusersデプロイガイド(3/3):Stable Diffusion XL + Control Net Canny & Inpaintのデプロイ

Posted at

背景

requirements.txt で必要なライブラリを追加する

前回同様、PyTorchコンテナでStableDiffusionを動かす場合、diffuserstransformers などのLibraryをインストールする必要があります. requirements.txtをcode/に配置すると自動でSageMakerのdeploy時にLibraryがインストールされます.

.
├── code
│   ├── inference.py
│   └── requirements.txt
└── model.tar.gz

Canny実装

inference.py

import torch
import numpy as np
from PIL import Image
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers.schedulers import EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, DPMSolverSinglestepScheduler
from diffusers.utils import load_image
import base64
import io
import json
import boto3
import tempfile
import logging
import cv2


# ログ設定
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)


def model_fn(model_dir):
    try:
        print("======model_dir==========")
        controlnet = ControlNetModel.from_pretrained(
            "diffusers/controlnet-canny-sdxl-1.0",
            torch_dtype=torch.float16
        )
        vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
        pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            controlnet=controlnet,
            vae=vae,
            torch_dtype=torch.float16,
        )
        pipe.enable_model_cpu_offload()
        return {"pipe": pipe}
    except Exception as e:
        logger.error(f"Error in model_fn: {e}")
        raise e


def input_fn(request_body, request_content_type):
    print("=====input_fn=======")
    if request_content_type == "application/json":
        try:
            request = json.loads(request_body)
            prompt = request["prompt"]
            image_url = request["image_url"]
            steps = request["steps"]
            sampler = request["sampler"]
            cfg_scale = request["cfg_scale"]
            seed = request["seed"]
            height = request["height"]
            width = request["width"]
            guidance_start = request["guidance_start"]
            guidance_end = request["guidance_end"]
            weight = request["weight"]
            loop_count = request["loop_count"]
            num_images_per_prompt = request["num_images_per_prompt"]
            if not isinstance(prompt, str):
                raise ValueError("`prompt` has to be of type `str` but is {}".format(type(prompt)))
            return prompt, image_url, steps, sampler, cfg_scale, seed, height, width, guidance_start, guidance_end, weight, loop_count, num_images_per_prompt
        except Exception as e:
            logger.error(f"Error in input_fn: {e}")
            raise e
    raise ValueError("Unsupported content type: {}".format(request_content_type))


def predict_fn(input_data, model):
    print("===predict_fn===")
    try:
        prompt, image_url, steps, sampler, cfg_scale, seed, height, width, guidance_start, guidance_end, weight, loop_count, num_images_per_prompt = input_data
        s3 = boto3.client('s3')
        bucket_name = image_url.split('/')[2]
        object_key = '/'.join(image_url.split('/')[3:])
        with tempfile.NamedTemporaryFile(suffix=".png") as tmp_file:
            s3.download_file(bucket_name, object_key, tmp_file.name)
            image = load_image(tmp_file.name).resize((1024, 1024))

        pipe = model["pipe"]

        if sampler == "Euler_a":
            pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
        elif sampler == "Euler":
            pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
        elif sampler == "DPM++2M":
            pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
        elif sampler == "DPM2":
            pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
        elif sampler == "DPM2_a":
            pipe.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipe.scheduler.config)
        elif sampler == "DPM++SDE":
            pipe.scheduler = DPMSolverSinglestepScheduler.from_config(pipe.scheduler.config)
        else:
            raise Exception(f"sampler: {sampler}, not accespable")

        image = np.array(image)
        image = cv2.Canny(image, 100, 200)
        image = image[:, :, None]
        image = np.concatenate([image, image, image], axis=2)
        image = Image.fromarray(image)

        pipe = pipe.to("cuda")
        generator = torch.Generator(device="cuda").manual_seed(seed)

        print("===before_pipe===")
        result = []
        with torch.cuda.amp.autocast():
            for _ in range(loop_count):
                images = pipe(
                    prompt,
                    image=image,
                    num_inference_steps=steps,
                    height=height,
                    width=width,
                    controlnet_conditioning_scale=weight/2,
                    guidance_scale=cfg_scale,
                    generator=generator,
                    num_images_per_prompt=num_images_per_prompt,
                    control_guidance_start=guidance_start,
                    control_guidance_end=guidance_end,
                ).images
                result = result + images
        pipe.to("cpu")  # メモリを解放するためにモデルをCPUに戻す
        torch.cuda.empty_cache()
        print("===after_pipe===")

        # 画像をbase64エンコード
        image_strs = []
        for image in result:
            buffered = io.BytesIO()
            image.save(buffered, format="PNG")
            image_strs.append(base64.b64encode(buffered.getvalue()).decode("utf-8"))

        return {"images": image_strs}

    except Exception as e:
        logger.error(f"Error in predict_fn: {e}")
        raise e


def output_fn(prediction, content_type):
    if content_type == "application/json":
        return json.dumps(prediction)
    raise ValueError("Unsupported content type: {}".format(content_type))

Inpaintの実装

import torch
import numpy as np
from PIL import Image
from diffusers import AutoPipelineForInpainting
from diffusers.schedulers import EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, DPMSolverSinglestepScheduler
from diffusers.utils import load_image
import base64
import io
import json
import boto3
import tempfile
import logging


# ログ設定
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)


def model_fn(model_dir):
    try:
        print("======model_dir==========")
        pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to("cuda")
        pipe.enable_model_cpu_offload()
        return {"pipe": pipe}
    except Exception as e:
        logger.error(f"Error in model_fn: {e}")
        raise e


def input_fn(request_body, request_content_type):
    print("=====input_fn=======")
    if request_content_type == "application/json":
        try:
            request = json.loads(request_body)
            prompt = request["prompt"]
            image_url = request["image_url"]
            mask_url = request["mask_url"]
            steps = request["steps"]
            sampler = request["sampler"]
            cfg_scale = request["cfg_scale"]
            seed = request["seed"]
            height = request["height"]
            width = request["width"]
            guidance_start = request["guidance_start"]
            guidance_end = request["guidance_end"]
            weight = request["weight"]
            loop_count = request["loop_count"]
            num_images_per_prompt = request["num_images_per_prompt"]
            if not isinstance(prompt, str):
                raise ValueError("`prompt` has to be of type `str` but is {}".format(type(prompt)))
            return prompt, image_url, mask_url, steps, sampler, cfg_scale, seed, height, width, guidance_start, guidance_end, weight, loop_count, num_images_per_prompt
        except Exception as e:
            logger.error(f"Error in input_fn: {e}")
            raise e
    raise ValueError("Unsupported content type: {}".format(request_content_type))


def predict_fn(input_data, model):
    print("===predict_fn===")
    try:
        prompt, image_url, mask_url, steps, sampler, cfg_scale, seed, height, width, guidance_start, guidance_end, weight, loop_count, num_images_per_prompt = input_data
        s3 = boto3.client('s3')
        
        bucket_name = image_url.split('/')[2]
        object_key = '/'.join(image_url.split('/')[3:])
        with tempfile.NamedTemporaryFile(suffix=".png") as tmp_file:
            s3.download_file(bucket_name, object_key, tmp_file.name)
            image = load_image(tmp_file.name).resize((1024, 1024))

        bucket_name = mask_url.split('/')[2]
        object_key = '/'.join(mask_url.split('/')[3:])
        with tempfile.NamedTemporaryFile(suffix=".png") as tmp_file:
            s3.download_file(bucket_name, object_key, tmp_file.name)
            mask_image = load_image(tmp_file.name).resize((1024, 1024))


        pipe = model["pipe"]

        if sampler == "Euler_a":
            pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
        elif sampler == "Euler":
            pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
        elif sampler == "DPM++2M":
            pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
        elif sampler == "DPM2":
            pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
        elif sampler == "DPM2_a":
            pipe.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipe.scheduler.config)
        elif sampler == "DPM++SDE":
            pipe.scheduler = DPMSolverSinglestepScheduler.from_config(pipe.scheduler.config)
        else:
            raise Exception(f"sampler: {sampler}, not accespable")

        pipe = pipe.to("cuda")
        generator = torch.Generator(device="cuda").manual_seed(seed)

        print("===before_pipe===")
        print(steps)
        result = []
        with torch.cuda.amp.autocast():
            for _ in range(loop_count):
                images = pipe(
                    prompt,
                    image=image,
                    mask_image=mask_image,
                    num_inference_steps=steps,
                    height=height,
                    width=width,
                    controlnet_conditioning_scale=weight/2,
                    guidance_scale=cfg_scale,
                    generator=generator,
                    num_images_per_prompt=num_images_per_prompt,
                    control_guidance_start=guidance_start,
                    control_guidance_end=guidance_end,
                ).images
                result = result + images
        pipe.to("cpu")  # メモリを解放するためにモデルをCPUに戻す
        torch.cuda.empty_cache()
        print("===after_pipe===")

        # 画像をbase64エンコード
        image_strs = []
        for image in result:
            buffered = io.BytesIO()
            image.save(buffered, format="PNG")
            image_strs.append(base64.b64encode(buffered.getvalue()).decode("utf-8"))

        return {"images": image_strs}

    except Exception as e:
        logger.error(f"Error in predict_fn: {e}")
        raise e


def output_fn(prediction, content_type):
    if content_type == "application/json":
        return json.dumps(prediction)
    raise ValueError("Unsupported content type: {}".format(content_type))

まとめ

以上、今回は、前回同様の手法で、CannyやInpaintの推論コードを紹介しました。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1