0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

xFormersをONNXに変換するときのエラー解決: Memory-efficient attention

Last updated at Posted at 2024-12-12

画像生成モデル(Diffusion model)をONNXに変換する際、モデル内のxFormersが原因でエラーが起きたので、解決方法を残します。

ONNXとは

ONNXは、TensorFlowやPyTorchの様々な機械学習フレームワーク間で、モデルを相互運用可能にするためのオープンフォーマットです。
詳細な説明は、以下を含めた他記事を参考にしてください。

xFormersとは

xFormersは、Transformersの研究と応用を加速することを目的に開発されたPyTorchベースのライブラリです。例えば、Stable Diffusionを用いた画像生成において、高速化やVRAM使用量の削減のために使われています。

エラーの原因

今回発生したエラーの原因は、xFormersのMemory-efficient attentionでした。そこで、調査のために、memory_efficient_attentionのみのモデルを対象として、ONNXへの変換を試みます。
実行環境は以下の通りです。

  • Python 3.10.12
  • CUDA 12.1
  • Pytorch 2.5.1
  • onnx 1.17.0
xformers_attention.py
import torch
import torch.nn as nn
from xformers.ops import memory_efficient_attention
torch.manual_seed(42)

class XformersAttentionModel(nn.Module):
    def __init__(self, dropout_p=0.0):
        super(XformersAttentionModel, self).__init__()
        self.dropout_p = dropout_p

    def forward(self, query, key, value, attn_bias=None):
        output = memory_efficient_attention(
            query, key, value, 
            attn_bias=attn_bias,
            p=self.dropout_p
        )
        return output

if __name__ == "__main__":
    # モデルインスタンス
    model_xformers = XformersAttentionModel().cuda()

    # ダミー入力
    dummy_query = torch.randn(1, 8, 128, 64, device="cuda")
    dummy_key = torch.randn(1, 8, 128, 64, device="cuda")
    dummy_value = torch.randn(1, 8, 128, 64, device="cuda")

    # ONNXモデルに出力
    torch.onnx.export(
        model_xformers,
        (dummy_query, dummy_key, dummy_value),
        "model_xformers.onnx",
        opset_version=17,
        input_names=["query", "key", "value"],
        output_names=["output"],
    )
...
RuntimeError: Found an unsupported argument type c10::SymInt in the JIT tracer. File a bug report.

やはりできません。こちらのissueでもあるように、基本的にxFormersはONNXをサポートしていないようです。

一方、xFormersのドキュメントでは、memory_efficient_attentionをスクラッチで書く方法が記載されていました。これを使えば、ONNXに変換できるかもしれません。

スクラッチで書く

実際に書きましょう。クラス以外は先ほどのコードとほぼ同じです。

equivalent_attention.py
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(42)

class EquivalentAttentionModel(nn.Module):
    def __init__(self, dropout_p=0.0):
        super(EquivalentAttentionModel, self).__init__()
        self.dropout_p = dropout_p

    def forward(self, query, key, value, attn_bias=None):
        scale = 1.0 / query.shape[-1] ** 0.5
        query = query * scale
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        attn = query @ key.transpose(-2, -1)
        if attn_bias is not None:
            attn = attn + attn_bias
        attn = attn.softmax(-1)
        attn = F.dropout(attn, self.dropout_p)
        output = attn @ value
        return output.transpose(1, 2)

if __name__ == "__main__":
    # モデルインスタンス
    model_equivalent = EquivalentAttentionModel().cuda()

    # ダミー入力
    dummy_query = torch.randn(1, 8, 128, 64, device="cuda")
    dummy_key = torch.randn(1, 8, 128, 64, device="cuda")
    dummy_value = torch.randn(1, 8, 128, 64, device="cuda")

    # ONNXモデルに出力
    torch.onnx.export(
        model_equivalent,
        (dummy_query, dummy_key, dummy_value),
        "model_equivalent.onnx",
        opset_version=17,
        input_names=["query", "key", "value"],
        output_names=["output"],
    )

すると、無事にONNXに変換できました。Netronで確認します。

スクリーンショット 2024-12-12 15.16.08.png

