LoginSignup
0
0

距離行列の作成とコスト計算ログ(不発版)

Last updated at Posted at 2023-06-28

はじめに

Q&A形式でのログです。

参考URL

[1] PyTorch公式サイト:torch.cdist — PyTorch 2.0 documentation
https://pytorch.org/docs/stable/generated/torch.cdist.html

[2] Twitter Yuta Suzuki:距離行列計算を自前からtorch.cdistに切り替えた話
https://twitter.com/resnant/status/1270998333671133185

[3] PyTorchのtorch.cdist関数は、2つの行列の全対ユークリッド(または任意のp-norm)距離を計算するのに便利なツールです。
https://runebook.dev/ja/docs/pytorch/generated/torch.cdist

[2][3]から外観を得て、[1]に沿って距離行列を作成した。

ソースコードは現在2か所からdownload可能です。

.ipynb形式:「morii_testcode_test_GithubGist_to_multi.ipynb」
・OneDriveから
https://1drv.ms/u/s!AncYIM5QjoV1gq4V6GtV2oVvmdm0Jg?e=2BSbVM
・GoogleDriveから
https://drive.google.com/file/d/1AMqcEJ9q4c6ZZ4UZDnJG0VKDN4McB_O8/view?usp=sharing

.py形式:「morii_testcode_test_GithubGist_to_multi.py」
・OneDriveから
https://1drv.ms/u/s!AncYIM5QjoV1gq4Wlew3mDDALentZw?e=K8tNBK
・GoogleDriveから
https://drive.google.com/file/d/1_T45kts2wSXVFsCUeyqd1RGWWiFaBt9G/view?usp=sharing

Q1. A_1 とB_1 間の距離行列C_1を求めてください.(出力は8×8のテンソルになるはずです)

import torch
A_1=torch.randn(8,3)
B_1=torch.randn(8,3)
print(f'A_1は\n{A_1}')
print(f'B_1は\n{B_1}')

A_1は
tensor([[-1.2416, -2.0468, 0.3050],
[ 0.0556, -0.3844, 0.6254],
[ 0.8345, -1.0627, -1.2592],
[-0.6727, 0.6312, 0.4476],
[-0.3870, -0.9692, -0.9783],
[-0.0361, 1.5460, -0.6113],
[ 0.4282, 0.8991, 0.4301],
[-1.4248, 0.5365, -0.6753]])

B_1は
tensor([[-1.7169, -0.2868, -0.1328],
[-2.0779, -0.2973, 1.1613],
[-0.5776, 0.4738, 0.6828],
[-1.6807, 0.4168, -0.4185],
[-0.8357, 0.9086, 1.1973],
[-0.7952, 0.7565, 1.5547],
[-0.1374, 0.0659, 0.0473],
[ 0.1833, 0.2078, 0.3087]])

C_1=torch.cdist(A_1, B_1, p=2)
print(f'C_1は\n{C_1}')

C_1は
tensor([[1.8749, 2.1198, 2.6339, 2.6049, 3.1137, 3.1015, 2.3978, 2.6672],
[1.9304, 2.2015, 1.0681, 2.1786, 1.6713, 1.6996, 0.7578, 0.6836],
[2.8949, 3.8635, 2.8507, 3.0367, 3.5651, 3.7260, 1.9812, 2.1205],
[1.5067, 1.8292, 0.2985, 1.3462, 0.8158, 1.1208, 0.8754, 0.9651],
[1.7173, 2.8087, 2.2086, 1.9769, 2.9087, 3.0920, 1.4784, 1.8349],
[2.5324, 3.2724, 1.7656, 2.0043, 2.0777, 2.4271, 1.6231, 1.6386],
[2.5149, 2.8717, 1.1209, 2.3239, 1.4786, 1.6678, 1.0773, 0.7433],
[1.0284, 2.1202, 1.6019, 0.3819, 1.9980, 2.3276, 1.5495, 1.9137]])

Q2. 行列A_1,B_1,C_1を使って、Sinkhoon アルゴリズムを実装してください。(出力は1次元のスカラーになります。)

from sinkhorn import SinkhornSolver
epsilon = 0.001
solver = SinkhornSolver(epsilon=epsilon, iterations=10000)
cost, pi = solver.forward(A_1, B_1)

Finished computing transport plan in 1 iterations

C_dist = solver._compute_cost_pytorch_cdist(A_1, B_1)
print(f'C_dist\n{C_dist}')
C_dist
tensor([[[1.8749, 2.1198, 2.6339, 2.6049, 3.1137, 3.1015, 2.3978, 2.6672]],

        [[1.9304, 2.2015, 1.0681, 2.1786, 1.6713, 1.6996, 0.7578, 0.6836]],

        [[2.8949, 3.8635, 2.8507, 3.0367, 3.5651, 3.7260, 1.9812, 2.1205]],

        [[1.5067, 1.8292, 0.2985, 1.3462, 0.8158, 1.1208, 0.8754, 0.9651]],

        [[1.7173, 2.8087, 2.2086, 1.9769, 2.9087, 3.0920, 1.4784, 1.8349]],

        [[2.5324, 3.2724, 1.7656, 2.0043, 2.0777, 2.4271, 1.6231, 1.6386]],

        [[2.5149, 2.8717, 1.1209, 2.3239, 1.4786, 1.6678, 1.0773, 0.7433]],

        [[1.0284, 2.1202, 1.6019, 0.3819, 1.9980, 2.3276, 1.5495, 1.9137]]])

print(f'cost\n{cost}')

cost
1.6988214254379272

Q3. AとBは10個のテンソルが集約されたテンソルです。距離行列Cを求めてください。(出力は10×8×8になります。)

