drmasato2001
@drmasato2001 (M M)

Are you sure you want to delete the question?

Leaving a resolved question undeleted may help others!

gradioのエラー

Discussion

Closed

独自で画像を集めてpythorchで画像分類を行いました。推論した結果をgradioで表出したいと思いますがエラーが出ます。どこを修正すればいいでしょうか?
解決方法を教えて下さい。

発生している問題・エラー

import torch.nn.functional as F
from PIL import Image
import torch
from torchvision import transforms
import numpy as np
import gradio as gr

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

data_dir = '/home/hogehoge/data'
cls_names = ['RA', 'MPA', 'normal']  # class names
n_cls = len(cls_names)  # number of classes

model_path = 'VGG16_0.01.pth'
model = torch.load(model_path, map_location=device).to(device)

model.eval()


def predict(inp):
    inp = Image.fromarray(inp.astype('uint8'), 'RGB')
    inp = transforms.ToTensor()(inp).unsqueeze(0)
    with torch.no_grad():
        prediction = F.softmax(model(inp)[0], dim=0)
    sorted_prediction, sorted_indices = torch.sort(prediction, descending=True)
    top_labels = [cls_names[i] for i in sorted_indices[:2]]
    top_probabilities = [float(sorted_prediction[i]) for i in sorted_indices[:2]]
    return {top_labels[i]: top_probabilities[i] for i in range(len(top_labels))}


inputs = gr.inputs.Image()
outputs = gr.outputs.Label()
interface = gr.Interface(fn=predict, inputs=inputs, outputs=outputs)
interface.launch()

gradio_error.png

自分で試したこと

sorted_prediction, sorted_indices = torch.sort(prediction, descending=True)
top_labels = [cls_names[i] for i in sorted_indices[:2]]
top_probabilities = [float(sorted_prediction[i]) for i in sorted_indices[:2]]

を色々変更してみましたが解決しません。

0

Your answer might help someone💌