yamadai0127
@yamadai0127 (ダイスケ)

Are you sure you want to delete the question?

If your question is resolved, you may close it.

Leaving a resolved question undeleted may help others!

We hope you find it useful!

初めて自分でAIを作っていて解決できなかったエラー

解決したいこと

PyTorchを使ってフルーツの名前を当てるAIを作っているのですが、学習を開始するところでエラーが出ました。自分でいろいろ試しましたがエラーが直りません。

発生している問題・エラー

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-12-5d60f5e3d878> in <module>
      6         optimizer.zero_grad()
      7         outputs = net(inputs.to(device))
----> 8         loss = criterion(outputs, labels.to(device))
      9         loss.backward()
     10         optimizer.step()

AttributeError: 'tuple' object has no attribute 'to'

該当するソースコード

loss = criterion(outputs, labels.to(device))

自分で試したこと

まず、以下のコードが今回作っている「フルーツの名前を当てるAI」です。
画像のデータセットはこちらを使用しました。↓
https://www.kaggle.com/datasets/karimabdulnabi/fruit-classification10-class?resource=download

import os
import torch
import random
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import torchvision.transforms as transforms
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import torch.optim as optim
%matplotlib inline

device = 'cuda:0'
class FruitDataset(Dataset):
    def __init__(self, directory):
        self.directory = directory
        self.transform = self.transform()
        self.label, self.label_to_index = self.findClasses()
        self.img_path_and_label = self.createImgPathAndLabel()
    
    def __len__(self):
        return len(self.img_path_and_label)

    def __getitem__(self, index):
        img_path, label = self.img_path_and_label[index]
        img = Image.open(img_path)
            
        if self.transform:
            img = self.transform(img)
        
        return img, label
    
    #画像が保存されているフォルダーの名前に対して数字(index)を割り当てるコード
    def findClasses(self):
        classes = [d.name for d in os.scandir(self.directory)]
        classes.sort()
        class_to_index = {class_name: i for i, class_name in enumerate(classes)}
        return classes, class_to_index
    
    #画像とラベルを対にするコード
    def createImgPathAndLabel(self):
        if self.directory:
            img_path_and_labels = []
            directory = os.path.expanduser(self.directory)
            for target_label in sorted(self.label_to_index):
                label_index = self.label_to_index[target_label]
                target_dir = os.path.join(directory, target_label)

                for root, _, file_names in sorted(os.walk(target_dir, followlinks = True)):
                    for file_name in file_names:
                        img_path = os.path.join(root, file_name)
                        img_path_and_label = img_path, target_label
                        img_path_and_labels.append(img_path_and_label)
            
            random.shuffle(img_path_and_labels)

        return img_path_and_labels
    
    def transform(self):
        transform = transforms.Compose([transforms.CenterCrop(128),
                                        transforms.Grayscale(1),
                                        transforms.ToTensor(),
                                    ])
        return transform
train_directory = './drive/MyDrive/DeepLearning/fruit/train'
train_dataset = FruitDataset(train_directory)
print(train_dataset[0])

test_directory = './drive/MyDrive/DeepLearning/fruit/test'
test_dataset = FruitDataset(test_directory)
print(test_dataset[0])
criterion = nn.CrossEntropyLoss().to(device)
train_dataloader = DataLoader(train_dataset, batch_size=500, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=100, shuffle=True)
class FruitNet(nn.Module):
    def __init__(self):
        super(FruitNet, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, 4, 2, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(16, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        self.lin = nn.Sequential(
            nn.Linear(8*8*256, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)            
        )
        
    
    def forward(self, x):
        x = self.cnn(x)
        x = x.view(-1, 8*8*256)
        x = self.lin(x)
        
        return x
net = FruitNet().to(device)
print(net)
print(summary(net, (1, 128, 128)))
loss_list = []
test_loss_list = []
acc_list = []
base_epoch = 0

optimizer = optim.Adam(params=net.parameters(), lr=0.001)
for epoch in range(30):
    net.train()
    total_loss = 0
    for data in train_dataloader:
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    loss_list.append(total_loss/len(train_dataset))

    correct = 0
    total = 0
    total_test_loss = 0
    net.eval()
    for data in test_dataloader:
        inputs, labels = data
        outputs = net(inputs.to(device))
        test_loss = criterion(outputs, labels.to(device))
        total_test_loss += test_loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(device)).sum()
    
    test_loss_list.append(total_test_loss/total)
    acc_list.append(float(correct)/total)


    print('{}epoch : train_loss:{}, test_loss:{}, test_acc:{}'.format(
        base_epoch + epoch + 1,
        total_loss/len(train_dataset),
        total_test_loss/len(test_dataset),
        float(correct)/total
    ))

base_epoch += epoch + 1

試してみたこと

  • loss = criterion(outputs, labels.to(device))をloss = criterion(outputs, labels)にしてみたがエラー直らず。

  • loss = criterion(outputs, labels.to(device))をloss = criterion(outputs, torch.tensor(labels).to(device))にしてみたがエラー直らず。

  • あえてcriterion = nn.CrossEntropyLoss().to(device)をcriterion = nn.CrossEntropyLoss()にしてみたが変わらなかった。

0

2Answer

@yamadai0127
FruitDatasetの修正が必要ですね。
createImgPathAndLabelのラベルの型を <class 'torch.Tensor'>に変換すればできますよ。
一応プログラム置いておきます。

class FruitDataset(Dataset):

    def createImgPathAndLabel(self):
        if self.directory:
            img_path_and_labels = []
            directory = os.path.expanduser(self.directory)
            for target_label in sorted(self.label_to_index):
                label_index = self.label_to_index[target_label]
                target_dir = os.path.join(directory, target_label)
      target_label_torch = torch.tensor( target_label)
                for root, _, file_names in sorted(os.walk(target_dir, followlinks = True)):
                    for file_name in file_names:
                        img_path = os.path.join(root, file_name)
                        img_path_and_label = img_path, target_label_torch 
                        img_path_and_labels.append(img_path_and_label)
            
            random.shuffle(img_path_and_labels)

        return img_path_and_labels

2Like

Comments

  1. `target_label_torch = torch.tensor( target_label )`インデントバグってますね。。
    `target_dir = os.path.join(directory, target_label)`と同じところにインデントしてください〜
  2. @yamadai0127

    Questioner

    ありがとうございます!!

細かく見ていませんが。。。

AttributeError: 'tuple' object has no attribute 'to'

labelstuple型だけれども、toなんて属性(メソッド)はないよ、ってことですね。

1Like

Comments

  1. @yamadai0127

    Questioner

    はい、どこを変更すればいいのかが分からなくて困ってます。
  2. フワフワっとした回答ですみません。

    `loss = criterion(outputs, labels.to(device))`

    の第2パラメータに何を指定したらいいか(型は何か)、これが肝かと思います。

    「loss = criterion(outputs, labels.to(device))をloss = criterion(outputs, labels)にしてみたがエラー直らず。」とありますが、その際のエラーメッセージがヒントになるんじゃないですかね。
  3. @yamadai0127

    Questioner

    なるほど、やってみます。親切にアドバイスしてくださり、ありがとうございます!

Your answer might help someone💌