LoginSignup
13
8

More than 1 year has passed since last update.

YOLOv7はYOLOシリーズで最強か【第二版】

Last updated at Posted at 2022-09-25

第二版について

本記事を最初に投稿した後、実行環境として、V100 と A100 が追加で使用することができ、特に、V100 については、YOLOv7 の論文での掲載値と比較することが可能となりました。これらの推論時間の実測によって、あらたな知見も得られたため、第二版として、データの追加と考察内容の更新を行いました。

はじめに

YOLOv7 の論文(以下、論文といいます)が2022年7月に公開されましたが、そこでは、網羅的に YOLO シリーズの比較がなされています。論文著者は、それら比較対象の YOLO を V100 上で動作させ、YOLOv7 が最も高速で、かつ最も精度が高いことを示しています。
キャプチャ.JPG

これは十分に有用な情報ですが、これらの比較がどの程度条件をあわせて行ったものか論文ではわかりませんでした。実際、推論実行時に性能に関係すると思われる FP16 や Fusion (畳み込みとバッチノーマライゼーションを融合するもの) の扱いがそれぞれの YOLO で異なりますし、以前、ある論文を読んでよさそうに思ったモデルが実際にはそうでもなかったという残念な経験が筆者にはあります。そこで、ここでは、Google Colab/Google Colab Pro を使って、YOLOv7 に加え他の YOLO シリーズのモデルを動作させ、推論速度と精度の関係を確認していきます。クラウド環境は性能比較には最適ではないですが、傾向把握には問題ないと思いますので、この比較を通して Web 記事のタイトル「 YOLOv7 は YOLO シリーズで最強か」の答えを出そうと思います。

比較対象のモデル

論文 Table 2 では、他の YOLO との比較が記載されていますが、この中でPPYOLOE以外を対象として、Google Colab/Google Colab Pro で動作させます。YOLOv5YOLOR論文 Figure 1 でも良い値をだしていること、YOLOX は最近取り上げられることがよくあることが、これらモデルの選択理由です。Google Colab では Tesla T4、Google Colab Pro では Tesla A100/V100/P100 の GPU での推論時間測定を行いました。APval はこれらで共通に値に記載されているため、精度比較としてこれを使います。Table 2 では、FPS ( Frame per Second ) が記載されていますが、論文の Figure 1 にあわせて1フレームあたりの推論時間 ( FPS 逆数) に変換して使用します。

なお、YOLOv7-tiny の事前学習モデルは、活性化関数が LeakyReLU になっており、論文 の SiLU とは違っていますが、ここでは APval をそのまま流用します。

table2.png

対象モデルでの FP16 と Fusion の扱い

推論時間の実行は、各 YOLO タイプで提供されているコードを使いますが、デフォルトでは以下の状況です。

YOLOタイプ 対象コード FP16 の扱い Fusion の扱い
YOLOv5 detect.py 指定できる。 指定できない。また、detect.pyで使われているモデルではFusion未サポート。
YOLOv7 detect.py 指定できない。GPU搭載時に強制的に FP16。 指定できない。強制的にFusionする。
YOLOR detect.py 指定できない。GPU搭載時に強制的に FP16。 指定できない。Fusion機能はある。
YOLOR
(1280画像サイズ)
detect.py 指定できない。GPU搭載時に強制的に FP16。 指定できない。強制的にFusionする。
YOLOX tools/demo.py 指定できる。 指定できるが、FP16 と組み合わせるとエラーになる。

そこで、それぞれの YOLO タイプで FP16 と Fusion を指定できるよう修正しました。コードは補足に記載しました。

各モデルの推論時間

各モデルでの実行条件をあわせるため、上記の修正を行って、FP32 か FP16 か、Fusion するかしないかを選択できるようにし、YOLOv7YOLOv5YOLORYOLOX の単位で最も高速で実行できるパターンで比較することとします。

以下に実行結果をまとめました。1フレームを処理するための推論時間 (ミリ秒) を小数点以下を切り捨てて記載しています。A100、V100、T4 では、FP16 / Fusion の組み合わせ、P100 では、FP32 / Fusion の組み合わせがそれぞれ総ての YOLO タイプで高速でしたので、この組み合わせを以下使用します。

