LoginSignup
1

タイタニック問題 - 基本編 -

Last updated at Posted at 2022-10-19

タイタニック問題 (Pytorch)

タイタニック問題を解くモデルを作ります。
まずは、深いこと考えずにとりあえず形にします。
作成したソースコードはこちら

データセットの用意

データセットはこちらなどからダウンロードできると思います。
データの形式などはいろいろネット上に載っているため割愛。

データセットの読み込み

まずは、データセットを読み込みPytorchのDatasetクラスで読み込めるようにします。
Datasetクラスを継承した、「TitanicDataSet」クラスを用意します。

titanic.py
class TitanicDataSet(Dataset):
    """
    タイタニック問題データセットクラス
    """
    __data = []
    __is_test = False

    def __init__(self, data_filepath, is_test=False):
        self.__data = pd.read_csv(data_filepath)
        # 欠損データを補完
        self.__data["Age"] = self.__data["Age"].fillna(self.__data["Age"].median())
        self.__data["Embarked"] = self.__data["Embarked"].fillna("S")

        self.__is_test = is_test

    def __len__(self):
        return len(self.__data)

    def __getitem__(self, idx):
        passId = self.__data.at[idx, "PassengerId"]
        sex = self.__trans_sex(self.__data.at[idx, "Sex"])
        age = self.__data.at[idx, "Age"]
        related_person = self.__data.at[idx, "SibSp"] + self.__data.at[idx, "Parch"]  # 乗車した家族の数
        pclass = self.__data.at[idx, "Pclass"]  # チケットクラス (上級: 1 , 中級: 2, 下級: 3)
        if self.__is_test:
            label = passId
        else:
            label = self.__data.at[idx, "Survived"]  # 0:死亡, 1:生存
        data = [sex, age, related_person,pclass]
        tdata = torch.tensor(data,dtype=torch.float32)

        if self.__is_test:
            return tdata, label

        return tdata, torch.tensor(label, dtype=torch.float32)

    def __trans_embarked(self,value):
        if value == "S":
            return 0
        elif value == "C":
            return 1
        elif value == "Q":
            return 2
        else:
            print("Unknown Value")
            return -1

    def __trans_sex(self,value):
        if value == "male":
            return 0
        elif value == "female":
            return 1
        else:
            print("Unknown Value")
            return -1

今回は、データセットの内、以下のデータを利用します。
本来であれば、各値を統計的に確認しますが、今回はスキップ。

性別
乗船者の性別
年齢
乗船者の年齢
関係者の数
(乗船していた兄弟、配偶者の数) + (乗船していた両親、子供の数)
チケットクラス
チケットクラス (上級, 中級, 下級)

また、データは一部に欠損データがあるため、補完します。
Pytorchで扱うために、Tensor型にしておきましょう。

モデル作成

今回は、適当に以下のようなモデルを作成します。
(ドロップアウトなどは一旦なしで)

活性化関数はランプ関数を利用します。

titanic.py
class TitanicModel(nn.Module):
    def __init__(self):
        super(TitanicModel, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(4, 25),
            nn.ReLU(),
            nn.Linear(25, 20),
            nn.ReLU(),
            nn.Linear(20, 2)
        )

    def forward(self, x):
        logits = self.network(x)
        return logits

学習させる

上記で用意したデータセットクラスとモデルで学習します。

まずは、学習データのを読み込み、学習に利用するデータを、学習用のデータと検証データに分割します。
検証用のデータは、未学習のデータに対してのモデルの性能を確認するために利用します。

titanic.py
 # データセット読み込み
    train_data = TitanicDataSet(TRAIN_DATA_PATH, False)
    # テストデータ読み込み
    test = TitanicDataSet(TEST_DATA_PATH, True)

    train, valid = torch.utils.data.random_split(
        train_data,
        [SPLIT_SIZE, TOTAL_TRAIN_DATA_SIZE-SPLIT_SIZE]
    )

また、今回は、損失関数に「CrossEntropy」、最適化は「確率的勾配降下法」を利用します。

titanic.py
    # 損失関数設定
    loss_fn = nn.CrossEntropyLoss()
    # オプティマイザー設定
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

次に学習部分を書いていきます。

titanic.py
def exec_train(dataloader, model, loss_fn, optimizer):
    """
    モデルに対して学習を実行します
    :param dataloader: データローダー(学習用)
    :param model: モデルインスタンス
    :param loss_fn: 損失関数
    :param optimizer: オプティマイザー
    :return: 損失値
    """
    for batch, (data, label) in enumerate(dataloader):
        label = label.float()
        X, label = data.to(device), label.to(device)

        # 損失誤差を計算
        pred = model(X)

        loss = loss_fn(pred, label)

        # バックプロパゲーション
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return loss.item()

あとは、上記の学習用コードをぶん回します。

titanic.py
for t in range(EPOCH_NUM):
    print("epoch : " + str(t))
    loss = exec_train(train_dataloader, model, loss_fn, optimizer)
    loss, correct_rate = exec_validation(valid_dataloader, model, loss_fn)

検証データにおけるlossと正答率は以下のようになりました。(とりあえず1000回ぶん回しました)

正答率は75%程度ですかね。

以上です。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
1