背景
-
前回SageMakerでのHuggingFace Diffusersデプロイガイド(1/3):Stable Diffusion XLのデプロイ でHuggingFaceのStableDifusion XL をSageMakerにデプロイしました
-
今回は、Stable Diffusion XL x ControlelNet Depthを使った画像生成を行うデプロイを行います
-
ControlNet Depthを使う場合は、predict_fnの処理を書かないといけないので、先ほどのHaggingFaceModelでは対応できません
-
今回はPyTorchコンテナを利用して、StableDiffusion XL + ControlNet Depthのデプロイを行なっていきます
-
baseモデルはSDXLです
requirements.txt で必要なライブラリを追加する
PyTorchコンテナでStableDiffusionを動かす場合、diffusers
や transformers
などのLibraryをインストールする必要があります. requirements.txtをcode/に配置すると自動でSageMakerのdeploy時にLibraryがインストールされるようです.
.
├── code
│ ├── inference.py
│ └── requirements.txt
└── model.tar.gz
実装
inference.py
import torch
import numpy as np
from PIL import Image
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers.schedulers import EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, DPMSolverSinglestepScheduler
from diffusers.utils import load_image
import base64
import io
import json
import boto3
import tempfile
import logging
# ログ設定
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
def model_fn(model_dir):
try:
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0",
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
)
pipe.enable_model_cpu_offload()
return {
"depth_estimator": depth_estimator,
"feature_extractor": feature_extractor,
"pipe": pipe,
}
except Exception as e:
logger.error(f"Error in model_fn: {e}")
raise e
def input_fn(request_body, request_content_type):
if request_content_type == "application/json":
try:
request = json.loads(request_body)
prompt = request["prompt"]
image_url = request["image_url"]
steps = request["steps"]
sampler = request["sampler"]
cfg_scale = request["cfg_scale"]
seed = request["seed"]
height = request["height"]
width = request["width"]
guidance_start = request["guidance_start"]
guidance_end = request["guidance_end"]
weight = request["weight"]
loop_count = request["loop_count"]
num_images_per_prompt = request["num_images_per_prompt"]
if not isinstance(prompt, str):
raise ValueError("`prompt` has to be of type `str` but is {}".format(type(prompt)))
return prompt, image_url, steps, sampler, cfg_scale, seed, height, width, guidance_start, guidance_end, weight, loop_count, num_images_per_prompt
except Exception as e:
logger.error(f"Error in input_fn: {e}")
raise e
raise ValueError("Unsupported content type: {}".format(request_content_type))
def predict_fn(input_data, model):
try:
prompt, image_url, steps, sampler, cfg_scale, seed, height, width, guidance_start, guidance_end, weight, loop_count, num_images_per_prompt = input_data
s3 = boto3.client('s3')
bucket_name = image_url.split('/')[2]
object_key = '/'.join(image_url.split('/')[3:])
with tempfile.NamedTemporaryFile(suffix=".png") as tmp_file:
s3.download_file(bucket_name, object_key, tmp_file.name)
image = load_image(tmp_file.name)
depth_estimator = model["depth_estimator"].to("cuda")
feature_extractor = model["feature_extractor"]
pipe = model["pipe"]
if sampler == "Eular_a":
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
elif sampler == "Eular":
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
elif sampler == "DPM++2M":
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
elif sampler == "DPM2":
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
elif sampler == "DPM2_a":
pipe.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipe.scheduler.config)
elif sampler == "DPM++SDE":
pipe.scheduler = DPMSolverSinglestepScheduler.from_config(pipe.scheduler.config)
else:
raise Exception(f"sampler: {sampler}, not accespable")
def get_depth_map(image, target_height, target_width):
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
with torch.no_grad(), torch.autocast("cuda"):
depth_map = depth_estimator(image).predicted_depth
depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1),
size=(target_height, target_width),
mode="bicubic",
align_corners=False,
)
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
image = torch.cat([depth_map] * 3, dim=1)
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
return image
depth_image = get_depth_map(image, height, width)
depth_estimator.to("cpu") # メモリを解放するためにモデルをCPUに戻す
torch.cuda.empty_cache()
pipe = pipe.to("cuda")
controlnet_conditioning_scale = 0.5 # recommended for good generalization
generator = torch.manual_seed(seed)
result = []
with torch.cuda.amp.autocast():
for _ in range(loop_count):
images = pipe(
prompt,
image=depth_image,
num_inference_steps=steps,
height=height,
width=width,
control_guidance_start=guidance_start,
control_guidance_end=guidance_end,
guidance_scale=cfg_scale,
generator=generator,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_images_per_prompt=num_images_per_prompt
).images
result = result + images
pipe.to("cpu") # メモリを解放するためにモデルをCPUに戻す
torch.cuda.empty_cache()
# 画像をbase64エンコード
image_strs = []
for image in result:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_strs.append(base64.b64encode(buffered.getvalue()).decode("utf-8"))
return {"images": image_strs}
except Exception as e:
logger.error(f"Error in predict_fn: {e}")
raise e
def output_fn(prediction, content_type):
if content_type == "application/json":
return json.dumps(prediction)
raise ValueError("Unsupported content type: {}".format(content_type))
まとめ
PyTorchイメージを使うと割と自由度高く推論コードが書け、それをSageMakerにデプロイできることがわかりました。次は、Cannyやinpaintのサンプルも書く予定です。