3
4

めちゃくちゃ速くなるんでだまされたと思ってやってみて

シンプルに画像分類モデルでやってみましょう。

結論から言うと、

runtime time(sec)
torch 0.009
torchscript 0.003
torch_tensorrt 0.001

outputは全部同じで
class: 'tiger cat'
confidence: 35.1

使用方法

インストール torch-tensorrt

pip install torch-tensorrt

モデルの変換

import torch
import torchvision.models as models
import torch_tensorrt

model = models.resnet18(pretrained=True)
model.eval().cuda()
example_input = torch.randn(1, 3, 224, 224).cuda()
optimized_model = torch.compile(model, backend="tensorrt")

推論

with torch.no_grad():
    output = optimized_model(batch_t)

1回目は遅くて、2回目から高速です。

入力の前処理

import torch
from torchvision import transforms
from PIL import Image

transform = transforms.Compose([
 transforms.Resize(224),
 transforms.ToTensor(),
 transforms.Normalize(
 mean=[0.485, 0.456, 0.406],
 std=[0.229, 0.224, 0.225]
 )])

img = Image.open("cat.jpg")
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0).cuda()

結果の後処理

import urllib
label_url = 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
class_labels = urllib.request.urlopen(label_url).read().splitlines()
class_labels = class_labels[1:] # remove the first class which is background

_, index = torch.max(output, 1)
percentage = torch.nn.functional.softmax(output, dim=1)[0] * 100


print(class_labels[index[0]], percentage[index[0]].item())

🐣

フリーランスエンジニアです。
AIについて色々記事を書いていますのでよかったらプロフィールを見てみてください。

もし以下のようなご要望をお持ちでしたらお気軽にご相談ください。
AIサービスを開発したい、ビジネスにAIを組み込んで効率化したい、AIを使ったスマホアプリを開発したい、
ARを使ったアプリケーションを作りたい、スマホアプリを作りたいけどどこに相談したらいいかわからない…

いずれも中間コストを省いたリーズナブルな価格でお請けできます。

お仕事のご相談はこちらまで
rockyshikoku@gmail.com

機械学習やAR技術を使ったアプリケーションを作っています。
機械学習/AR関連の情報を発信しています。

X
Medium
GitHub

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4