21
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

PyTorch 1.3で追加されたquantizationを試す

Last updated at Posted at 2019-10-17

PyTorch 1.3がリリースされました.
今回はその中でも,実験的機能ではありますが追加されたquantizationを試します.
変更点の詳細については公式の記事など参照してください.公式

#Quantization
公式のドキュメントチュートリアルを参考にコードを実装しました.
今回は量子化したモデルの学習(Quantization-aware training)ではなく,学習した重みを使ったモデルの量子化(Post-training static quantization)を行います.
##事前学習
まずベーシックなLenet5で重みの学習を行います.

lenet5.py
import torch
import torch.nn as nn
from torch.quantization import QuantStub, DeQuantStub

class LeNet5(nn.Module):
    def __init__(self, bias=True):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 6, 5, padding=2, bias=bias),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(6, 16, 5, bias=bias),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120, bias=bias),
            nn.ReLU(inplace=True),
            nn.Linear(120, 84, bias=bias),
            nn.ReLU(inplace=True),
            nn.Linear(84, 10, bias=bias)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x

class QLeNet5(nn.Module):
    def __init__(self, bias=True):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 6, 5, padding=2, bias=bias),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(6, 16, 5, bias=bias),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120, bias=bias),
            nn.ReLU(inplace=True),
            nn.Linear(120, 84, bias=bias),
            nn.ReLU(inplace=True),
            nn.Linear(84, 10, bias=bias)
        )
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.classifier(x)
        x = self.dequant(x)

        return x

    def fuse_model(self):
        for idx,m in enumerate(self.modules()):
            for idx in range(len(m.features) - 1):
                if(type(m.features[idx])) == nn.Conv2d:
                    torch.quantization.fuse_modules(m.features, [str(idx), str(idx + 1)], inplace=True)

            for idx in range(len(m.classifier) - 1):
                if(type(m.classifier[idx])) == nn.Linear:
                    torch.quantization.fuse_modules(m.classifier, [str(idx), str(idx + 1)], inplace=True)
            break

pre.py
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from lenet5 import LeNet5
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import tqdm

def train(model, device, train_loader, optimizer, epoch):
    criterion = nn.CrossEntropyLoss()
    model.train()
    with tqdm.tqdm(train_loader) as pbar:
        for batch_idx,(x, y) in enumerate(pbar):
            scores = model(x)
            loss = criterion(scores, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_description("Train Epoch {}".format(epoch))

def test(model, device, test_loader):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    num_correct, num_samples = 0, len(test_loader.dataset)
    for x, y in test_loader:
        scores = model(x)
        test_loss = criterion(scores,y)
        _, preds = scores.data.max(1)
        num_correct += (preds == y).sum()

    acc = float(num_correct) / num_samples

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, num_correct, len(test_loader.dataset),
        100. * acc))

def main():
    train_dataset = MNIST(root='./data/',train=True, download=True, transform=transforms.ToTensor())
    loader_train = DataLoader(train_dataset, batch_size=32, shuffle=True)

    test_dataset = MNIST(root='./data/', train=False, download=True, transform=transforms.ToTensor())
    loader_test = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'


    model = LeNet5().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

    epochs = 10

    for epoch in range(1, epochs + 1):
        train(model, device, loader_train, optimizer, epoch)
        test(model, device, loader_test)

    torch.save(model.state_dict(),"mnist.pkl")

if __name__ == '__main__':
    main()

###結果

Test set: Average loss: 0.0045, Accuracy: 9898/10000 (99%)

コードの方ではベーシックなLenet5を使って学習を行っていますが,QLenet5で学習してもこの時点では量子化されていないので大体同じ結果になると思います.

##モデルの量子化
学習した重みを使って,quantizationされたLenet5(QLenet5)の精度を確かめます.

quant.py
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torch.quantization
from lenet5 import QLeNet5,LeNet5

def test(model, device, test_loader):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    num_correct, num_samples = 0, len(test_loader.dataset)
    for x, y in test_loader:
        scores = model(x)
        test_loss = criterion(scores,y)
        _, preds = scores.data.max(1)
        num_correct += (preds == y).sum()

    acc = float(num_correct) / num_samples

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, num_correct, len(test_loader.dataset),
        100. * acc))

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

def main():
    train_dataset = MNIST(root='./data/',train=True, download=True, transform=transforms.ToTensor())
    loader_train = DataLoader(train_dataset, batch_size=32, shuffle=True)

    test_dataset = MNIST(root='./data/', train=False, download=True, transform=transforms.ToTensor())
    loader_test = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = LeNet5().to(device)
    model.load_state_dict(torch.load('./mnist.pkl'))

    print("Before Quantization")
    print_size_of_model(model)
    test(model, device, loader_test)

    qmodel = QLeNet5().to(device)
    qmodel.load_state_dict(torch.load('./mnist.pkl'))
    qmodel.eval()

    qmodel.fuse_model()

    qmodel.qconfig = torch.quantization.default_qconfig
    torch.quantization.prepare(qmodel,inplace=True)
    print("Prepare Model")
    test(qmodel, device, loader_test)

    torch.quantization.convert(qmodel,inplace=True)
    print("After Quantization")
    print_size_of_model(qmodel)
    test(qmodel, device, loader_test)

if __name__ == '__main__':
    main()

量子化をするには,まずfuse_modelでConv2dやLinearをReLuとfusionさせます.
次にqconfigで各レイヤーにセットするObserverを決め,prepareでセットします.その後一度モデルにデータを通すことで,各レイヤーのscaleやzero_pointを計算するためのパラメータを取得します.
最後にconvertで各レイヤーをqunatizedレイヤーに変換します.
###結果

Before Quantization
Size (MB): 0.248675
Test set: Average loss: 0.0045, Accuracy: 9898/10000 (99%)

Prepare Model
Test set: Average loss: 0.0045, Accuracy: 9898/10000 (99%)

After Quantization
Size (MB): 0.065886
Test set: Average loss: 0.0043, Accuracy: 9895/10000 (99%)

精度はそれほど下がらずに,モデルのサイズは約1/4になっているので量子化は成功したといえるのではないでしょうか.

#おわり
pytorch1.3のquantizationを試してみましたが,レイヤーをfusionさせるところが面倒くさかったです.かなり無理やりになってしまった感はあるので,こっちの書き方のほうがいいよ等あれば教えていただきたいです.
まだ実験的機能なので,その辺は今後に期待という感じですかね.

21
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?