測定結果2.JPG

先の Table 2 で対象とするモデルを抜き出し、上記の実測値をあわせて記載したものが以下です。推論時間の単位はミリ秒です。

比較テーブル.JPG

グラフ表示

V100 での論文掲載値と実測値

左グラフは論文に掲載された V100 での推論実行時間(FPS の逆数)、右グラフは筆者実測値(FP16/Fused の実測パターン)です。

V100比較.JPG

左上に向かうに従って速度も速く精度も良くなりますので、YOLOv7 は実測値でも他モデルに比較して左上に位置しているため優位であることは論文の主張と同じですが、画像サイズ640(グラフ内の青と赤の実線)のうちでも小さいモデルでは、YOLOv5YOLOv7 を上回っています。YOLOv5-S と YOLOv7-tiny、YOLOv5-L と YOLOv7 はそれぞれ以下のとおり、#Param. や FLOPs が近いため、今回の実測値のように近い値になっていることは納得性が高いと思います。このレベルの差は測定誤差の可能性もありますが、いずれにしても、論文で示されているほどの顕著な差はなさそうです。

v5s_v7tiny.JPG

推論時間と精度

次に、論文の Figure 1 のように推論時間と精度を軸としたグラフを A100、V100、P100、T4 それぞれについて作成します。サイズ 640 は実線、サイズ 1280 は破線で示しています。ここでは、A100、V100、T4 では、FP16 / Fusion の組み合わせ、P100 では、FP32 / Fusion の組み合わせでの測定値を使っています。

GPU比較.JPG

上述のとおり、YOLOv7 / YOLOv5-L あたりを基準とし、便宜的にそれ未満を下位モデル、それ以上を上位モデルとします。具体的に FLOPs でみると、100 GFLOPsが基準となりますので、下位モデルの対象は、YOLOX-M 以下、YOLOv5-M 以下、YOLOv5-S6 以下、YOLOv7-tiny です。 YOLOR は下位モデルの対象はありません。

下位モデル

V100 では、YOLOv5 優位、その他のGPUでは、YOLOv7 が優位です。YOLOX はいずれの GPU でも YOLOv7YOLOv5 に及びません。

上位モデル

いずれの GPU でも YOLOv7 が優位です。
二位以下は、YOLOv5YOLOR は拮抗しています。YOLOX はいずれの GPU でも YOLOv7YOLOv5YOLOR に及びません。

推論時間と計算量

推論時間と計算量の関係を、測定パターン(FP16 / Fused、FP32 / Fused、FP16 / Not Fused、FP32 / Not Fused)でグラフ化し、A100、V100、P100、T4 の順番に並べました。

LvsG.JPG

FP16 と Fusion の効果

今回、条件をそろえる対象とした FP16 と Fusion の推論速度に対する効果を GPU 毎にみたいと思います。
A100 と V100 では、FP16 と Fusion の効果がそれぞれでています。特に、V100 では、FP16 の効果が顕著にでています。
T4 でも FP16 の効果が顕著ですが、Fusion はさほどではありません。
P100 では、FP16 か FP32 か、Fusion するかしないかは他のGPUと比較して顕著な違いとしてはあらわれていません。この特性のために、他のGPU では、FP16 / Fusion の組み合わせが最も良い数値だったにも関らず、P100 のみ FP32 / Fusion の組み合わせがベストとなったようです。

スケーラビリティ

A100 と V100 は計算量が大きくてもうまくスケールしています。FP16 を使うことで、T4 でも P100 以上にスケールする様子がわかります。以下、参考までに、800 GFLOPs を超えるモデルの測定値を抜き出しています。

scalability.JPG

まとめ

以上のように、ここでは FP16 使用有無、Fusion 有無や、その他対象とする画像、各種パラメータをなるべく同じにそろえて各モデルを比較しました。その結果、画像サイズ640 の下位モデルでは YOLOv5 と拮抗するものの、全般的には YOLOv7 が最も推論速度と精度のバランスに優れていることが確認できました。