A=torch.randn(10,8,3)
B=torch.randn(10,8,3)
C=torch.cdist(A, B, p=2)
print(C)
tensor([[[3.3452, 3.5884, 2.6145, 3.5548, 2.6811, 1.7614, 4.8377, 3.2827],
         [2.7055, 2.8352, 3.2904, 3.4297, 3.2852, 3.5755, 3.8415, 3.5525],
         [2.8125, 3.0205, 1.4525, 2.8069, 2.1878, 0.3542, 4.1632, 2.1684],
         [2.7713, 3.0012, 2.2180, 3.0812, 2.3985, 1.7458, 4.2431, 2.8260],
         [2.8290, 2.6322, 2.2476, 3.3099, 4.0334, 2.9723, 2.6648, 1.8386],
         [1.8975, 1.9757, 2.3589, 1.1350, 2.3386, 3.7782, 2.4262, 2.0510],
         [0.9523, 0.7452, 2.4652, 1.8895, 2.9092, 3.7592, 1.3614, 2.1590],
         [2.3381, 2.3968, 2.0447, 1.6194, 2.5042, 3.3260, 2.8266, 1.7473]],

        [[1.3115, 1.6532, 3.0833, 1.0828, 2.1806, 2.0015, 1.7661, 0.9011],
         [1.5730, 2.5654, 3.9633, 0.3884, 2.4030, 1.8086, 2.4826, 0.4883],
         [1.9457, 2.7556, 3.9811, 0.8660, 1.8056, 1.8166, 2.4628, 0.9498],
         [2.3601, 4.1735, 4.9313, 2.3770, 2.9711, 4.1579, 4.1577, 3.0857],
         [2.7012, 2.3534, 3.7331, 1.7409, 2.1794, 0.5135, 1.8746, 0.9426],
         [2.1261, 1.5404, 2.7603, 1.6741, 1.2132, 1.6354, 1.0849, 1.2387],
         [1.7516, 2.0816, 2.8682, 2.7619, 3.5830, 3.9498, 2.8420, 2.9162],
         [1.8000, 1.4290, 1.8932, 2.7812, 2.7963, 3.7957, 2.1804, 2.8766]],

        [[3.0585, 2.9258, 2.9229, 1.4963, 2.1865, 1.5238, 2.9494, 1.1623],
         [2.4140, 2.8110, 2.6675, 2.7249, 3.4203, 2.4400, 1.6703, 2.7754],
         [2.4229, 2.9906, 2.1859, 1.9546, 2.4580, 1.9566, 2.0477, 1.5294],
         [1.5616, 2.3468, 1.9568, 1.8826, 2.4227, 2.0639, 1.0216, 2.6526],
         [1.3946, 3.6217, 0.4893, 2.3777, 1.6518, 3.2633, 1.6216, 3.1761],
         [1.9462, 4.6370, 1.2176, 3.8646, 3.2498, 4.5615, 1.9073, 4.4238],
         [1.3844, 2.4626, 1.9043, 1.2222, 0.7378, 2.4471, 1.8722, 3.2822],
...
         [2.0949, 2.8157, 4.1192, 2.0881, 2.6393, 1.0169, 1.6152, 1.5368],
         [1.8283, 2.0573, 2.3762, 1.9086, 3.7325, 2.7268, 2.0280, 3.4565],
         [0.9737, 2.8433, 1.9362, 2.7126, 3.1949, 2.4475, 2.1220, 3.6938],
         [3.0246, 2.2267, 3.0546, 2.2309, 4.6683, 3.7015, 2.8659, 4.0444]]])

Q4. A,B,Cを使って、sinkhoonアルゴリズムを実装してください。(出力は各テンソルに対応するスカラーが集約された10次元のテンソルになります。)

epsilon = 0.00001
solver = SinkhornSolver(epsilon=epsilon, iterations=100000)
cost, pi = solver.forward(A, B)
print(f'cost\n',cost)

cost
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])

C_dist = solver._compute_cost_pytorch_cdist(A, B)
print(C_dist)
tensor([[[[3.3452, 3.5884, 2.6145, 3.5548, 2.6811, 1.7614, 4.8377, 3.2827]],

         [[2.7055, 2.8352, 3.2904, 3.4297, 3.2852, 3.5755, 3.8415, 3.5525]],

         [[2.8125, 3.0205, 1.4525, 2.8069, 2.1878, 0.3542, 4.1632, 2.1684]],

         [[2.7713, 3.0012, 2.2180, 3.0812, 2.3985, 1.7458, 4.2431, 2.8260]],

         [[2.8290, 2.6322, 2.2476, 3.3099, 4.0334, 2.9723, 2.6648, 1.8386]],

         [[1.8975, 1.9757, 2.3589, 1.1350, 2.3386, 3.7782, 2.4262, 2.0510]],

         [[0.9523, 0.7452, 2.4652, 1.8895, 2.9092, 3.7592, 1.3614, 2.1590]],

         [[2.3381, 2.3968, 2.0447, 1.6194, 2.5042, 3.3260, 2.8266, 1.7473]]],


        [[[1.3115, 1.6532, 3.0833, 1.0828, 2.1806, 2.0015, 1.7661, 0.9011]],

         [[1.5730, 2.5654, 3.9633, 0.3884, 2.4030, 1.8086, 2.4826, 0.4883]],

         [[1.9457, 2.7556, 3.9811, 0.8660, 1.8056, 1.8166, 2.4628, 0.9498]],

         [[2.3601, 4.1735, 4.9313, 2.3770, 2.9711, 4.1579, 4.1577, 3.0857]],
...

         [[0.9737, 2.8433, 1.9362, 2.7126, 3.1949, 2.4475, 2.1220, 3.6938]],

         [[3.0246, 2.2267, 3.0546, 2.2309, 4.6683, 3.7015, 2.8659, 4.0444]]]])


0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0