EarlyStoppingを用いた学習
EarlyStoppingを用いた、学習の概要は上記のようになります。
Code
for epoch in range(EPOCHS):
	train_loss = 0
	train_acc = 0
	total = 0
	model.train()
	
	for data in trainloader:
		optimizer.zero_grad()
		output = model(data[0].float())
		predicted = torch.argmax(output, axis=1)
		target = data[1]
		loss = criterion(output, target)
		train_loss += loss.item()
		train_acc += (predicted == target).sum().item()
		total += data[1].size(0)
		loss.backward()
		optimizer.step()
	train_loss = train_loss / total
	train_acc = train_acc / total
	
	print("train_acc:", str(train_acc))
	model.eval()
	vali_total = 0
	vali_loss = 0
	vali_acc = 0
	for data in validloader:
		with torch.no_grad():
			out = model.forward(data[0].float())
			predicted = torch.argmax(out, axis=1)
		target = data[1]
		loss = criterion(out, data[1])
		vali_loss += loss.item()
		vali_acc += (predicted == target).sum().item()
		vali_total += data[1].size(0)
	vali_loss = vali_loss / vali_total
	vali_acc = vali_acc / vali_total
	
	
	print("valid_acc:", str(vali_acc))
	
	earlystopping(vali_loss, model)
	if earlystopping.early_stop:
		print("Early stopping")
		break