ということで、論文に示されている精度が正しい場合、YOLO シリーズの現時点での代表ともいえるYOLOv5YOLORYOLOX に対して優位性があることが確認できましたので、本 Web 記事のタイトル「YOLOv7 は YOLO シリーズで最強か」は、現時点では概ね正しいと思います。ただ、他モデルとの差は、論文に記載されているほどには大きくはないように思います。
なお、YOLOv7YOLOv5 に比較して下位モデルが手薄なので、もう少しラインアップが加わると使い勝手もさらに良くなるのはないか、と思いました。

補足:コード

実行は Google Colab、Google Colab Pro を使って行いました。それぞれのYOLOモデルのインストールやコード修正、推論実行をご参考までに以下に記載します。

YOLOv5

YOLOv5のインストール

os.chdir('/content')
!git clone https://github.com/ultralytics/yolov5  # clone
os.chdir('yolov5')
!pip install -r requirements.txt  # install

detect.pyの変更

筆者が本 Web 記事を執筆中も、頻繁にコードがアップデートされています。2022年9月25日現在のコードに対する差分を以下記載します。

importの追加

オリジナルコードではDetectMultiBackend() でモデルを生成しますが、fusion がサポートされていないので、fusion が使える DetectionModel() をかわりに使います。また、モデル生成時に yaml を使うのでここで import します。

from models.yolo import DetectionModel
import yaml
def run()の引数追加

引数の最後に以下を追加します。

        fuse=False,  # use fusing
model の生成

オリジナルでは以下でモデルを生成します。

    model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
    stride, names, pt = model.stride, model.names, model.pt

これを以下に変更します。

    # model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
    # stride, names, pt = model.stride, model.names, model.pt
    model_name = weights[0].split('.')[0]
    addon_path = 'hub/' if model_name[-1] == '6' else ''
    yaml_path = '/content/yolov5/models/' + addon_path + model_name + '.yaml'
    with open(yaml_path, 'r') as yml:
        config = yaml.safe_load(yml)
    model = DetectionModel(config)
    model.to(device)
    if fuse:
        model.fuse()
        print('Model fusing')
    if half:
        model.half()
        print('Model half')
    stride, names, pt = int(max(model.stride)), model.names, True
model.warmup() の代替え

DetectionModel() には warmup() が実装されていないので、以下のコードに置き換えます。

    # model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))  # warmup
    if device.type != 'cpu':
        model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters())))  # run once
画像変換

DetectionModel() には、device 属性、fp16 属性がないので、以下のコードに置き換えます。

            # im = torch.from_numpy(im).to(model.device)
            # im = im.half() if model.fp16 else im.float()
            im = torch.from_numpy(im).to(device)
            im = im.half() if half else im.float()  # uint8 to fp16/32
推論時間の取得

推論時間は、run() の最後で t[1] で入手できます。セーブしても良いですし、ここでは簡単にプリントしたコードです。

    print('Inference:',t[1])
def parse_opt() への追加

parser.add_argument の最後に以下を追加して fuse を指定できるようにします。

    parser.add_argument('--fuse', action='store_true', help='fusing')

推論実行

--source には入力画像(ビデオ)を指定します。--weightsに事前学習モデルを指定します。fp16 使用時は --halffusion--fuse でそれぞれ指定します。
画像サイズを 1280 で実行する場合、--img 1280 を指定します。

!python detect.py --source images/kite.jpg --weights yolov5s.pt --conf 0.25 --iou 0.45

YOLOv7

YOLOv7のインストール

os.chdir('/content')
!git clone https://github.com/WongKinYiu/yolov7.git > /dev/null
os.chdir('yolov7')
#%cd yolov7
!pip3 install -U pip && pip3 install -r requirements.txt > /dev/null
!pip3 install -v -e . > /dev/null # or  python3 setup.py develop

detect.pyの変更

引数受領

detect() の最初で opt から引数を取得しますが、half と fuse を追加で取得します。

    # source, weights, view_img, save_txt, imgsz, trace = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, not opt.no_trace
    source, weights, view_img, save_txt, imgsz, trace, half, fuse = \
    opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, not opt.no_trace, opt.half, opt.fuse
half の反映

half の定義に引数の値を反映します。

    # half = device.type != 'cpu'  # half precision only supported on CUDA
    half = device.type != 'cpu' and half # half precision only supported on CUDA

