1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PyTorch:訓練モードと推論モードでモデルの挙動を変える

Last updated at Posted at 2024-06-23

self.trainingを使う

nn.Moduleを継承して作ったクラスのインスタンスはtraining属性を持っている。training属性は訓練モードであれば「True」、推論モードであれば「False」になる。

import torch
import torch.nn as nn

# nn.Moduleを継承してクラスを定義
class Network(nn.Module):
    pass

# Networkクラスのインスタンスを作成
model = Network()

↓訓練モードであれば「True」が出力される。

model.train()
print(model.training)
# 出力: True

↓推論モードであれば「False」が出力される。

model.eval()
print(model.training)
# 出力: False

これを使えば訓練モードと推論モードでモデルの挙動を変えることが出来る。特に意味はないが推論モードのときだけニューラルネットワークの出力を1000倍するようにしてみる。

# ニューラルネットワークを定義
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(5, 10)
        self.fc2 = nn.Linear(10, 1)
        self.activation = nn.ReLU()

    def forward(self, x):
        out = self.fc1(x)
        out = self.activation(out)
        # 訓練時と推論時で条件分岐
        if self.training:
            out = self.fc2(out)
        else:
            out = self.fc2(out)*1000
        return out

model = Network()

# モデルへの入力を適当に乱数で作成
input = torch.randn(5).view(1, -1)

↓訓練モードのときは「-0.3638」が出力される。

model.train()
print(model(input))
# 出力: tensor([[-0.3638]], grad_fn=<AddmmBackward0>)

↓推論モードのときは「-363.7650」が出力される。推論モードのときだけ出力が1000倍されている。

model.eval()
print(model(input))
# 出力: tensor([[-363.7650]], grad_fn=<MulBackward0>)
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?