1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

画像生成Advent Calendar 2024

Day 8

ColabのTPU v2を利用してsafetensors形式で公開されているモデルの画像を生成

Last updated at Posted at 2024-12-07

はじめに

以前にColabのTPU v2でStable Diffusionを利用した画像生成をしました。

このときはCompVis/stable-diffusion-v1-4を利用したのですが、safetensors形式のファイルで公開されているモデルを利用したいことがあり、その方法を調べました。

この情報が参考になりました。

ColabのランタイムにTPU v2を選択して、以下のコードを実行していきます。

パッケージのインストール

!pip uninstall -y tensorflow && pip install tensorflow-cpu
!pip install --upgrade jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install --upgrade jaxlib flax transformers ftfy diffusers

モデルの準備

HuggingFaceからreal-dream-8.safetensorsをダウンロードして、形式を変換して保存します。

!wget https://huggingface.co/luisrguerra/sd-1.5-real-dream/resolve/main/real-dream-8.safetensors

from diffusers import StableDiffusionPipeline

pipeline = StableDiffusionPipeline.from_single_file("real-dream-8.safetensors")
pipeline.save_pretrained("real-dream-8", safe_serialization=False)

パイプラインのセットアップ

import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from flax.jax_utils import replicate

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "real-dream-8",
    from_pt=True,
    revision="bf16",
    dtype=jnp.bfloat16,
    safety_checker=None
)

p_params = replicate(params)

画像生成

プロンプトを指定して画像を生成します。最初の実行は3分ほどかかりますが、2回目以降は10秒程度で画像が生成できるようになります。

from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
import random
from IPython.display import display

prompt = "photo, best quality, summer, 1girl" # @param {type:"string"}
prompts = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)

rng = jax.random.PRNGKey(random.randrange(1000000))
rng = jax.random.split(rng, jax.device_count())

images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

for image in images:
    display(image)

このような画像が一度の実行で8枚生成されました。

download.png

おわりに

ColabのTPU v2を利用してsafetensorsで公開されているモデルの画像を生成することができました。細かいパラメーターの違いは確認していないのですが、ColabのT4 GPUのときよりも高速に画像を生成することができました。

今回、私はFlaxStableDiffusionPipelineを使うことでJAXやFlaxにも興味を持ち、今年の冬休みに勉強したいと思っています。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?