import pickle
import pandas as pd
import numpy as np
import category_encoders as ce
import torch
from torch import nn
from data_preprocessing import data_preprocessing
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from model import Model
from earlystopping import EarlyStopping
from data_loader import *
# HyperParameter
LEARNING_RATE = 0.0001
TRAIN_BATCH_SIZE = 10
VALID_BATCH_SIZE = 5
EPOCHS = 1000
data = pd.read_csv("titanic/train.csv")
target, X, ce_ohe = data_preprocessing(data)
# データの正規化
mm = preprocessing.MinMaxScaler()
X = mm.fit_transform(X)
X_train, X_valid, y_train, y_valid = train_test_split(X, target, test_size=0.33, random_state=42)
model = Model()
earlystopping = EarlyStopping()
criterion = nn.NLLLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
trainset = DataSet(X_train, y_train)
trainloader = torch.utils.data.DataLoader(
	trainset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
validset = DataSet(X_valid, y_valid)
validloader = torch.utils.data.DataLoader(
	validset, batch_size=VALID_BATCH_SIZE, shuffle=True)
for epoch in range(EPOCHS):
	train_loss = 0
	train_acc = 0
	total = 0
	model.train()
	
	for data in trainloader:
		optimizer.zero_grad()
		output = model(data[0].float())
		predicted = torch.argmax(output, axis=1)
		target = data[1]
		loss = criterion(output, target)
		train_loss += loss.item()
		train_acc += (predicted == target).sum().item()
		total += data[1].size(0)
		loss.backward()
		optimizer.step()
	train_loss = train_loss / total
	train_acc = train_acc / total
	
	print("train_acc:", str(train_acc))
	model.eval()
	vali_total = 0
	vali_loss = 0
	vali_acc = 0
	for data in validloader:
		with torch.no_grad():
			out = model.forward(data[0].float())
			predicted = torch.argmax(out, axis=1)
		target = data[1]
		loss = criterion(out, data[1])
		vali_loss += loss.item()
		vali_acc += (predicted == target).sum().item()
		vali_total += data[1].size(0)
	vali_loss = vali_loss / vali_total
	vali_acc = vali_acc / vali_total
	
	
	print("valid_acc:", str(vali_acc))
	
	earlystopping(vali_loss, model)
	if earlystopping.early_stop:
		print("Early stopping")
		break
with open('models/ohe.pkl', 'wb') as f:
	pickle.dump(ce_ohe, f)
with open('models/mm.pkl', 'wb') as f:
	pickle.dump(mm, f)