LoginSignup
0
0

距離行列の作成とコスト計算ログ(暫定版)

Posted at

はじめに

Q&A形式でのログです。

参考URL

[1] Approximating Wasserstein distances with PyTorch - Daniel Daza
https://dfdazac.github.io/sinkhorn.html

[2] Github Approximating Wasserstein distances with PyTorch
https://github.com/dfdazac/wassdistance

[1]で外観を得て、[2]のソースコードを活用させて頂いた。

Q. AとBは10個のテンソルが集約されたテンソルです。距離行列Cを求めてください。(出力は10×8×8になります。)

Q. また、A,B,Cを使って、sinkhoonアルゴリズムを実装してください。(出力は各テンソルに対応するスカラーが集約された10次元のテンソルになります。)

import torch

from layers import SinkhornDistance

A = torch.randn(5000,8,3)
B = torch.randn(5000,8,3)

sinkhorn = SinkhornDistance(eps=0.1, max_iter=1000)

dist, P, C = sinkhorn(A, B)


print("Cost:",C)
print("Distance:", dist)

Cost: tensor([[[ 2.3851,  4.6197,  6.1333,  ...,  8.1374,  0.0954, 10.4732],
         [ 2.0316,  3.5151,  7.5441,  ..., 16.0305,  5.8873,  1.9347],
         [ 2.7555,  5.4268,  8.6332,  ..., 15.5158, 10.2188,  0.1254],
         ...,
         [ 6.5027,  4.5277,  1.2859,  ...,  9.0828,  8.6347, 17.3766],
         [ 2.9556,  4.4240,  8.7556,  ..., 16.0003,  3.1574,  5.7351],
         [ 4.4386,  9.3323, 11.8327,  ...,  9.0434,  1.1514, 11.5495]],

        [[ 3.9119,  2.5504, 16.9356,  ...,  0.5636,  2.8750,  2.3992],
         [ 0.7460,  1.6460,  6.8128,  ...,  3.4929,  1.6600,  2.0364],
         [ 2.9862,  1.4779,  6.7706,  ...,  4.1522,  0.6850,  5.9932],
         ...,
         [ 1.4956,  1.4795,  5.9848,  ...,  3.4524,  1.0665,  3.1406],
         [ 9.0050, 15.3521,  1.6283,  ..., 19.0885, 14.6266,  8.7985],
         [ 6.1278,  2.8331, 26.2087,  ...,  7.9378,  4.2873, 14.4930]],

        [[ 6.5756,  1.6820,  2.3338,  ...,  0.0749, 12.2661,  4.8055],
         [ 5.5638,  2.3422,  5.7793,  ...,  5.9754,  2.9283,  0.5275],
         [ 8.5439,  2.2426,  2.3104,  ...,  0.5533, 14.1425,  5.5368],
         ...,
         [ 1.7040, 10.5317,  8.9386,  ...,  8.9560,  5.0410, 10.4021],
         [12.1518,  1.4655,  9.7433,  ...,  6.5903, 10.1003,  0.9371],
         [ 1.2980,  7.9191,  7.1047,  ...,  6.7856,  4.4195,  7.9719]],

        ...,

        [[ 0.3512,  6.7879,  1.0256,  ...,  1.1962,  6.1207,  3.1671],
         [ 1.3076, 12.6213,  1.2475,  ...,  4.0746,  9.3167,  5.8667],
         [ 6.8642, 11.7086, 14.6859,  ..., 15.9577,  5.3421,  2.2187],
         ...,
         [ 4.8603, 22.7114,  8.9389,  ..., 11.1563, 15.9644,  4.2383],
         [ 0.5606,  4.8859,  1.8659,  ...,  1.5179,  4.5857,  2.5412],
         [ 9.8800,  5.5608, 13.2379,  ..., 17.1964,  1.4715, 11.7824]],

        [[ 2.8790,  5.0972,  7.1846,  ...,  3.3449, 13.3123,  8.6139],
         [ 5.2767,  1.2954,  5.3868,  ...,  3.5217,  1.5166,  0.1790],
         [ 8.9345,  5.4933, 13.6296,  ...,  8.7750,  7.9416,  1.5231],
         ...,
         [ 7.8432,  7.8624, 15.5582,  ...,  8.9348, 15.2594,  6.0558],
         [ 2.2326,  2.1241,  0.5018,  ...,  0.5880,  5.4888,  7.2493],
         [ 2.3860,  2.9177,  4.9353,  ...,  4.0292,  6.4979,  4.7113]],

        [[ 8.0692,  3.8253,  1.9432,  ...,  7.2783,  3.0775, 12.3579],
         [ 1.9585,  4.3045,  2.8538,  ...,  2.8123, 11.6873, 13.2805],
         [ 4.3140,  5.0203,  6.8623,  ...,  3.9087, 15.3001, 13.2837],
         ...,
         [12.6561,  4.3500,  4.0669,  ..., 10.0359,  1.3259, 11.6383],
         [10.6419,  1.2867,  3.7827,  ...,  4.8285,  1.3232,  3.9288],
         [ 5.9991,  3.8937,  0.7738,  ...,  5.1932,  3.8849, 11.1763]]])
Distance: tensor([2.4712, 1.6183, 1.8624,  ..., 1.9867, 1.9327, 2.3939])

これより、コスト値の集まりが求まった。
(今回は実装ではなく、[2]のソースコードを活用させていただいた。)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0