LoginSignup
3
3

More than 3 years have passed since last update.

ディープラーニングのモデル軽量化ライブラリDistiller

Last updated at Posted at 2020-05-02

はじめに

※このライブラリはGPUの使用を前提としているのでご注意ください※
主はディープラーニングに関してド素人ですので, 誤りがある場合は優しく教えて頂けると幸いです。

Distillerとは

Distillerとは, DeepLearningのモデルを軽量化するアルゴリズムを備えたintelがPyTorchベースで作成したライブラリです。モデルの軽量化の主な例としては, 量子化(Quantization), 枝刈り(Pruning), 蒸留(Distillation)など様々なものがあり, これらを簡単に使えるのがDistillerです。
さらに, チュートリアルではTensorBoardと連帯して学習の状況を確認できる機能までついていた(感謝感激)

モデル軽量化について詳しく書かれたサイトがこちら
https://laboro.ai/column/%E3%83%87%E3%82%A3%E3%83%BC%E3%83%97%E3%83%A9%E3%83%BC%E3%83%8B%E3%83%B3%E3%82%B0%E3%82%92%E8%BB%BD%E9%87%8F%E5%8C%96%E3%81%99%E3%82%8B%E3%83%A2%E3%83%87%E3%83%AB%E5%9C%A7%E7%B8%AE/

環境開発

$ git clone https://github.com/NervanaSystems/distiller.git
$ cd distiller
$ pip install -r requirements.txt
$ pip install -e .
$ python
>>> import distiller
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/mnt/PytorchIntro/distiller/distiller/__init__.py", line 20, in <module>
    from .config import file_config, dict_config, config_component_from_file_by_class

...

  File "/root/local/python-3.7.1/lib/python3.7/site-packages/git/exc.py", line 9, in <module>
    from git.compat import UnicodeMixin, safe_decode, string_types
  File "/root/local/python-3.7.1/lib/python3.7/site-packages/git/compat.py", line 16, in <module>
    from gitdb.utils.compat import (
ModuleNotFoundError: No module named 'gitdb.utils.compat'

自分の場合はライブラリに追加したDistillerをインポートしようとしたらgitのライブラリ関連でエラーがでたので, 悪さをしていたgitdb2をダウングレードしたら直りました。(自分のインストールしたバージョンは4.0.2)

$ pip uninstall gitdb2
$ pip install gitdb2==2.0.6

いざ確認

$ cd distiller/examples/classifier_compression/
$ python3 compress_classifier.py --arch simplenet_cifar ../../../data.cifar10 -p 30 -j=1 --lr=0.01

--------------------------------------------------------
Logging to TensorBoard - remember to execute the server:
> tensorboard --logdir='./logs'

=> created a simplenet_cifar model with the cifar10 dataset
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../../../data.cifar10/cifar-10-python.tar.gz
 99%|█████████████████████████████████████████████████████████████████████████████▌| 169582592/170498071 [00:18<00:00, 11451969.71it/s]Extracting ../../../data.cifar10/cifar-10-python.tar.gz to ../../../data.cifar10
Files already downloaded and verified
Dataset sizes:
        training=45000
        validation=5000
        test=10000


Training epoch: 45000 samples (256 per mini-batch)
170500096it [00:30, 11451969.71it/s]                                                                                                   Epoch: [0][   30/  176]    Overall Loss 2.303411    Objective Loss 2.303411    Top1 10.299479    Top5 50.104167    LR 0.010000    Time 0.038285
Epoch: [0][   60/  176]    Overall Loss 2.301507    Objective Loss 2.301507    Top1 10.774740    Top5 51.328125    LR 0.010000    Time 0.037495
Epoch: [0][   90/  176]    Overall Loss 2.299031    Objective Loss 2.299031    Top1 12.335069    Top5 54.973958    LR 0.010000    Time 0.037465
Epoch: [0][  120/  176]    Overall Loss 2.293749    Objective Loss 2.293749    Top1 13.424479    Top5 57.542318    LR 0.010000    Time 0.037429
Epoch: [0][  150/  176]    Overall Loss 2.278429    Objective Loss 2.278429    Top1 14.692708    Top5 59.864583    LR 0.010000    Time 0.037407

Parameters:
+----+---------------------+---------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
|    | Name                | Shape         |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
|----+---------------------+---------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
|  0 | module.conv1.weight | (6, 3, 5, 5)  |           450 |            450 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07800 | -0.01404 |    0.06724 |
|  1 | module.conv2.weight | (16, 6, 5, 5) |          2400 |           2400 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04952 |  0.00678 |    0.04246 |
|  2 | module.fc1.weight   | (120, 400)    |         48000 |          48000 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02906 |  0.00082 |    0.02511 |
|  3 | module.fc2.weight   | (84, 120)     |         10080 |          10080 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05328 |  0.00084 |    0.04607 |
|  4 | module.fc3.weight   | (10, 84)      |           840 |            840 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.06967 | -0.00275 |    0.06040 |
|  5 | Total sparsity:     | -             |         61770 |          61770 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00000 |  0.00000 |    0.00000 |
+----+---------------------+---------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
Total sparsity: 0.00

--- validate (epoch=0)-----------
5000 samples (256 per mini-batch)
==> Top1: 25.240    Top5: 75.520    Loss: 2.060

==> Best [Top1: 25.240   Top5: 75.520   Sparsity:0.00   NNZ-Params: 61770 on epoch: 0]
Saving checkpoint to: logs/2020.05.02-235616/checkpoint.pth.tar

...

とりあえず, 動いたので一安心 ε-(´∀`*)ホッ
何か発見があり次第追記していこうと思います。

参考サイト

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3