また、少し後ろで、model.half() を実行する箇所がありますが、実行がわかるようにプリント文を追加します。

    if half:
        model.half()  # to FP16
        print('Model half') # 追加
fuse の反映

原実装では必ず fusion するので、fuse の指定が反映できるよう、attempt_load() の引数に追加します。また、fuse の指定があった場合にプリントします。

    #model = attempt_load(weights, map_location=device)  # load FP32 model
    model = attempt_load(weights, map_location=device, fuse=fuse)
    if fuse:
         print('Model fused')
結果の出力

ここでは、すでに実装されている t1、t2 を使って推論時間を出力します。detect() の最後に追加します。なお、推論対象の画像が複数枚あったとしても、このコードでは最後の画像の推論時間の出力になります。

    print('Inference:',t2 - t1)
parser.add_argument() 追加

half と fuse を指定できるよう、以下のコードを parser.add_argument() の最後に追加します。

    parser.add_argument('--half', action='store_true', help='half precision')
    parser.add_argument('--fuse', action='store_true', help='fusion')

models/experimental.py の変更

fuse の指定を反映します。

def attempt_load(weights, map_location=None, fuse=True):
    # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
    model = Ensemble()
    for w in weights if isinstance(weights, list) else [weights]:
        attempt_download(w)
        ckpt = torch.load(w, map_location=map_location)  # load
        if fuse:
            model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval())  # FP32 model
        else:
            model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval())  # FP32 model
(snip)

推論実行

--img-size はモデルに応じて 640 あるいは 1280 を指定します。fp16 使用時は --halffusion--fuse でそれぞれ指定します。

!python detect.py --weights yolov7.pt --conf 0.25 --iou 0.45 --img-size 640 --source images/kite.jpg --no-trace

YOLOR (yolor_csp、yolor_csp_star、yolor_p6)

YOLOR のインストール

os.chdir('/content')
!git clone https://github.com/WongKinYiu/yolor

os.chdir('yolor')

# pip install required packages
!pip install -qr requirements.txt

# install mish-cuda if you want to use mish activation
# https://github.com/thomasbrandon/mish-cuda
# https://github.com/JunnYu/mish-cuda
!git clone https://github.com/JunnYu/mish-cuda
os.chdir('mish-cuda')
!python setup.py build install
os.chdir('..')

# install pytorch_wavelets if you want to use dwt down-sampling module
# https://github.com/fbcotter/pytorch_wavelets
!git clone https://github.com/fbcotter/pytorch_wavelets
os.chdir('pytorch_wavelets')
!pip install .
os.chdir('..')

事前学習モデルの取得

画像サイズ640 向けの事前学習モデル(yolor_csp.pt、yolor_csp_star.pt)と画像サイズ1280向けの事前学習モデルのうち yolor_p6.ptは、ここからダウンロードできます。

detect.pyの変更

引数受領

detect() の最初で opt から引数を取得しますが、half と fuse を追加で取得します。

    # out, source, weights, view_img, save_txt, imgsz, cfg, names = \
    #     opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, opt.cfg, opt.names
    out, source, weights, view_img, save_txt, imgsz, cfg, names, half, fuse = \
        opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, opt.cfg, opt.names, opt.half, opt.fuse
half と fuse の反映

half の定義に引数の値を反映します。

    # half = device.type != 'cpu'  # half precision only supported on CUDA
    half = device.type != 'cpu' and half # half precision only supported on CUDA

また、少し後ろで、model.to(device).eval() となっている箇所を以下のようにします。model.half() を実行する箇所では、実行がわかるようにプリント文を追加します。

    # model.to(device).eval()
    if fuse: # 追加
        model.fuse() # 追加
        print('Model Fusing') # 追加
    if half:
        model.half()  # to FP16
        print('Model half') # 追加
    model.eval() # 追加
推論実行後の時間採取
        pred = model(img, augment=opt.augment)[0]
        t11 = time_synchronized() # 追加
結果の出力

t1 と新たに追加した t11 を使って推論時間を出力します。detect() の最後に追加します。なお、YOLOv7 同様、推論対象の画像が複数枚あったとしても、このコードでは最後の画像の推論時間の出力になります。

    print('Inference:', t11 - t1)
parser.add_argument() 追加

