6
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

FLUX.1をdiffusers環境下VRAM 16GB以下のGPUで使用する

Last updated at Posted at 2024-08-11

概要

つい最近Stable Diffusion開発者の方達が発表したFLUX.1ですが、容量が大きく大体のPCではVRAM不足で動きません。
色々調査を行った所、量子化を行うことで動作することが判明しましたので、記載していこうと思います。

目次

diffusersの量子化

今回VRAMの使用量を抑えるために量子化を行います。量子化ですが、huggingfaceのライブラリである、Optimum Quantoを使用します。

Optimum Quantoについて

Pytorchの量子化バックエンドライブラリです。
transformersとdiffusersに対応しており、数行のコードを追加するだけでモデルを量子化して読み込むことができます。
また、CUDAやMPSなど複数デバイスでのロードに対応しているようです。
詳しくはリポジトリや以下を参考にしてください。
Memory-efficient Diffusion Transformers with Quanto and Diffusers

動作環境

  • Ubuntu 22.04
  • Intel i7-13700K
  • RTX A4000
  • RAM:48GB

環境構築

pipで以下をインストールします。

$ pip install torch torchvision transformers sentencepiece protobuf accelerate diffusers optimum-quanto huggingface_hub

※diffusersのバージョンが古い場合、FLUX.1の環境がないため、バージョンを0.30.0以上にバージョンアップして下さい

コード

今回動かすのはFLUX.1-schnellになります。

flux1.py
import time

import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from optimum.quanto import freeze, qfloat8, quantize
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

dtype = torch.bfloat16

bfl_repo = "black-forest-labs/FLUX.1-schnell"

# モデルの読み込み
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    bfl_repo, subfolder="scheduler"
)
text_encoder = CLIPTextModel.from_pretrained(
    "openai/clip-vit-large-patch14", torch_dtype=dtype
)
tokenizer = CLIPTokenizer.from_pretrained(
    "openai/clip-vit-large-patch14", torch_dtype=dtype
)
text_encoder_2 = T5EncoderModel.from_pretrained(
    bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype
)
tokenizer_2 = T5TokenizerFast.from_pretrained(
    bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype
)
vae = AutoencoderKL.from_pretrained(
    bfl_repo, subfolder="vae", torch_dtype=dtype
)
transformer = FluxTransformer2DModel.from_pretrained(
    bfl_repo, subfolder="transformer", torch_dtype=dtype
)

# 8bit量子化
quantize(transformer, weights=qfloat8)
freeze(transformer)

quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

# FluxPipelineの設定
pipe = FluxPipeline(
    scheduler=scheduler,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    text_encoder_2=None,
    tokenizer_2=tokenizer_2,
    vae=vae,
    transformer=None,
)
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
pipe.enable_model_cpu_offload()

# シードの固定
generator = torch.Generator().manual_seed(0)

# 画像の生成
image = pipe(
    prompt="A cat holding a sign that says hello world",
    width=1024,
    height=1024,
    num_inference_steps=4,
    max_sequence_length=256,
    generator=generator,
    guidance_scale=0.0,
).images[0]
image.save("flux-dev.png")

上記コードで以下のような画像が生成されます。
image.png

モデル読み込み速度の改善

これでFLUX.1が使用できるようになりました!お疲れ様です。

ここからはモデルの読み込み速度について改善したいと思います。
上記コードでFLUX.1が使用できるようになりましたが、モデルの読み込みや量子化がかなり遅いことについて気になりました。(生成速度はそこまで遅くないので、初回だけですが…)
調べたところ、量子化したモデルを保存することができそうなので、コードを記載します。

量子化モデルの保存

save.py
import json

import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from optimum.quanto import freeze, qfloat8, quantization_map, quantize
from transformers import T5EncoderModel

dtype = torch.bfloat16

bfl_repo = "black-forest-labs/FLUX.1-schnell"

text_encoder_2 = T5EncoderModel.from_pretrained(
    bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype
)
transformer = FluxTransformer2DModel.from_pretrained(
    bfl_repo, subfolder="transformer", torch_dtype=dtype
)

# FluxTransformerの保存
quantize(transformer, weights=qfloat8)
freeze(transformer)
transformer.save_pretrained(
    "./FLUX.1-schnell-distilled/transformer"
)
with open(
    "./FLUX.1-schnell-distilled/transformer/quanto_qmap.json",
    "w",
) as f:
    json.dump(quantization_map(transformer), f)

