0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

深層学習(Pytorch)を用いた、Kaggle Titanic実践 PART 9 (学習のCode)

Posted at

学習のCode

import pickle
import pandas as pd
import numpy as np
import category_encoders as ce
import torch
from torch import nn
from data_preprocessing import data_preprocessing
from sklearn.model_selection import train_test_split
from sklearn import preprocessing

from model import Model
from earlystopping import EarlyStopping
from data_loader import *

# HyperParameter
LEARNING_RATE = 0.0001
TRAIN_BATCH_SIZE = 10
VALID_BATCH_SIZE = 5
EPOCHS = 1000


data = pd.read_csv("titanic/train.csv")

target, X, ce_ohe = data_preprocessing(data)

# データの正規化
mm = preprocessing.MinMaxScaler()
X = mm.fit_transform(X)

X_train, X_valid, y_train, y_valid = train_test_split(X, target, test_size=0.33, random_state=42)

model = Model()

earlystopping = EarlyStopping()

criterion = nn.NLLLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

trainset = DataSet(X_train, y_train)
trainloader = torch.utils.data.DataLoader(
	trainset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
validset = DataSet(X_valid, y_valid)
validloader = torch.utils.data.DataLoader(
	validset, batch_size=VALID_BATCH_SIZE, shuffle=True)

for epoch in range(EPOCHS):

	train_loss = 0
	train_acc = 0
	total = 0
	model.train()
	
	for data in trainloader:

		optimizer.zero_grad()

		output = model(data[0].float())
		predicted = torch.argmax(output, axis=1)

		target = data[1]

		loss = criterion(output, target)

		train_loss += loss.item()
		train_acc += (predicted == target).sum().item()

		total += data[1].size(0)

		loss.backward()
		optimizer.step()

	train_loss = train_loss / total
	train_acc = train_acc / total
	
	print("train_acc:", str(train_acc))
	model.eval()
	vali_total = 0
	vali_loss = 0
	vali_acc = 0

	for data in validloader:

		with torch.no_grad():
			out = model.forward(data[0].float())
			predicted = torch.argmax(out, axis=1)

		target = data[1]

		loss = criterion(out, data[1])

		vali_loss += loss.item()
		vali_acc += (predicted == target).sum().item()
		vali_total += data[1].size(0)

	vali_loss = vali_loss / vali_total
	vali_acc = vali_acc / vali_total
	
	
	print("valid_acc:", str(vali_acc))
	
	earlystopping(vali_loss, model)
	if earlystopping.early_stop:
		print("Early stopping")
		break

with open('models/ohe.pkl', 'wb') as f:
	pickle.dump(ce_ohe, f)
with open('models/mm.pkl', 'wb') as f:
	pickle.dump(mm, f)
0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?