0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

torch.quantization.quantize_dynamicによるモデルの量子化

Last updated at Posted at 2025-01-25

今更ながらモデルの量子化の実験をしてみます。PyTorchでモデルの量子化がどれくらい簡単にできるかというこの確認と、量子化前後で以下を比較してみます。

  • 推論結果
    • 本来はtest setを用意して定量的に精度を比較できると良かったですが、ここではサンプル画像を用意して定性的に推論結果を比較してみます。
  • 実行時間
    • 推論時間が短くなることを確認します。
  • モデルサイズ
    • モデルサイズが小さくなることを確認します。

frog.jpg

画像はAdobe Stockから適当な画像を取得しました。

モデルはResNet50を利用しました。

実験スクリプト

import sys
import time
from PIL import Image

import torch
import torchvision
from torchvision import transforms
from torchvision.models import ResNet50_Weights


def load_model(weights=ResNet50_Weights.DEFAULT):
    """モデルのロードと設定"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torchvision.models.resnet50(weights=weights)
    model.to(device)
    model.eval()
    classes = weights.meta["categories"]
    return model, classes, device


def preprocess_image(image_path, device):
    """画像の前処理"""
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ])
    try:
        input_image = Image.open(image_path)
        input_tensor = preprocess(input_image)
        return input_tensor.unsqueeze(0).to(device)
    except FileNotFoundError:
        raise FileNotFoundError(f"Error: {image_path} not found.")
    except Exception as e:
        raise Exception(f"An error occurred during image preprocessing: {str(e)}")


def predict(model, input_batch, classes, quantized=False):
    """モデルによる推論と結果の表示"""
    try:
        start_time = time.time()
        with torch.no_grad():
            output = model(input_batch)
        end_time = time.time()

        # 確率を計算
        probabilities = torch.nn.functional.softmax(output[0], dim=0)

        # 上位5クラスを取得
        top5_prob, top5_idx = torch.topk(probabilities, 5)
        model_type = "Quantized Model" if quantized else "Original Model"

        print(f"\n{model_type} Predictions:")
        print(f"Time taken: {end_time - start_time:.3f} seconds")
        for i in range(5):
            print(f"{classes[top5_idx[i]]:>20}: {top5_prob[i].item() * 100:.2f}%")
    except Exception as e:
        raise Exception(f"An error occurred during prediction: {str(e)}")


def quantize_model(model):
    """モデルを量子化 (int8)"""
    try:
        # GPUモデルをCPUに移動(量子化はCPUでのみ実行可能)
        model = model.cpu()
        quantized_model = torch.quantization.quantize_dynamic(
            model, 
            {torch.nn.Linear, torch.nn.Conv2d}, 
            dtype=torch.qint8
        )
        return quantized_model
    except Exception as e:
        raise Exception(f"An error occurred during model quantization: {str(e)}")


def get_model_size(model):
    """モデルのサイズを計算(MB単位)"""
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb


# メイン処理
if __name__ == "__main__":
    IMAGE_PATH = './frog.jpg'

    try:
        # モデルとクラスラベルのロード
        model, classes, device = load_model()
        print(f"Using device: {device}")

        # 画像の前処理
        input_batch = preprocess_image(IMAGE_PATH, device)

        # オリジナルモデルで推論
        predict(model, input_batch, classes)
        print(f"Model size: {get_model_size(model):.2f} MB")

        # 量子化されたモデルで推論
        quantized_model = quantize_model(model)
        # 量子化モデルはCPU上で実行
        input_batch = input_batch.cpu()
        predict(quantized_model, input_batch, classes, quantized=True)
        print(f"Quantized model size: {get_model_size(quantized_model):.2f} MB")

    except FileNotFoundError as e:
        print(e)
    except Exception as e:
        print(f"An unexpected error occurred: {str(e)}")

実行結果

Using device: cpu

Original Model Predictions:
Time taken: 0.247 seconds
           tree frog: 37.26%
         tailed frog: 4.85%
            bullfrog: 0.28%
European fire salamander: 0.16%
   African chameleon: 0.15%
Model size: 97.70 MB

Quantized Model Predictions:
Time taken: 0.228 seconds
           tree frog: 37.68%
         tailed frog: 4.81%
            bullfrog: 0.28%
European fire salamander: 0.16%
   African chameleon: 0.15%
Quantized model size: 89.88 MB
  • 推論結果
    • 各クラスに割り当てている確率は変化していますが、大きく変わったようには見えません。
  • 実行時間
    • 推論時間に変化はありませんでした。実は今回の検証はローカルのMac(M2)のCPUで実行しましたが、このCPUはint8ハードウェアアクセラレーションを持っていないため、高速化されないということがわかりました。
  • モデルサイズ
    • 雑に書いているのでメモリ使用量の計算が正確か自信がありませんが、モデルサイズが小さくなっていることは間違いなさそうです。

おまけ

その後GoogleColab(GPU)でautocastを利用して、FP32(デフォルト)とFP16で比較してみましたが、あまり差はありませんでした。

def predict(model, input_batch, classes, quantized=False):
    """モデルによる推論と結果の表示"""
    try:
        start_time = time.time()
        with torch.no_grad():
          with autocast(dtype=torch.float):
            output = model(input_batch)
        end_time = time.time()

        # 確率を計算
        probabilities = torch.nn.functional.softmax(output[0], dim=0)

        # 上位5クラスを取得
        top5_prob, top5_idx = torch.topk(probabilities, 5)
        model_type = "Quantized Model" if quantized else "Original Model"

        print(f"\n{model_type} Predictions:")
        print(f"Time taken: {end_time - start_time:.3f} seconds")
        for i in range(5):
            print(f"{classes[top5_idx[i]]:>20}: {top5_prob[i].item() * 100:.2f}%")
    except Exception as e:
        raise Exception(f"An error occurred during prediction: {str(e)}")
FP32: 0.123 seconds
FP16: 0.127 seconds

ちなみに

今回はbatch_size=1でしたが、適当に以下のように画像を32毎用意してbatch_size=32とした場合、処理時間がCPUは約32倍、GPUではほぼかわらずであることも(当然ですが)確認できました。

# 同じ画像を32枚用意
input_tensor = input_tensor.unsqueeze(0).repeat(32, 1, 1, 1).to(device)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?