はじめに
- pytorchの勉強中に多次元の行列から指定の次元を一行で取り出しており、この方法を知らなかったため記事とする。
- 状況としては、shape(3,2)の行列から0番目と2番目だけを抽出したい場合などだ。
- list in listでも需要がありそうだと思った。(リスト形式ではできないが...)
【pytorch】 list in listから指定の行を取り出す方法
a = [[0,1,2], [3,4,5],[6,7,8]]
a = torch.Tensor(a)
b = [0,2]
print(a[b])
>>> tensor([[0., 1., 2.],
[6., 7., 8.]])
numpyやarrayではできるか
- numpyはpytorchと同じようにできた。
- listはできなかった。リスト内表記やfor文を使えば書ける。
# numpy
a = [[0,1,2], [3,4,5],[6,7,8]]
a = np.array(a)
b = [0,2]
print(a[b])
>>> [[0. 1. 2.]
[6. 7. 8.]]
# list
a = [[0,1,2], [3,4,5],[6,7,8]]
b = [0,2]
print(a[b])
>>> TypeError: list indices must be integers or slices, not list
print([a[j] for j in b])
>>> [[0, 1, 2], [6, 7, 8]]
おわりに
- 多次元行列から指定の行だけ取り出したい場合はありそうなため、使っていきたい。
- listでもlist内表記やfor文を使うよりは、numpyに変換後、a[b]で書いたほうがスッキリ書けると思う。