LoginSignup
22
13

More than 5 years have passed since last update.

Chainer向けChannel Pruningフレームワークを作った話

Posted at

ChaienrをつかったDeep Learningモデルの軽量化・高速化に関する話です。

Chainer向けChannel PruningフレームワークChainerPrunerをOSSとして公開しました。

また同OSSについて、本日2018/12/15にChainer Meetup #08にて「Chainer向けChannel Pruningフレームワークの設計と実装」と題して話させてもらいました。ありがとうございました。

この記事では、開発の背景やコンセプト、簡単な使い方をまとめます。開発段階のため、実装の詳細はもっと良いものに変化する可能性があります。また、API Referenceも中途半端なので、近日中にUpdate予定です。

Channel Pruningとは

Deep Learningモデルを軽量化するためのPruningの中の1手法です。Pruningは出力に寄与しないWeightを除去することでストレージ及びメモリの使用量の削減と計算量の削減(以降、便宜上軽量化・高速化と言い換えます)を行う手法です。

Weightの削減の仕方に応じて様々なPruningがありますが、Channel Pruningは、Convolution層の出力チャネルを減らすようにWeightを削減する方法です。Weightのshapeを小さくできるため、CPUやGPUなどの汎用的なハードウェアにおいても軽量化・高速化を実現できます。

有名な手法だと、ICCV2017のLearning Efficient Convolutional Networks through Network Slimmingがありますが、継続的に論文はでており、ECCV2018のAMC: AutoML for Model Compression and Acceleration on Mobile DevicesやICLR2019(under review)ではRethinking the Value of Network Pruningのような論文も出ており、大変興味深いです。より軽量で精度が良く、学習の制御がしやすくなってきているように思います。

Channel Pruningの学習パイプライン

Channel Pruningは通常のTraining Loopの中に、以下のように適用されます。

例えば先に紹介したLearning Efficient Convolutional Networks through Network Slimmingでは、モデルを初期化→学習→Pruning→再学習→Pruning→再学習…を繰り返します。著者実装の図がわかりやすいです。

また、AMC: AutoML for Model Compression and Acceleration on Mobile Devicesでは、学習済みモデルをどうPruningするか、つまり各層のChannelの削減率をどう設定するか、を強化学習で探索し、一番精度が保てる削減率を用いて、モデルをスクラッチで学習します。このように、これはアルゴリズムによって様々です。

ChainerPrunerのコンセプト

スライドでも説明させていただきましたが、ChainerPrunerは、Channel Pruningを実装するためのコアライブラリを提供しており、論文のアルゴリズムなどが簡単に実装できることを目指しています。

それは、私自身がChannel Pruningのアルゴリズムを追実装する際の以下の課題感からでした。

  1. Pruning前後でモデルの定義を修正する必要がある
  2. Pruningするために、レイヤー感の依存関係を管理しなければならない

それぞれ説明します。

[課題1]Pruning前後でモデルの定義を修正する必要がある

Channel Pruningでは、学習前後でConvolution層の出力チャネル数が変化します。そのため、学習後のコンパクトなモデルを読み込むために、モデル定義を修正する必要がありました。

[課題2]Pruningするために、レイヤー感の依存関係を管理しなければならない

Weightをコンパクトにする過程で生じる問題ですが、例えば以下の結合を考えます。

(1)Conv-(2)BN-(3)ReLU-(4)Conv-(5)BN-(6)FC

(1)の層をPruningして出力チャネル数を変更する場合、(2)と(4)のWeightも調整が必要です。(2)の入出力チャネル数と(4)の入力チャネル数が(1)の出力チャネル数に依存するからです。これを実現するために、モデルごとにPruningの実装をするのは面倒です。

ChainerPrunerでの解決方法

現状、以下のように解決しています。

  • PruningしたWeightのサイズに合わせてモデル定義を変える (実装の該当箇所)
  • 計算グラフを解析し、自動でPruningのためのレイヤー間の依存関係を解決 (実装の該当箇所)

特に計算グラフを解析する際には、Function HookやChainer v5で導入されたLinkHookを用いて実装しています。これらを活用すればグラフ情報を構築できるので便利。

