はじめに
Q&A形式でのログです。
参考URL
[1] PyTorch公式サイト:torch.cdist — PyTorch 2.0 documentation
https://pytorch.org/docs/stable/generated/torch.cdist.html
[2] Twitter Yuta Suzuki:距離行列計算を自前からtorch.cdistに切り替えた話
https://twitter.com/resnant/status/1270998333671133185
[3] PyTorchのtorch.cdist関数は、2つの行列の全対ユークリッド(または任意のp-norm)距離を計算するのに便利なツールです。
https://runebook.dev/ja/docs/pytorch/generated/torch.cdist
[2][3]から外観を得て、[1]に沿って距離行列を作成した。
ソースコードは現在2か所からdownload可能です。
.ipynb形式:「morii_testcode_test_GithubGist_to_multi.ipynb」
・OneDriveから
https://1drv.ms/u/s!AncYIM5QjoV1gq4V6GtV2oVvmdm0Jg?e=2BSbVM
・GoogleDriveから
https://drive.google.com/file/d/1AMqcEJ9q4c6ZZ4UZDnJG0VKDN4McB_O8/view?usp=sharing
.py形式:「morii_testcode_test_GithubGist_to_multi.py」
・OneDriveから
https://1drv.ms/u/s!AncYIM5QjoV1gq4Wlew3mDDALentZw?e=K8tNBK
・GoogleDriveから
https://drive.google.com/file/d/1_T45kts2wSXVFsCUeyqd1RGWWiFaBt9G/view?usp=sharing
Q1. A_1 とB_1 間の距離行列C_1を求めてください.(出力は8×8のテンソルになるはずです)
import torch
A_1=torch.randn(8,3)
B_1=torch.randn(8,3)
print(f'A_1は\n{A_1}')
print(f'B_1は\n{B_1}')
A_1は
tensor([[-1.2416, -2.0468, 0.3050],
[ 0.0556, -0.3844, 0.6254],
[ 0.8345, -1.0627, -1.2592],
[-0.6727, 0.6312, 0.4476],
[-0.3870, -0.9692, -0.9783],
[-0.0361, 1.5460, -0.6113],
[ 0.4282, 0.8991, 0.4301],
[-1.4248, 0.5365, -0.6753]])
B_1は
tensor([[-1.7169, -0.2868, -0.1328],
[-2.0779, -0.2973, 1.1613],
[-0.5776, 0.4738, 0.6828],
[-1.6807, 0.4168, -0.4185],
[-0.8357, 0.9086, 1.1973],
[-0.7952, 0.7565, 1.5547],
[-0.1374, 0.0659, 0.0473],
[ 0.1833, 0.2078, 0.3087]])
C_1=torch.cdist(A_1, B_1, p=2)
print(f'C_1は\n{C_1}')
C_1は
tensor([[1.8749, 2.1198, 2.6339, 2.6049, 3.1137, 3.1015, 2.3978, 2.6672],
[1.9304, 2.2015, 1.0681, 2.1786, 1.6713, 1.6996, 0.7578, 0.6836],
[2.8949, 3.8635, 2.8507, 3.0367, 3.5651, 3.7260, 1.9812, 2.1205],
[1.5067, 1.8292, 0.2985, 1.3462, 0.8158, 1.1208, 0.8754, 0.9651],
[1.7173, 2.8087, 2.2086, 1.9769, 2.9087, 3.0920, 1.4784, 1.8349],
[2.5324, 3.2724, 1.7656, 2.0043, 2.0777, 2.4271, 1.6231, 1.6386],
[2.5149, 2.8717, 1.1209, 2.3239, 1.4786, 1.6678, 1.0773, 0.7433],
[1.0284, 2.1202, 1.6019, 0.3819, 1.9980, 2.3276, 1.5495, 1.9137]])
Q2. 行列A_1,B_1,C_1を使って、Sinkhoon アルゴリズムを実装してください。(出力は1次元のスカラーになります。)
from sinkhorn import SinkhornSolver
epsilon = 0.001
solver = SinkhornSolver(epsilon=epsilon, iterations=10000)
cost, pi = solver.forward(A_1, B_1)
Finished computing transport plan in 1 iterations
C_dist = solver._compute_cost_pytorch_cdist(A_1, B_1)
print(f'C_dist\n{C_dist}')
C_dist
tensor([[[1.8749, 2.1198, 2.6339, 2.6049, 3.1137, 3.1015, 2.3978, 2.6672]],
[[1.9304, 2.2015, 1.0681, 2.1786, 1.6713, 1.6996, 0.7578, 0.6836]],
[[2.8949, 3.8635, 2.8507, 3.0367, 3.5651, 3.7260, 1.9812, 2.1205]],
[[1.5067, 1.8292, 0.2985, 1.3462, 0.8158, 1.1208, 0.8754, 0.9651]],
[[1.7173, 2.8087, 2.2086, 1.9769, 2.9087, 3.0920, 1.4784, 1.8349]],
[[2.5324, 3.2724, 1.7656, 2.0043, 2.0777, 2.4271, 1.6231, 1.6386]],
[[2.5149, 2.8717, 1.1209, 2.3239, 1.4786, 1.6678, 1.0773, 0.7433]],
[[1.0284, 2.1202, 1.6019, 0.3819, 1.9980, 2.3276, 1.5495, 1.9137]]])
print(f'cost\n{cost}')
cost
1.6988214254379272
Q3. AとBは10個のテンソルが集約されたテンソルです。距離行列Cを求めてください。(出力は10×8×8になります。)
A=torch.randn(10,8,3)
B=torch.randn(10,8,3)
C=torch.cdist(A, B, p=2)
print(C)
tensor([[[3.3452, 3.5884, 2.6145, 3.5548, 2.6811, 1.7614, 4.8377, 3.2827],
[2.7055, 2.8352, 3.2904, 3.4297, 3.2852, 3.5755, 3.8415, 3.5525],
[2.8125, 3.0205, 1.4525, 2.8069, 2.1878, 0.3542, 4.1632, 2.1684],
[2.7713, 3.0012, 2.2180, 3.0812, 2.3985, 1.7458, 4.2431, 2.8260],
[2.8290, 2.6322, 2.2476, 3.3099, 4.0334, 2.9723, 2.6648, 1.8386],
[1.8975, 1.9757, 2.3589, 1.1350, 2.3386, 3.7782, 2.4262, 2.0510],
[0.9523, 0.7452, 2.4652, 1.8895, 2.9092, 3.7592, 1.3614, 2.1590],
[2.3381, 2.3968, 2.0447, 1.6194, 2.5042, 3.3260, 2.8266, 1.7473]],
[[1.3115, 1.6532, 3.0833, 1.0828, 2.1806, 2.0015, 1.7661, 0.9011],
[1.5730, 2.5654, 3.9633, 0.3884, 2.4030, 1.8086, 2.4826, 0.4883],
[1.9457, 2.7556, 3.9811, 0.8660, 1.8056, 1.8166, 2.4628, 0.9498],
[2.3601, 4.1735, 4.9313, 2.3770, 2.9711, 4.1579, 4.1577, 3.0857],
[2.7012, 2.3534, 3.7331, 1.7409, 2.1794, 0.5135, 1.8746, 0.9426],
[2.1261, 1.5404, 2.7603, 1.6741, 1.2132, 1.6354, 1.0849, 1.2387],
[1.7516, 2.0816, 2.8682, 2.7619, 3.5830, 3.9498, 2.8420, 2.9162],
[1.8000, 1.4290, 1.8932, 2.7812, 2.7963, 3.7957, 2.1804, 2.8766]],
[[3.0585, 2.9258, 2.9229, 1.4963, 2.1865, 1.5238, 2.9494, 1.1623],
[2.4140, 2.8110, 2.6675, 2.7249, 3.4203, 2.4400, 1.6703, 2.7754],
[2.4229, 2.9906, 2.1859, 1.9546, 2.4580, 1.9566, 2.0477, 1.5294],
[1.5616, 2.3468, 1.9568, 1.8826, 2.4227, 2.0639, 1.0216, 2.6526],
[1.3946, 3.6217, 0.4893, 2.3777, 1.6518, 3.2633, 1.6216, 3.1761],
[1.9462, 4.6370, 1.2176, 3.8646, 3.2498, 4.5615, 1.9073, 4.4238],
[1.3844, 2.4626, 1.9043, 1.2222, 0.7378, 2.4471, 1.8722, 3.2822],
...
[2.0949, 2.8157, 4.1192, 2.0881, 2.6393, 1.0169, 1.6152, 1.5368],
[1.8283, 2.0573, 2.3762, 1.9086, 3.7325, 2.7268, 2.0280, 3.4565],
[0.9737, 2.8433, 1.9362, 2.7126, 3.1949, 2.4475, 2.1220, 3.6938],
[3.0246, 2.2267, 3.0546, 2.2309, 4.6683, 3.7015, 2.8659, 4.0444]]])
Q4. A,B,Cを使って、sinkhoonアルゴリズムを実装してください。(出力は各テンソルに対応するスカラーが集約された10次元のテンソルになります。)
epsilon = 0.00001
solver = SinkhornSolver(epsilon=epsilon, iterations=100000)
cost, pi = solver.forward(A, B)
print(f'cost\n',cost)
cost
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
C_dist = solver._compute_cost_pytorch_cdist(A, B)
print(C_dist)
tensor([[[[3.3452, 3.5884, 2.6145, 3.5548, 2.6811, 1.7614, 4.8377, 3.2827]],
[[2.7055, 2.8352, 3.2904, 3.4297, 3.2852, 3.5755, 3.8415, 3.5525]],
[[2.8125, 3.0205, 1.4525, 2.8069, 2.1878, 0.3542, 4.1632, 2.1684]],
[[2.7713, 3.0012, 2.2180, 3.0812, 2.3985, 1.7458, 4.2431, 2.8260]],
[[2.8290, 2.6322, 2.2476, 3.3099, 4.0334, 2.9723, 2.6648, 1.8386]],
[[1.8975, 1.9757, 2.3589, 1.1350, 2.3386, 3.7782, 2.4262, 2.0510]],
[[0.9523, 0.7452, 2.4652, 1.8895, 2.9092, 3.7592, 1.3614, 2.1590]],
[[2.3381, 2.3968, 2.0447, 1.6194, 2.5042, 3.3260, 2.8266, 1.7473]]],
[[[1.3115, 1.6532, 3.0833, 1.0828, 2.1806, 2.0015, 1.7661, 0.9011]],
[[1.5730, 2.5654, 3.9633, 0.3884, 2.4030, 1.8086, 2.4826, 0.4883]],
[[1.9457, 2.7556, 3.9811, 0.8660, 1.8056, 1.8166, 2.4628, 0.9498]],
[[2.3601, 4.1735, 4.9313, 2.3770, 2.9711, 4.1579, 4.1577, 3.0857]],
...
[[0.9737, 2.8433, 1.9362, 2.7126, 3.1949, 2.4475, 2.1220, 3.6938]],
[[3.0246, 2.2267, 3.0546, 2.2309, 4.6683, 3.7015, 2.8659, 4.0444]]]])