1
1

More than 3 years have passed since last update.

PyTorchメモ(次元管理)

Posted at

はじめに

主は深層学習初心者です。
間違いがありましたら優しく教えて頂けると幸いです。

tensorの定義

>>> import torch
>>> tensor = torch.randn(2, 3, 3)
tensor([[[ 1.5399, -0.8363,  0.3968],
         [ 0.0699,  1.1410,  0.7154],
         [ 0.4368,  0.9433, -0.8077]],

        [[ 1.1562, -1.3698,  0.6734],
         [-0.6762,  0.1539, -0.1286],
         [-0.4542,  0.3858, -1.6197]]])

tensorを次元ごとに操作

n次元目の総和を求める(次元を圧縮)

>>> sum_tensor = tensor.sum(2, keepdim=False)
tensor([[ 1.1004,  1.9262,  0.5725],
        [ 0.4599, -0.6509, -1.6881]])

torch.Size([2, 3])

n次元目に総和を求める(次元を圧縮しない)

>>> sum_tensor = tensor.sum(2, keepdim=True)
tensor([[[ 1.1004],
         [ 1.9262],
         [ 0.5725]],

        [[ 0.4599],
         [-0.6509],
         [-1.6881]]])

torch.Size([2, 3, 1])

次元を拡張

GPUを使用する場合以下を用いると高速化できます

>>> tensor.sum(2, keepdim=True).expand([3, 2, 3, 1])

tensor([[[[ 1.1004],
          [ 1.9262],
          [ 0.5725]],

         [[ 0.4599],
          [-0.6509],
          [-1.6881]]],


        [[[ 1.1004],
          [ 1.9262],
          [ 0.5725]],

         [[ 0.4599],
          [-0.6509],
          [-1.6881]]],


        [[[ 1.1004],
          [ 1.9262],
          [ 0.5725]],

         [[ 0.4599],
          [-0.6509],
          [-1.6881]]]])

まとめ

主が困ったらまた追記すると思います

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1