# T5Encoderの保存
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
text_encoder_2.save_pretrained(
    "/models/StableDiffution/Diffusers/FLUX.1-schnell-distilled/text_encoder_2"
)
with open(
    "/models/StableDiffution/Diffusers/FLUX.1-schnell-distilled/text_encoder_2/quanto_qmap.json",
    "w",
) as f:
    json.dump(quantization_map(text_encoder_2), f)

量子化モデルの読み込み

flux1_load.py
import time

import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from optimum.quanto import freeze, qfloat8, quantize
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

dtype = torch.bfloat16

bfl_repo = "black-forest-labs/FLUX.1-schnell"

# モデルの読み込み
# 量子化モデルを読み込むモデルに関してコメント化
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    bfl_repo, subfolder="scheduler"
)
text_encoder = CLIPTextModel.from_pretrained(
    "openai/clip-vit-large-patch14", torch_dtype=dtype
)
tokenizer = CLIPTokenizer.from_pretrained(
    "openai/clip-vit-large-patch14", torch_dtype=dtype
)
# text_encoder_2 = T5EncoderModel.from_pretrained(
#     bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype
# )
tokenizer_2 = T5TokenizerFast.from_pretrained(
    bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype
)
vae = AutoencoderKL.from_pretrained(
    bfl_repo, subfolder="vae", torch_dtype=dtype
)
# transformer = FluxTransformer2DModel.from_pretrained(
#     bfl_repo, subfolder="transformer", torch_dtype=dtype
# )

# 量子化モデルの読み込み
class QuantizedT5EncoderModelForCausalLM(QuantizedTransformersModel):
    auto_class = T5EncoderModel


class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
    base_class = FluxTransformer2DModel


transformer = QuantizedFluxTransformer2DModel.from_pretrained(
    "./FLUX.1-schnell-distilled/transformer",
).to(dtype=dtype)

T5EncoderModel.from_config = lambda c: T5EncoderModel(c)
text_encoder_2 = QuantizedT5EncoderModelForCausalLM.from_pretrained(
    "./FLUX.1-schnell-distilled/text_encoder_2"
).to(dtype=dtype)

# 8bit量子化
# 事前に量子化済みのため、コメント化
# quantize(transformer, weights=qfloat8)
# freeze(transformer)

# quantize(text_encoder_2, weights=qfloat8)
# freeze(text_encoder_2)

# FluxPipelineの設定
pipe = FluxPipeline(
    scheduler=scheduler,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    text_encoder_2=None,
    tokenizer_2=tokenizer_2,
    vae=vae,
    transformer=None,
)
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
pipe.enable_model_cpu_offload()

# シードの固定
generator = torch.Generator().manual_seed(0)

# 画像の生成
image = pipe(
    prompt="A cat holding a sign that says hello world",
    width=1024,
    height=1024,
    num_inference_steps=4,
    max_sequence_length=256,
    generator=generator,
    guidance_scale=0.0,
).images[0]
image.save("flux-dev.png")

最初のコードと違い、以下で量子化モデルを読み込んでおります。

# 量子化モデルの読み込み
class QuantizedT5EncoderModelForCausalLM(QuantizedTransformersModel):
    auto_class = T5EncoderModel


class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
    base_class = FluxTransformer2DModel


transformer = QuantizedFluxTransformer2DModel.from_pretrained(
    "./FLUX.1-schnell-distilled/transformer",
).to(dtype=dtype)

T5EncoderModel.from_config = lambda c: T5EncoderModel(c)
text_encoder_2 = QuantizedT5EncoderModelForCausalLM.from_pretrained(
    "./FLUX.1-schnell-distilled/text_encoder_2"
).to(dtype=dtype)

# 8bit量子化
# 事前に量子化済みのため、コメント化
# quantize(transformer, weights=qfloat8)
# freeze(transformer)

# quantize(text_encoder_2, weights=qfloat8)
# freeze(text_encoder_2)

速度比較

量子化モデルを保存・読み込みを行うことで以下のような読み込み速度改善が確認できました。

種類 時間(秒)
モデルを読み込み量子化してから生成 273.61
事前に量子化したモデルを読み込み生成 100.35

まとめ

今回はdiffusersで量子化を行いVRAM 16GBでFLUX.1の生成が行えるようにしました。
アプリに組み込むなど、diffusersでFLUX.1を使用する場合の参考になれば幸いです。
※ UIベースであれば、ComfyUIにFP8モデルがあるようです。

参考・出展リンク

6
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?