2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【pytorch備忘録】multi-target not supported at なんちゃら/ClassNLLCriterion.cu:15というCrossEntropyLoss()のエラー対策

Last updated at Posted at 2021-09-16

絶賛pytorch勉強中の人間の備忘録です。
あるデータに対しては動いていたはずの学習/検証のための関数を別のデータに適用したらCrossEntropyLoss()の部分がmulti-target not supported at なんちゃら/ClassNLLCriterion.cu:15というエラーを吐くようになりました。

修正前

(不慣れで冗長なコードですすみません・・・・)

train_valid_loop.pyの一部
def train_valid_loop(
    train_loader, valid_loader, valid_data_tensor, valid_label_tensor, model,
    n_epoch, optimizer):
    train_acc_list = []
    train_loss_list = []
    valid_acc_list = []
    valid_loss_list = []
    auc_score_list = []

    for epoch in range(n_epoch):
        train_loss = 0
        train_acc = 0
        valid_loss = 0
        valid_acc = 0
        best_auc_score = 0

        model.train()
        for xt, yt in train_loader:
            xt = xt.to(device)
            yt = yt.to(device)

            y_pred = model.forward(xt)
            loss = loss_fn(y_pred, yt)
            # errorの原因はここ↑
            # どういうわけかこの第2引数をyvのまま入力すると
            # multi-target(?)だと怒られる

            train_loss += loss.item() * xt.size(0)
            train_acc += (y_pred.max(1)[1] == yt).sum().item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        avg_train_loss = train_loss / len(train_loader.dataset)
        avg_train_acc = train_acc / len(train_loader.dataset)

        model.eval()
        with torch.no_grad():
            for xv, yv in valid_loader:
                xv = xv.to(device)
                yv = yv.to(device)

                y_pred = model.forward(xv)
                loss = loss_fn(y_pred, yv)
                # 同上のエラー

                valid_loss += loss.item() * xv.size(0)
                valid_acc += (y_pred.max(1)[1] == yv).sum().item()

            avg_valid_loss = valid_loss / len(valid_loader.dataset)
            avg_valid_acc = valid_acc / len(valid_loader.dataset)

        if epoch == 0 or (epoch + 1 ) % 10 == 0:
            train_acc_list.append(avg_train_acc)
            train_loss_list.append(avg_train_loss)
            valid_acc_list.append(avg_valid_acc)
            valid_loss_list.append(avg_valid_loss)
        # 出来たモデルでauc scoreを計算
        _,prediction = torch.max(
            model.forward(valid_data_tensor.to(device)),dim=1)#fold全体の予測値
        # tensor配列からnumpy配列に戻すときはdetach()を挟む必要アリ
        auc_score = roc_auc_score(valid_label_tensor.detach().numpy().copy(),prediction.to('cpu').detach().numpy().copy())
        if auc_score > best_auc_score:
            model_path = 'model_pth'
            torch.save(model.state_dict(), model_path)
        auc_score_list.append(auc_score)
    return train_acc_list, train_loss_list, valid_acc_list, valid_loss_list, auc_score_list

解決策

元のpd.DataFrameの訓練用ラベルをtensorに変換する手順で次元に冗長な部分があったっぽい?です。データの整形が甘かったのかも。
よってただ単に損失関数を計算する前に正解ラベルyt,yvsqueeze_()にぶち込む。

# 該当箇所1
yt = yt.squeeze_()
loss = loss_fn(y_pred, yt)

# 該当箇所2
yv = yv.squeeze_()
loss = loss_fn(y_pred, yv)

これで正常に動いてくれました。

参考ページ

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?