import pickle
import pandas as pd
import numpy as np
import category_encoders as ce
import torch
from torch import nn
from data_preprocessing import data_preprocessing
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from model import Model
from earlystopping import EarlyStopping
from data_loader import *
# HyperParameter
LEARNING_RATE = 0.0001
TRAIN_BATCH_SIZE = 10
VALID_BATCH_SIZE = 5
EPOCHS = 1000
data = pd.read_csv("titanic/train.csv")
target, X, ce_ohe = data_preprocessing(data)
# データの正規化
mm = preprocessing.MinMaxScaler()
X = mm.fit_transform(X)
X_train, X_valid, y_train, y_valid = train_test_split(X, target, test_size=0.33, random_state=42)
model = Model()
earlystopping = EarlyStopping()
criterion = nn.NLLLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
trainset = DataSet(X_train, y_train)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
validset = DataSet(X_valid, y_valid)
validloader = torch.utils.data.DataLoader(
validset, batch_size=VALID_BATCH_SIZE, shuffle=True)
for epoch in range(EPOCHS):
train_loss = 0
train_acc = 0
total = 0
model.train()
for data in trainloader:
optimizer.zero_grad()
output = model(data[0].float())
predicted = torch.argmax(output, axis=1)
target = data[1]
loss = criterion(output, target)
train_loss += loss.item()
train_acc += (predicted == target).sum().item()
total += data[1].size(0)
loss.backward()
optimizer.step()
train_loss = train_loss / total
train_acc = train_acc / total
print("train_acc:", str(train_acc))
model.eval()
vali_total = 0
vali_loss = 0
vali_acc = 0
for data in validloader:
with torch.no_grad():
out = model.forward(data[0].float())
predicted = torch.argmax(out, axis=1)
target = data[1]
loss = criterion(out, data[1])
vali_loss += loss.item()
vali_acc += (predicted == target).sum().item()
vali_total += data[1].size(0)
vali_loss = vali_loss / vali_total
vali_acc = vali_acc / vali_total
print("valid_acc:", str(vali_acc))
earlystopping(vali_loss, model)
if earlystopping.early_stop:
print("Early stopping")
break
with open('models/ohe.pkl', 'wb') as f:
pickle.dump(ce_ohe, f)
with open('models/mm.pkl', 'wb') as f:
pickle.dump(mm, f)