55
32

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

【PyTorch入門】第2回 autograd:自動微分

Last updated at Posted at 2019-02-23

はじめに

PyTorch入門第1回では,pytorchの基本操作を行いました.ここら辺は,numpyやmatlabを触ったことがあるので行列操作は問題ありませんでした.頑張るのはここからです.DeepLearning初心者にとっては厳しいですね….

autograd

PyTorchのニューラルネットワークはautogradパッケージが中心になっています.autogradは自動微分機能を提供します.つまり,このパッケージを使うと勝手に微分の計算を行ってくれると言うことです.
これはdefine-by-runフレームワークです.define-by-runについてはここを参照(まとめると,順伝播のコードを書くだけで逆伝播が定義できると言うことらしいです).

勾配の計算

autogradについては言葉で説明するより,実際のコードを見た方が理解が早いと思うので,簡単な式の勾配を計算するコードを参考に説明します.

##コードの全体像

import torch

# テンソルを作成
# requires_grad=Trueで自動微分対象を指定
x = torch.tensor(1.0, requires_grad=True)
w = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)

# 計算グラフを構築
# y = 2 * x + 3
y = w * x + b

# 勾配を計算
y.backward()

# 勾配を表示
print(x.grad)  # dy/dx = w = 2
print(w.grad)  # dy/dw = x = 1
print(b.grad)  # dy/db = 1
実行結果
tensor(2.)
tensor(1.)
tensor(1.)

##コードの詳細

##計算グラフの構築
計算グラフは, requires_grad=Trueで微分対象に指定されたtensorと,Functionパッケージの2つで構築されます.

# テンソルを作成
# requires_grad=Trueで自動微分対象を指定
x = torch.tensor(1.0, requires_grad=True)
w = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)

# 計算グラフを構築
# y = 2 * x + 3
y = w * x + b

###requires_grad
requires_gradでは,requires_grad=Trueとすることで微分対象のtensorを指定します.デフォルトではrequires_grad=Falseとなっており,このままだと微分の対象にはならず,勾配にはNoneが返ります.requires_grad=FalseはFine-tuningで層のパラメータを固定したいときに便利だそうです.
実際に,requires_grad=Trueと指定すると,以下のように表示されます.

x = torch.tensor(1.0, requires_grad=True)
print(x)
実行結果
tensor(1., requires_grad=True)

###grad_fn
autogradにはFunctionと言うパッケージがあります.requires_grad=Trueで指定されたtensorとFunctionは内部で繋がっており,この2つで計算グラフが構築されています.この計算グラフに計算の記録が全て残ります.生成されたtensorのそれぞれに.grad_fnという属性があり,この属性によってどのFunctionによってtensorが生成されたのかを参照できます.ただし,ユーザによって作られたtensorの場合grad_fnはNoneとなります.

a = torch.tensor(1.0, requires_grad=True)
print(a.grad_fn)

b = a+2 # bは足し算(add)によって形成
print(b.grad_fn)
実行例
None
<AddBackward object at 0x7f740ffc0eb8>

##勾配の算出

# 勾配を計算
y.backward()

# 勾配を表示
print(x.grad)  # dy/dx = w = 2
print(w.grad)  # dy/dw = x = 1
print(b.grad)  # dy/db = 1

backward()を実行すると,グラフを構築する勾配を計算し,各変数の.gradと言う属性にその勾配が入ります.

55
32
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
55
32

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?