はじめに
以前にColabのTPU v2でStable Diffusionを利用した画像生成をしました。
このときはCompVis/stable-diffusion-v1-4
を利用したのですが、safetensors形式のファイルで公開されているモデルを利用したいことがあり、その方法を調べました。
この情報が参考になりました。
ColabのランタイムにTPU v2を選択して、以下のコードを実行していきます。
パッケージのインストール
!pip uninstall -y tensorflow && pip install tensorflow-cpu
!pip install --upgrade jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install --upgrade jaxlib flax transformers ftfy diffusers
モデルの準備
HuggingFaceからreal-dream-8.safetensorsをダウンロードして、形式を変換して保存します。
!wget https://huggingface.co/luisrguerra/sd-1.5-real-dream/resolve/main/real-dream-8.safetensors
from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_single_file("real-dream-8.safetensors")
pipeline.save_pretrained("real-dream-8", safe_serialization=False)
パイプラインのセットアップ
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from flax.jax_utils import replicate
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"real-dream-8",
from_pt=True,
revision="bf16",
dtype=jnp.bfloat16,
safety_checker=None
)
p_params = replicate(params)
画像生成
プロンプトを指定して画像を生成します。最初の実行は3分ほどかかりますが、2回目以降は10秒程度で画像が生成できるようになります。
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
import random
from IPython.display import display
prompt = "photo, best quality, summer, 1girl" # @param {type:"string"}
prompts = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
rng = jax.random.PRNGKey(random.randrange(1000000))
rng = jax.random.split(rng, jax.device_count())
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
for image in images:
display(image)
このような画像が一度の実行で8枚生成されました。
おわりに
ColabのTPU v2を利用してsafetensorsで公開されているモデルの画像を生成することができました。細かいパラメーターの違いは確認していないのですが、ColabのT4 GPUのときよりも高速に画像を生成することができました。
今回、私はFlaxStableDiffusionPipeline
を使うことでJAXやFlaxにも興味を持ち、今年の冬休みに勉強したいと思っています。