1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyTorchでの半精度におけるコサイン類似度の問題(OpenAI CLIP)

Last updated at Posted at 2022-12-22

Stable Diffusionなどの画像生成モデルに使われているOpenAI CLIPでは画像とテキスト間の類似度にコサイン類似度を使用していますが、PyTorch 1.13現在PyTorchのコサイン類似度(Cosine Similarity)には半精度の問題あるためここに書いておきます。

なお、ここでは分かりやすく一つだけのコサイン類似度を扱うため半精度の必要性が分かりにくいと思いますが、実際に使っているコードはマトリクスで多対多のコサイン類似度を計算していてVRAMが厳しかったりします…。

まずはCLIPにおける「room」のテキスト埋め込みの自己コサイン類似度が精度の問題で0になってしまうケースを見てみましょう。CLIP の設定は Stable Diffusion 1.x に準ずるものとします。

import transformers
text = "room"

# トークナイザーでテキストをトークン化してNVIDIA GPUへ送る
tokenizer = transformers.CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
tokens = tokenizer(text, padding="max_length", max_length=77, return_attention_mask=False, return_tensors="pt").input_ids.cuda()
# >>> tokens
# tensor([[49406,  1530, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
#         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
#         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
#         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
#         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
#         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
#         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
#         49407, 49407, 49407, 49407, 49407, 49407, 49407]], device='cuda:0')

# テキストから変換したトークンをテキスト埋め込みへと変換してNVIDIA GPUへ送って半精度に落とす
text_encoder = transformers.CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14",)
text_embedded = text_encoder(tokens)[0].flatten(0).cuda().half()
# >>> text_embedded.shape
# torch.Size([59136])

# トークンから変換したテキスト埋め込みの自己コサイン類似度を取る
result = torch.nn.functional.cosine_similarity(text_embedded, text_embedded, dim=0, eps=1e-6)
# >>> result
# tensor(0., device='cuda:0', dtype=torch.float16, grad_fn=<SumBackward1>) # 0 になった!?

見て分かる通り、自己コサイン類似度が0になっちゃいました。そこで実装を見ていきます。PyTorchでのコサイン類似度の実装は単純化してPython化すると以下のようになっています。

def cosine_similarity(x1, x2, eps): # dimは単純化のため省略
    w12 = torch.sum(x1 * x2)
    w1 = torch.sum(x1 * x1)
    w2 = torch.sum(x2 * x2)
    n12 = (w1 * w2).clamp_min_(eps * eps).sqrt_()
    return w12 / n12

精度の問題を確認したところsumで問題が起きてるので、そこの精度を上げる必要があるようです。

def cosine_similarity_fix(x1, x2, eps):
    w12 = torch.sum(x1 * x2, dtype=torch.float)
    w1 = torch.sum(x1 * x1, dtype=torch.float)
    w2 = torch.sum(x2 * x2, dtype=torch.float)
    return w12 * (w1 * w2).clamp_min_(eps * eps).rsqrt()

でも半精度のままで計算するならこっちの形が正解かもですね。

def cosine_similarity_fix_halfonly(x1, x2, eps):
    t1 = x1 / torch.norm(x1).clamp_min_(eps * eps)
    t2 = x2 / torch.norm(x2).clamp_min_(eps * eps)
    return torch.sum(t1 * t2)

この関数を使って前述の自己コサイン類似度を計算するとちゃんと1.0となります。

>>> cosine_similarity_fix_halfonly(text_embedded, text_embedded, 1e-6)
tensor(1., device='cuda:0', dtype=torch.float16, grad_fn=<SumBackward0>)
…あれ、ベクトルノルムも似たような計算をしてるはず…?何か勘違いしてそう
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?