LoginSignup
12
12

More than 5 years have passed since last update.

CSVファイルからChainer用のデータセットを作成する

Last updated at Posted at 2017-01-05

概要

ChainerでローカルのCSVファイルを扱う方法です。
Chainerの公式サイトに載っているMNISTのサンプルはライブラリから直接データを取得していたので、CSVファイルを扱う方法が分からず調べた次第です。

環境

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.11.6
BuildVersion:   15G1212
$ python --version
Python 3.5.2 :: Anaconda 4.1.1 (x86_64)
$ pip list --format=legacy | grep chainer
chainer (1.18.0)

学習に使用するデータ

kaggleのチュートリアル的存在である、タイタニック号の乗客の生死を予測するコンペのデータを使いました。

Titanic: Machine Learning from Disaster

「train.csv」と「test.csv」が公開されてますが、「test.csv」は生死の結果が載っていないので「train.csv」のみを使いました。
CSVファイルを扱うことが目的なので精度は求めません。

コード全体

import numpy as np
import pandas as pd
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions

# csvファイルの読み込み
data_f = pd.read_csv('train.csv', header=0)

# 関係ありそうなPClass,Sex,Ageのみを使う
data_f = data_f[["Pclass", "Sex", "Age", "Survived"]]

# Ageの欠損値を中央値で補完
data_f["Age"] = data_f["Age"].fillna(data_f["Age"].median())
# maleは1, femaleは0に置換
data_f["Sex"] = data_f["Sex"].replace("male", 1)
data_f["Sex"] = data_f["Sex"].replace("female", 0)

data_array = data_f.as_matrix()

X = []
Y = []
for x in data_array:
    x_split = np.hsplit(x, [3,4])
    X.append(x_split[0].astype(np.float32))
    Y.append(x_split[1].astype(np.int32))

X = np.array(X)
Y = np.ndarray.flatten(np.array(Y))

# 891個のデータのうち623個(7割)を訓練用データ、残りをテスト用データにする
train, test = datasets.split_dataset_random(datasets.TupleDataset(X, Y), 623)
train_iter = iterators.SerialIterator(train, batch_size=100, shuffle=True)
test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)

class MLP(Chain):
    def __init__(self):
        super(MLP, self).__init__(
            l1=L.Linear(3, 100),
            l2=L.Linear(100, 100),
            l3=L.Linear(100, 2),
        )

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        y = self.l3(h2)
        return y

model = L.Classifier(MLP())
optimizer = optimizers.SGD()
optimizer.setup(model)

updater = training.StandardUpdater(train_iter, optimizer)
trainer = training.Trainer(updater, (30, 'epoch'), out='result')

trainer.extend(extensions.Evaluator(test_iter, model))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())

trainer.run()

学習結果

一応、学習結果です。

$ python mlp.py
epoch       main/accuracy  validation/main/accuracy
1           0.58           0.708235
2           0.593333       0.708235
3           0.611667       0.704902
4           0.596667       0.787843
5           0.652857       0.688431
6           0.653333       0.676863
7           0.651667       0.676863
8           0.626667       0.680196
9           0.642857       0.773137
10          0.728333       0.673529
11          0.691667       0.734902
12          0.611667       0.746667
13          0.685          0.637255
14          0.641429       0.685098
15          0.643333       0.402549
16          0.631667       0.726667
17          0.661667       0.705294
18          0.628571       0.695294
19          0.656667       0.615686
20          0.645          0.676863
21          0.578333       0.751569
22          0.635714       0.769804
23          0.661667       0.672157
24          0.611667       0.774706
25          0.69           0.676863
26          0.628333       0.585882
27          0.634286       0.708627
28          0.611667       0.417451
29          0.663333       0.753137
30          0.68           0.743137

やはりイマイチです。

12
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
12