LoginSignup
4
1

More than 3 years have passed since last update.

軸の増やし方

Posted at

pytorch

unsqueeze()を使う.

a = torch.rand((3, 3))
a.size() # -> [3, 3]

a = a.unsqueeze(0)
a.size() # -> [1, 3, 3]

a = a.unsqueeze(1)
a.size() # -> [3, 1, 3]

numpy

reshape, newaxis, expand_dimsを使う方法がある.
reshapenewaxisを使えば複数同時に増やすことも可能.
reshapeはめんどくさいからnewaxisかな〜.

a = np.random.normal(size=(3,3))
a.shape # -> [3, 3]

# reshape
b = a.reshape(1, 3, 3)
b.shape # -> [1, 3, 3]

c = a.reshape(3, 1, 3)
c.shape # -> [3, 1, 3]

d = a.reshape(1, *a.shape)
d.shape # -> [1, 3, 3]

# newaxis
b = a[np.newaxis]
b.shape # -> [1, 3, 3]

c = a[:, np.newaxis]
c.shape # -> [3, 1, 3]

# expand_dims
b = np.expand_dims(a, 0)
b.shape # -> [1, 3, 3]

c = np.expand_dims(a, 1)
c.shape # -> [3, 1, 3]
4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1