精度の検証

スクラッチで書いたことによって、どの程度精度に影響が出るか確認します。
xformersを使ったモデルとONNXモデルそれぞれに対して、同じ入力を与えて、出力の差の最大値を表示します。

diff_val.py
import numpy as np
import torch
import onnxruntime
from xformers_attention import XformersAttentionModel
torch.manual_seed(42)

model_xformers = XformersAttentionModel().cuda()
sess = onnxruntime.InferenceSession("model_equivalent.onnx")

# ダミー入力
dummy_query = torch.rand(1, 8, 128, 64, device="cuda")
dummy_key = torch.rand(1, 8, 128, 64, device="cuda")
dummy_value = torch.rand(1, 8, 128, 64, device="cuda")

# 出力を計算
output_xformers = model_xformers(dummy_query, dummy_key, dummy_value)
output_onnx = sess.run(
    None,
    {
        "query": dummy_query.cpu().numpy(),
        "key": dummy_key.cpu().numpy(),
        "value": dummy_value.cpu().numpy(),
    },
)[0]

# 差分を表示
diff = output_xformers.cpu().numpy() - output_onnx
max_difference = np.abs(diff).max()
print("Max difference:", max_difference)
Max difference: 2.9802322e-07

差は最大でも10の-7乗オーダーなので、ほぼ気にしなくていいでしょう。

おわりに

今回のコードは以下にまとめてあります。

また、本記事を書くにあたり、以下の記事を参考にさせていただきました。

おまけ

  • xFormersを使ったモデル(GPU)
  • xFormersをスクラッチに書き換えたモデル(GPU)
  • ONNXモデル(CPU)

で実行速度を比較しました。

diff_time.py
import time
import torch
import onnxruntime
from xformers_attention import XformersAttentionModel
from equivalent_attention import EquivalentAttentionModel
torch.manual_seed(42)

model_xformers = XformersAttentionModel().cuda()
model_equivalent = EquivalentAttentionModel().cuda()
sess = onnxruntime.InferenceSession("model_equivalent.onnx")

# 計測用のイベント
start_event_xformers = torch.cuda.Event(enable_timing=True)
end_event_xformers = torch.cuda.Event(enable_timing=True)
start_event_equivalent = torch.cuda.Event(enable_timing=True)
end_event_equivalent = torch.cuda.Event(enable_timing=True)

# ダミー入力
dummy_query = torch.rand(1, 8, 128, 64, device="cuda")
dummy_key = torch.rand(1, 8, 128, 64, device="cuda")
dummy_value = torch.rand(1, 8, 128, 64, device="cuda")

# Xformersモデルの計算時間を測定
start_event_xformers.record()
output_xformers = model_xformers(dummy_query, dummy_key, dummy_value)
end_event_xformers.record()
torch.cuda.synchronize()

# Equivalentモデルの計算時間を測定
start_event_equivalent.record()
output_equivalent = model_equivalent(dummy_query, dummy_key, dummy_value)
end_event_equivalent.record()
torch.cuda.synchronize()

# ONNXモデルの計算時間を測定
start = time.time()
output_onnx = sess.run(
    None,
    {
        "query": dummy_query.cpu().numpy(),
        "key": dummy_key.cpu().numpy(),
        "value": dummy_value.cpu().numpy(),
    },
)[0]
end = time.time()

# 計算時間を取得
time_xformers = start_event_xformers.elapsed_time(end_event_xformers)
time_equivalent = start_event_equivalent.elapsed_time(end_event_equivalent)
time_onnx = (end - start) * 1000

# 計算時間の差を表示
print("Xformers Model Time (ms):", time_xformers)
print("Equivalent Model Time (ms):", time_equivalent)
print("ONNX Model Time (ms):", time_onnx)
Xformers Model Time (ms): 0.5304319858551025
Equivalent Model Time (ms): 713.9911499023438
ONNX Model Time (ms): 1.3279914855957031

ONNXがいい速さしてます。まあ元のモデルがAttention一つだけなので、比較になるのかは怪しいですが...
あと、スクラッチにするとこんなにも遅くなるんですね。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?