0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

実務でよく使うPyTorchの機能まとめ④|パラメータ処理(torch.nn.Linear)

Posted at

torch.nn.Linear

PyTorchで線型回帰の演算を行う場合はtorch.nn.Linearクラスを用いると良いです。torch.nn.Linearクラスは下記のように用いることができます。

import torch

linear1 = torch.nn.Linear(10, 20)
x = torch.randn(128, 10)
output = linear1(x)

print(type(x))
print(x.shape)
print(type(linear1))
print(linear1)
print(linear1.weight.shape)
print(linear1.bias.shape)
print(linear1.bias)
print(type(output))
print(output.shape)

・実行結果

<class 'torch.Tensor'>
torch.Size([128, 10])
<class 'torch.nn.modules.linear.Linear'>
Linear(in_features=10, out_features=20, bias=True)
torch.Size([20, 10])
torch.Size([20])
Parameter containing:
tensor([ 0.0950,  0.2425, -0.3109, -0.0619, -0.2536, -0.2907,  0.1110,  0.0490,
         0.0461, -0.1994,  0.0555,  0.3100, -0.1634,  0.1735,  0.2929, -0.1472,
         0.2521,  0.0934,  0.0722, -0.0661], requires_grad=True)
<class 'torch.Tensor'>
torch.Size([128, 20])

実行結果より、$128 \times 10$の入力(x)にlinear1 = torch.nn.Linear(10, 20)によって定義されたlinear1を作用させることで、$128 \times 20$の出力(output)が得られることが確認できます。ここで特に抑えておくと良いのがtorch.nn.Linearで作成されたlinear1オブジェクトからlinear1.parameterslinear1.biasを用いることでパラメータを抽出することができるということです。

PyTorch_4_1.png
https://pytorch.org/docs/stable/generated/torch.nn.Linear.html

また、linear1.parameterslinear1.biasの初期値の生成については上記を詳しく確認しておくと良いです。たとえばtorch.nn.Linear(10, 20)の場合のパラメータの値の分布のヒストグラムは下記を実行することで作成することができます。

import matplotlib.pyplot as plt
import torch
torch.manual_seed(10)

linear1 = torch.nn.Linear(10, 100)

plt.hist(linear1.weight.view(-1,1).detach().numpy(), bins=25)
plt.savefig("hist_Linear_init1.png")

・実行結果
hist_Linear_init1.png

$\sqrt{k} = 1/\sqrt{10} = 0.316 \cdots$であるので、$U(-\sqrt{k}, \sqrt{k})$に基づいて初期値のサンプリングが生成されていることが確認できます。また、結果に再現性を持たせるにあたってtorch.manual_seed(10)を追記しました。

$k = 1/\sqrt{\mathrm{infeatures}}$であるので、以下ではin_features=100の場合でも同様なコードで確認を行いました。

import matplotlib.pyplot as plt
import torch
torch.manual_seed(10)

linear1 = torch.nn.Linear(100, 10)

plt.hist(linear1.weight.view(-1,1).detach().numpy(), bins=25)
plt.savefig("hist_Linear_init2.png")

・実行結果
hist_Linear_init2.png

実行結果を確認すると、$k = 1/\sqrt{\mathrm{infeatures}} = 0.1$でサンプリングが行われていることが確認できます。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?