これによってユーザーは、学習時はモデルとPruningのためのハイパーパラメーター(各層のPruning率など)だけ指定すればよくなり、推論時は元のモデル定義とPruning済のWeightのみ持っていればよく、実験から運用までが楽になります。

このアプローチは、動的にPruningするのでとりあえずDynamic Pruningと呼ぶことにします。このDynamic Pruningについて説明するnotebookがあるので参考にして頂けたら幸いです。

Mask, Rebuild

Pruningの実装は、Mask, Rebuldの2stepで行うようにしています。

  • Mask
    • Weightにたいして、しきい値などでPruning対象のチャネルをall zeroにする (実装の該当箇所)
  • Rebuild
    • Maskでall zeroにした部分を除去し、レイヤー間チャネル数の依存関係の解決しながらWeightを再構築 (実装の該当箇所)

Maskは、Channel Pruningアルゴリズム毎にchainerpruner.Maskのサブクラスとして、実装する方針としており、現在は汎用的な、WeightのL1 or L2ノルムとWeightのしきい値 or チャネルの比率でPruning率を指定できる chainerpruner.masks.NormMaskを実装しています。

Rebuildは、層ごとにPruningするためのchainerpruner.rebuild.RebuildLinkクラスを実装しており、chainerpruner.rebuild.links以下に実装しています。

このI/Fをベースにして、対応する層やMaskのアルゴリズムを拡張していく想定です。

Channel Pruningアルゴリズムの実装例

論文のアルゴリズムなどは chainerpruner.pruning以下に実装中です。これらは、ここまでに説明したコアライブラリをベースに実装しています。

例えば、Learning Efficient Convolutional Networks through Network Slimmingについて実装したchainerpruner.pruning.network_slimmingについて見ていきましょう。
これを、独自のアルゴリズムを実装する際の参考にしていただければと思います(別途ドキュメント各予定ですが)。

まず、このアルゴリズムは以下の2stepで行われます。

  1. BatchNormalizationレイヤーのgamma(scale)をL1正則化(Lasso)し、gammaを0に近づけていく
  2. このgammaをConvolutionレイヤーの出力チャネルの重要度とし、gammaが小さいチャネルをPruningする

ChainerのTrainerを2回runするイメージです。サンプルの学習コードがあります。

1では、ChainerのOptimizerHookのLassoをforkし、BatchNormalizationにのみ正則化がかかるように修正して使っています。

2は、数行なのでここに転載します。元コードはここ

from chainerpruner import Graph
from chainerpruner.rebuild import rebuild
from chainerpruner.rebuild import calc_pruning_connection
from chainerpruner.masks import NormMask


def pruning(model, args, target_layers, threshold, default=None):
    # モデルをdefine-by-runしてグラフ構造を取得
    graph = Graph(model, args) 
    # target_layer (e.g. ['/conv1', '/conv2']のようなlink名のlist) について、
    # 大きさがthreshold以下のbatchnormのgammaについて、mask(=all zero)にするクラスをインスタンス
    mask = NormMask(model, graph, target_layers, threshold=threshold, default=default,
                    mask_layer='batchnorm')
    info = {}
    info['mask'] = mask() # 実際にマスク(all zero)にする
    info['rebuild'] = rebuild(model, graph, target_layers) # rebuild(all zeroを除外して再構築)
    return info

どのようにTraining Loopに組み込んでいるかは、examples/network_slimming/train_mnist.pyを参照ください。

他の例としてProgressive Deep Neural Networks Acceleration via Soft Filter Pruningの実装もあります。こちらはTrainerExtensionsとして実装していますので、興味のある方は見てください。

あとがき

ChainerPrunerというフレームワークを公開しました。簡単にChannel Pruningを実装・実験・運用することを目指しており、今後もupdate予定です。フィードバックいただけると嬉しいです。

ChainerPrunerを使った具体的な実験結果やパフォーマンスについては、また別途記事をかけたらと思います。

Deep Learningの事業応用を目的としたモデルの軽量化が着目されるこの頃、きっとこうしたライブラリは各社つらみを抱えながらも内製していると思われます。みんなで情報共有できたらいいですね!

Chainer Meetup、初参加ですが非常に楽しかったです。皆様ありがとうございました。

22
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
22
13