0
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

深層学習(Pytorch)を用いた、Kaggle Titanic実践 PART 10 (テストのCode)

Posted at

予測するためのCode

import pickle
import pandas as pd
import numpy as np
import category_encoders as ce
import torch
from torch import nn
from data_preprocessing import data_preprocessing_test
from sklearn.model_selection import train_test_split
from sklearn import preprocessing

from model import Model
from earlystopping import EarlyStopping
from data_loader import *

model = Model()
model.load_state_dict(torch.load("models/model.pth"))
with open('models/ohe.pkl', 'rb') as f:
	ce_ohe = pickle.load(f)
with open('models/mm.pkl', 'rb') as f:
	mm = pickle.load(f)

model.eval()

data = pd.read_csv("titanic/test.csv")
X_test = data_preprocessing_test(data, ce_ohe)
X_test = mm.transform(X_test)

id_lst = list(data["PassengerId"].values)
pred_lst = list()
for data in X_test:
	res = model(torch.tensor(data).unsqueeze(0).float())
	pred_lst.extend(np.argmax(np.exp(res.detach().numpy()), axis=1))

sub_df = pd.DataFrame(np.array(id_lst+pred_lst).reshape(2, -1))
sub_df = sub_df.T
sub_df = sub_df.rename(columns={0: 'PassengerId', 1:"Survived"})
print(sub_df.head())
sub_df.to_csv("sub_dl.csv", index=False)
0
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?