half と fuse を指定できるよう、以下のコードを parser.add_argument() の最後に追加します。

    parser.add_argument('--half', action='store_true', help='half precision')
    parser.add_argument('--fuse', action='store_true', help='fusion')

推論実行

--img-size はモデルに応じて 640 あるいは 1280 を指定します。fp16 使用時は --halffusion--fuse でそれぞれ指定します。

!python detect.py --source images/kite.jpg --cfg cfg/yolor_csp.cfg --weights yolor_csp.pt --output image_out/yolor_csp --conf 0.25 --iou 0.45 --img-size 640 --device 0

YOLOR (yolor-w6, yolor-e6, yolor-d6)

YOLOR のインストール

os.chdir('/content/')
!git clone -b paper https://github.com/WongKinYiu/yolor

os.chdir('yolor')

# pip install required packages
!pip install -qr requirements.txt

# install mish-cuda if you want to use mish activation
# https://github.com/thomasbrandon/mish-cuda
# https://github.com/JunnYu/mish-cuda
!git clone https://github.com/JunnYu/mish-cuda
os.chdir('mish-cuda')
!python setup.py build install
os.chdir('..')

# install pytorch_wavelets if you want to use dwt down-sampling module
# https://github.com/fbcotter/pytorch_wavelets
!git clone https://github.com/fbcotter/pytorch_wavelets
os.chdir('pytorch_wavelets')
!pip install .
os.chdir('..')

事前学習モデルの取得

yolor-w6、yolor-e6、yolor-d6 の事前学習モデルはここからダウンロードできます。

detect.pyの変更

引数受領

detect() の最初で opt から引数を取得しますが、half と fuse を追加で取得します。

    # out, source, weights, view_img, save_txt, imgsz, cfg, names = \
    #     opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, opt.cfg, opt.names
    out, source, weights, view_img, save_txt, imgsz, cfg, names, half, fuse = \
        opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, opt.cfg, opt.names, opt.half, opt.fuse
half と fuse の反映

half の定義に引数の値を反映します。

    # half = device.type != 'cpu'  # half precision only supported on CUDA
    half = device.type != 'cpu' and half # half precision only supported on CUDA

また、model を attempt_load() で生成しますが、そこでは fusion を強制しますので、指定ができるように引数で fuse をわたします。また、model.half() を実行する箇所では、実行がわかるようにプリント文を追加します。

    model = attempt_load(weights, map_location=device, fuse=fuse)  # 変更
    imgsz = check_img_size(imgsz, s=model.stride.max())  # check img_size
    if half:
        model.half()  # to FP16
        print("Model half") # 追加
推論実行後の時間採取
        pred = model(img, augment=opt.augment)[0]
        t11 = time_synchronized() # 追加
結果の出力

t1 と新たに追加した t11 を使って推論時間を出力します。detect() の最後に追加します。なお、推論対象の画像が複数枚あったとしても、このコードでは最後の画像の推論時間の出力になります。

    print('Inference:', t11 - t1)
parser.add_argument() 追加

half と fuse を指定できるよう、以下のコードを parser.add_argument() の最後に追加します。

    parser.add_argument('--half', action='store_true', help='half precision')
    parser.add_argument('--fuse', action='store_true', help='fusion')

models/experimental.py の変更

fuse の指定を反映します。

def attempt_load(weights, map_location=None, fuse=True):
    # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
    model = Ensemble()
    if fuse:
        print("Model fusing in attempt_load()")
    for w in weights if isinstance(weights, list) else [weights]:
        attempt_download(w)
        if fuse:
            model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval())  # load FP32 model
        else:
            model.append(torch.load(w, map_location=map_location)['model'].float().eval())  # load FP32 model
(snip)

推論実行

--img-size はモデルに応じて 640 あるいは 1280 を指定します。fp16 使用時は --halffusion--fuse でそれぞれ指定します。

!python detect.py --source images/kite.jpg --cfg cfg/yolor_csp.cfg --weights yolor_csp.pt --output image_out/yolor_csp --conf 0.25 --iou 0.45 --img-size 640 --device 0

YOLOX

YOLOXのインストール

