12
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchのSoftmax関数で軸を指定してみる

Last updated at Posted at 2021-01-09

#はじめに
掲題の件、調べたときのメモ。

#環境

  • pytorch 1.7.0

#軸の指定方法
nn.Softmax クラスのインスタンスを作成する際、引数dimで軸を指定すればよい。

#やってみよう
今回は以下の配列を例にやってみる。

input = torch.randn(2, 3)
print(input)
tensor([[-0.2562, -1.2630, -0.1973],
        [ 0.8285, -0.9981,  0.3171]])

##dimを指定しない場合

m = nn.Softmax()
print(m(input))

こんな風に怒られる。

/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:2: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.

##dim=0を指定した場合

m = nn.Softmax(dim=0)
print(m(input))

列単位でSoftmaxをかけてくれる。

tensor([[0.2526, 0.4342, 0.3742],
        [0.7474, 0.5658, 0.6258]])

念のため列単位で集計をすると、各列合計が1になる。

torch.sum(m(input), axis=0)
tensor([1., 1., 1.])

##dim=1を指定した場合

m = nn.Softmax(dim=1)
print(m(input))

行単位でSoftmaxをかけてくれる。

tensor([[0.4122, 0.1506, 0.4372],
        [0.5680, 0.0914, 0.3406]])

念のため行単位で集計すると、各行合計が1になる。

torch.sum(m(input), axis=1)
tensor([1.0000, 1.0000])
12
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?