EarlyStoppingを用いた学習
EarlyStoppingを用いた、学習の概要は上記のようになります。
Code
for epoch in range(EPOCHS):
train_loss = 0
train_acc = 0
total = 0
model.train()
for data in trainloader:
optimizer.zero_grad()
output = model(data[0].float())
predicted = torch.argmax(output, axis=1)
target = data[1]
loss = criterion(output, target)
train_loss += loss.item()
train_acc += (predicted == target).sum().item()
total += data[1].size(0)
loss.backward()
optimizer.step()
train_loss = train_loss / total
train_acc = train_acc / total
print("train_acc:", str(train_acc))
model.eval()
vali_total = 0
vali_loss = 0
vali_acc = 0
for data in validloader:
with torch.no_grad():
out = model.forward(data[0].float())
predicted = torch.argmax(out, axis=1)
target = data[1]
loss = criterion(out, data[1])
vali_loss += loss.item()
vali_acc += (predicted == target).sum().item()
vali_total += data[1].size(0)
vali_loss = vali_loss / vali_total
vali_acc = vali_acc / vali_total
print("valid_acc:", str(vali_acc))
earlystopping(vali_loss, model)
if earlystopping.early_stop:
print("Early stopping")
break