3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

PyTorch学習テンプレート

3
Posted at

はじめに

PyTorchのデータセット、モデル、学習の記述テンプレです。
ほぼ自分用なので参考にならないかもしれません。

バージョン

Python 3.7.9
torch 1.6.0
torchvision 0.7.0

データセット作成

自作のデータセットの作成

load_data.py
import torch

def load_data(file_path):
    '''
    データ読み込み
    '''
    return input_data, output_data

class myDataset(torch.utils.data.Dataset):
    def __init__(self, file_path):
        input_data, output_data = load_data(file_path)
        self.input = torch.tensor(input_data, dtype=torch.float)
        self.data_num = len(self.input)
        self.output = torch.tensor(output_data, dtype=torch.float)

    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        input = self.input[idx]
        output = self.output[idx]

        return input, output

モデル

ニューラルネットワークのモデル記述

model.py
import torch
import torch.nn as nn

class myModel(nn.Module):
    def __init__(self, input_dim=10, fc_dim=300, out_dim=5):
        super().__init__()

        self.fc1 = nn.Linear(input_dim, fc_dim)
        self.fc2 = nn.Linear(fc_dim, out_dim)

        self.act = nn.ReLU()

    def forward(self, input):
        out = self.fc1(input)
        out = self.act(out)
        out = self.fc2(out)

        return out

学習

train.py
import torch
import random
from model import myModel
from load_data import myDataset

torch.manual_seed(12)
random.seed(12)

######### ハイパーパラメータ ##########
# モデルパラメータ
input_dim = 16
fc_dim = 1024
out_dim = 10

# 学習時パラメータ
batch_size = 8
learning_rate = 0.001
num_epochs = 100

# 検証の間隔
val_step = 5

######### 学習準備 ##########
# dataloader作成
file_path = './dataset'
dataset = myDataset(file_path)

# データセットの2割を検証データに
n_samples = len(dataset)
index_list = list(range(n_samples))
random.shuffle(index_list)
train_index = index_list[:int(n_samples * 0.8)]
val_index = index_list[int(n_samples * 0.8):]

train_dataset = torch.utils.data.dataset.Subset(dataset, train_index)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size, shuffle=True)

val_dataset = torch.utils.data.dataset.Subset(dataset, val_index)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size, shuffle=False)

dataloaders = {'train': train_dataloader, 'val': val_dataloader}

# GPUの設定
device = torch.device(
    'cuda:0' if torch.cuda.is_available() else "cpu")

# モデルの設定
net = myModel(input_dim=input_dim, fc_dim=fc_dim, out_dim=out_dim)
net.to(device)

# 最適化手法の設定
# optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
optimizer = optim.Adam(net.parameters(), lr=learning_rate)

# 損失関数の設定
# criterion = nn.MSELoss()
criterion = nn.CrossEntropyLoss()

######### 学習 ##########
loss_list = {'train': [], 'val': []}

for epoch, phase in itertools.product(range(num_epochs), ['train', 'val']):
    if phase == 'train':
        # 訓練モード
        net.train()
    elif (epoch + 1) % val_step == 0:
        # 検証モード
        net.eval()
    else:
        continue

    # epochの損失和
    epoch_loss = 0.0

    for inputs, results in tqdm(dataloaders[phase]):
        inputs_gpu = inputs.to(device)
        results_gpu = results.to(device)

        # optimizer初期化
        optimizer.zero_grad()

        # 順伝搬
        with torch.set_grad_enabled(phase == 'train'):
            outputs = net(inputs_gpu)

            loss = criterion(outputs, results_gpu)

            # 逆伝播
            if phase == 'train':
                loss.backward()
                optimizer.step()

            epoch_loss += loss.item() * inputs[1].size(0)

    epoch_loss = epoch_loss / len(dataloaders[phase].dataset)
    loss_list[phase].append(epoch_loss)

    # 結果の表示
    print('Epoch {}/{} | {:^5} | Loss: {}'.format(epoch +
                                                  1, num_epochs, phase, epoch_loss))

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?