os.chdir('/content')
!git clone https://github.com/Megvii-BaseDetection/YOLOX.git > /dev/null
os.chdir('YOLOX')
!pip3 install -U pip && pip3 install -r requirements.txt > /dev/null
!pip3 install -v -e . > /dev/null # or  python3 setup.py develop

事前学習モデルの取得

YOLOX github の Benchmark セクションにモデルの表があります。各モデルに対応する事前学習モデルを weights 欄から取得します。yolox_s の事前学習モデルのダウンロード例です。

!wget https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s.pth

tools/demo.py の変更

import の追加

時間計測のため、以下を import を行うファイルの先頭部分の最後に追加します。

from yolox.utils import time_synchronized
推論時間の計測

inference() コードの後半部はオリジナルでは以下のコードです。

        with torch.no_grad():
            t0 = time.time()
            outputs = self.model(img)
            if self.decoder is not None:
                outputs = self.decoder(outputs, dtype=outputs.type())
            outputs = postprocess(
                outputs, self.num_classes, self.confthre,
                self.nmsthre, class_agnostic=True
            )
            logger.info("Infer time: {:.4f}s".format(time.time() - t0))
        return outputs, img_info

時間の計測部分を time_synchronized() を使うように変更し、コーラーに結果を戻すようにします。

        with torch.no_grad():
            t0 = time_synchronized() # 変更
            outputs = self.model(img)
            t1 = time_synchronized() # 追加
            if self.decoder is not None:
                outputs = self.decoder(outputs, dtype=outputs.type())
            outputs = postprocess(
                outputs, self.num_classes, self.confthre,
                self.nmsthre, class_agnostic=True
            )
            logger.info("Infer time: {:.4f}s".format(t1 - t0)) # 変更
        return outputs, img_info, t1 - t0 # 変更

なお、推論対象の画像が複数枚あったとしても、このコードでは最後の画像の推論時間しかコーラーに戻りません。

image_demo() の変更

predictor.inference() の戻りで上記の推論時間を受けるように変更し、さらに return でコーラーに返すようにします。

        # outputs, img_info = predictor.inference(image_name)
        outputs, img_info, latency = predictor.inference(image_name) # 変更
        (snip)
    return latency # 変更
main() の変更

オリジナルコードのままでは --fp16--fuse を同時に指定すると、model.fuse() でエラーになるため、実行順序を変更します。model = exp.get_model() から if args.trt: の直前までを以下のように変更します。

    model = exp.get_model()
    logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))

    # if args.device == "gpu":
    #     model.cuda()
    #     if args.fp16:
    #         model.half()  # to FP16
    # model.eval()

    if not args.trt:
        if args.ckpt is None:
            ckpt_file = os.path.join(file_name, "best_ckpt.pth")
        else:
            ckpt_file = args.ckpt
        logger.info("loading checkpoint")
        ckpt = torch.load(ckpt_file, map_location="cpu")
        # load the model state dict
        model.load_state_dict(ckpt["model"])
        logger.info("loaded checkpoint done.")

    if args.device == "gpu":               # 追加
        model.cuda()                       # 追加

    if args.fuse:
        logger.info("\tFusing model...")
        model = fuse_model(model)

    if args.fp16 and args.device == "gpu": # 追加
        model.half()  # to FP16            # 追加
        logger.info("\tModel half")        # 追加
        
    model.eval()                           # 追加

    if args.trt:
        assert not args.fuse, "TensorRT model is not support model fusing!"
        (snip)

また、推論時間をプリントするため、image_demo() の戻りを使います。

    if args.demo == "image":
        inference_time = image_demo(predictor, vis_folder, args.path, current_time, args.save_result) # 変更
        print('Inference:', inference_time) # 追加

推論実行

fp16 使用時は --fp16fusion--fuse でそれぞれ指定します。

!python tools/demo.py image -n yolox-s -c yolox_s.pth --path images/kite.jpg --save_result --device gpu --conf 0.25 --nms 0.45

A100 の対応

A100 では、YOLOR などインストール・環境設定時に以下のエラーがでます。

A100-SXM4-40GB with CUDA capability sm_80 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70 sm_75.
If you want to use the A100-SXM4-40GB GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/

Web 上の記事を参考に、推論実行前に torch を import し直すことでエラーを回避しました。

!pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
13
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
8