はじめに
Q&A形式でのログです。
参考URL
[1] Approximating Wasserstein distances with PyTorch - Daniel Daza
https://dfdazac.github.io/sinkhorn.html
[2] Github Approximating Wasserstein distances with PyTorch
https://github.com/dfdazac/wassdistance
[1]で外観を得て、[2]のソースコードを活用させて頂いた。
Q. AとBは10個のテンソルが集約されたテンソルです。距離行列Cを求めてください。(出力は10×8×8になります。)
Q. また、A,B,Cを使って、sinkhoonアルゴリズムを実装してください。(出力は各テンソルに対応するスカラーが集約された10次元のテンソルになります。)
import torch
from layers import SinkhornDistance
A = torch.randn(5000,8,3)
B = torch.randn(5000,8,3)
sinkhorn = SinkhornDistance(eps=0.1, max_iter=1000)
dist, P, C = sinkhorn(A, B)
print("Cost:",C)
print("Distance:", dist)
Cost: tensor([[[ 2.3851, 4.6197, 6.1333, ..., 8.1374, 0.0954, 10.4732],
[ 2.0316, 3.5151, 7.5441, ..., 16.0305, 5.8873, 1.9347],
[ 2.7555, 5.4268, 8.6332, ..., 15.5158, 10.2188, 0.1254],
...,
[ 6.5027, 4.5277, 1.2859, ..., 9.0828, 8.6347, 17.3766],
[ 2.9556, 4.4240, 8.7556, ..., 16.0003, 3.1574, 5.7351],
[ 4.4386, 9.3323, 11.8327, ..., 9.0434, 1.1514, 11.5495]],
[[ 3.9119, 2.5504, 16.9356, ..., 0.5636, 2.8750, 2.3992],
[ 0.7460, 1.6460, 6.8128, ..., 3.4929, 1.6600, 2.0364],
[ 2.9862, 1.4779, 6.7706, ..., 4.1522, 0.6850, 5.9932],
...,
[ 1.4956, 1.4795, 5.9848, ..., 3.4524, 1.0665, 3.1406],
[ 9.0050, 15.3521, 1.6283, ..., 19.0885, 14.6266, 8.7985],
[ 6.1278, 2.8331, 26.2087, ..., 7.9378, 4.2873, 14.4930]],
[[ 6.5756, 1.6820, 2.3338, ..., 0.0749, 12.2661, 4.8055],
[ 5.5638, 2.3422, 5.7793, ..., 5.9754, 2.9283, 0.5275],
[ 8.5439, 2.2426, 2.3104, ..., 0.5533, 14.1425, 5.5368],
...,
[ 1.7040, 10.5317, 8.9386, ..., 8.9560, 5.0410, 10.4021],
[12.1518, 1.4655, 9.7433, ..., 6.5903, 10.1003, 0.9371],
[ 1.2980, 7.9191, 7.1047, ..., 6.7856, 4.4195, 7.9719]],
...,
[[ 0.3512, 6.7879, 1.0256, ..., 1.1962, 6.1207, 3.1671],
[ 1.3076, 12.6213, 1.2475, ..., 4.0746, 9.3167, 5.8667],
[ 6.8642, 11.7086, 14.6859, ..., 15.9577, 5.3421, 2.2187],
...,
[ 4.8603, 22.7114, 8.9389, ..., 11.1563, 15.9644, 4.2383],
[ 0.5606, 4.8859, 1.8659, ..., 1.5179, 4.5857, 2.5412],
[ 9.8800, 5.5608, 13.2379, ..., 17.1964, 1.4715, 11.7824]],
[[ 2.8790, 5.0972, 7.1846, ..., 3.3449, 13.3123, 8.6139],
[ 5.2767, 1.2954, 5.3868, ..., 3.5217, 1.5166, 0.1790],
[ 8.9345, 5.4933, 13.6296, ..., 8.7750, 7.9416, 1.5231],
...,
[ 7.8432, 7.8624, 15.5582, ..., 8.9348, 15.2594, 6.0558],
[ 2.2326, 2.1241, 0.5018, ..., 0.5880, 5.4888, 7.2493],
[ 2.3860, 2.9177, 4.9353, ..., 4.0292, 6.4979, 4.7113]],
[[ 8.0692, 3.8253, 1.9432, ..., 7.2783, 3.0775, 12.3579],
[ 1.9585, 4.3045, 2.8538, ..., 2.8123, 11.6873, 13.2805],
[ 4.3140, 5.0203, 6.8623, ..., 3.9087, 15.3001, 13.2837],
...,
[12.6561, 4.3500, 4.0669, ..., 10.0359, 1.3259, 11.6383],
[10.6419, 1.2867, 3.7827, ..., 4.8285, 1.3232, 3.9288],
[ 5.9991, 3.8937, 0.7738, ..., 5.1932, 3.8849, 11.1763]]])
Distance: tensor([2.4712, 1.6183, 1.8624, ..., 1.9867, 1.9327, 2.3939])
これより、コスト値の集まりが求まった。
(今回は実装ではなく、[2]のソースコードを活用させていただいた。)