0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Pytorch dataloader で陥ったこと

Last updated at Posted at 2023-12-12

状況

・3回目のkaggle挑戦で画像分類に初挑戦
・学習は完了したが、テストデータのevalでエラー発生
・備忘として記録

現在、Editではうまくsubmission.csvまで出力できるのに、submitするとエラーが出て採点してもらえない状況。。。なぜ?(原因わかって解決できれば投稿したいと思います)

陥ったこと

kaggleのコンペ「UBC Ovarian Cancer Subtype Classification and Outlier Detection (UBC-OCEAN)」に挑戦している時、テスト画像の予測中に下記エラーが発生した。

RuntimeError: Given groups=1, weight of size [3, 6, 1, 1], expected input[1, 3, 1024, 512] to have 6 channels, but got 3 channels instead

エラー内容と原因

エラーの内容はtorchの形状の不一致のようだった。学習ではバッチサイズ5で実施していたが、
テストの時に画像読み込みが1枚しかなく、形状が異なっていることが原因のようだ。

モデルのコード

class Model_CNN(nn.Module):
    def __init__(self, filters) -> None:
        super(Model_CNN, self).__init__()
        
        self.main = nn.Sequential(
            nn.Conv2d(3, 512, kernel_size=19, stride=11, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 256, kernel_size=13, stride=11, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 128, kernel_size=4, stride=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.classifier = nn.Sequential(nn.Linear(128, num_class))
                                
        
    def forward(self, x):
        out = self.main(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

モデルでエラー発生箇所は下記の予測のところで発生。

モデルを使った予測

pred_list = []

model.load_state_dict(torch.load(w_save_path))
model.eval()
with torch.no_grad():
    for imgs in testdataloader:  # 予測
        imgs = Variable(imgs.to(device))
        pred = model(imgs)
        pred_list.append(int(pred.max(1)[1]))

回避策

いい方法がわからなかったので、画像1枚の時に形状を変更することにした

pred_list = []

model.load_state_dict(torch.load(w_save_path))
model.eval()
with torch.no_grad():
    for imgs in testdataloader.dataset:
        if len(imgs) == 3: # ファイルが1つしかない場合
            imgs = imgs[np.newaxis, :, :, :] # 形状を一致させる
        imgs = Variable(imgs.to(device))
        pred = model(imgs)
        pred_list.append(int(pred.max(1)[1]))

その他

・他にいい方法があれば教えていただけると、助かります。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?