timm
のMobileNet-V4というモデルを使用して、PyTorchでの枝刈り(プルーニング)を試してみたので、手順をまとめます。
枝刈りの対象となるモジュールの特定
枝刈りは主にweight
パラメータに対して行います。そのため、まずweight
パラメータを持つモジュール(層)を特定する必要があります。
weight
パラメータを持つモジュールの確認
モデル内の全モジュールからweight
属性を持つものを抽出し、種類と名前を表示する関数を実行します。
import torch
from torch import nn
import timm
from typing import Type, List, Tuple, Dict
import torch.nn.utils.prune as prune
def is_prunable_module(module: nn.Module) -> bool:
"""モジュールが枝刈り可能かどうかを判定"""
return (hasattr(module, 'weight') and
module.weight is not None and
isinstance(module.weight, torch.Tensor))
def get_prunable_modules(model: nn.Module) -> List[Tuple[str, nn.Module]]:
"""モデル内の枝刈り可能なモジュールのリストを取得"""
prunable_modules = []
for name, module in model.named_modules():
if is_prunable_module(module):
prunable_modules.append((name, module))
return prunable_modules
def get_prunable_module_dict(model: nn.Module) -> Dict[Type[nn.Module], List[Tuple[str, nn.Module]]]:
"""モデル内の枝刈り可能なモジュールを辞書形式で取得"""
prunable_modules = get_prunable_modules(model)
unique_modules = {}
for name, module in prunable_modules:
module_type = type(module)
if module_type not in unique_modules:
unique_modules[module_type] = []
unique_modules[module_type].append((name, module))
return unique_modules
def print_prunable_module_types(model: nn.Module):
unique_modules = get_prunable_module_dict(model)
print("Prunable modules:")
for module_type, instances in unique_modules.items():
print(f" {module_type.__name__}:")
for name, _ in instances:
print(f" - {name}")
def main():
model = timm.create_model('mobilenetv4_conv_small.e2400_r224_in1k', pretrained=True, num_classes=3)
# 枝刈り可能なモジュールの型と名前を表示
print_prunable_module_types(model)
if __name__ == '__main__':
main()
# 実行結果
# Prunable modules:
# Conv2d:
# - conv_stem
# .....
# BatchNormAct2d:
# - bn1
# .....
# Linear:
# - classifier
実行結果から、このモデルではConv2d
, BatchNormAct2d
, Linear
モジュールがweight
パラメータを持つことがわかりました。
枝刈りの対象とするモジュールの選択
全てのweight
パラメータを持つモジュールが枝刈りに適しているわけではありません。以下の点を考慮して対象を絞り込みます:
- 多次元のテンソルを持つモジュール(例:
Conv2d
,Linear
)を主な対象とします - パラメータ数が少ないモジュール(例:
BatchNormAct2d
)は除外します
この選択基準は、各モジュールのweightパラメータの次元数や、一般的な枝刈りの慣行に基づいています。
これを踏まえて、今回はConv2d
とLinear
を枝刈りの対象としました。
枝刈り実行
今回はl1_unstructuredとglobal_unstructuredの2つのメソッドを試しました。
PyTorchの枝刈りメソッドであるl1_unstructured
とglobal_unstructured
の主な違いは、枝刈りの適用範囲にあります。
l1_unstructured
:
- 各モジュールごとに個別に枝刈りを適用します
- 指定した割合(
amount
)に基づいて、各モジュール内のパラメータを削減します
global_unstructured
:
- モデル全体を対象に枝刈りを適用します
- 指定した割合に基づいて、モデル全体のパラメータを削減します
- 枝刈りの方法(例:L1ノルムに基づく削減)は別途指定する必要があります
global_unstructured
では、異なる枝刈り方法を指定できますが、この例ではl1_unstructured
と同様にL1ノルムに基づいて枝刈りを行うprune.L1Unstructured
を使用しています。
prune.l1_unstructured
def has_valid_bias(module):
if not hasattr(module, 'bias'): return False
return True if module.bias is not None else False
def get_pruning_module_types(model):
"""枝刈りの対象にするモジュールの型を取得"""
types_to_remove = [
timm.layers.norm_act.BatchNormAct2d,
]
prune_types = [t for t in get_prunable_module_dict(model).keys() if t not in types_to_remove]
return prune_types
def apply_pruning_to_each_layer(model, method, amount):
"""モジュール毎に枝刈りを実行"""
prune_types = get_pruning_module_types(model)
for _, module in model.named_modules():
if not isinstance(module, tuple(prune_types)): continue
method(module, name='weight', amount=amount)
if not has_valid_bias(module): continue
method(module, name='bias', amount= amount)
def print_sparsity(model):
"""枝刈り対象にしたモジュールのパラメータに含まれる0の割合"""
prune_types = get_pruning_module_types(model)
print("=====result check=====")
for name, module in model.named_modules():
if not isinstance(module, tuple(prune_types)): continue
print(f"Sparsity in {name}.weight: "
f"{100. * float(torch.sum(module.weight == 0)) / float(module.weight.nelement()):.2f}%")
if not has_valid_bias(module): continue
print(f"Sparsity in {name}.bias: "
f"{100. * float(torch.sum(module.bias == 0)) / float(module.bias.nelement()):.2f}%")
def main():
model = timm.create_model('mobilenetv4_conv_small.e2400_r224_in1k', pretrained=True, num_classes=3)
# print_prunable_module_types(model)
# モデルの各モジュールに対してl1_unstructured実行
apply_pruning_to_each_layer(model=model, method=prune.l1_unstructured, amount=0.3)
# 枝刈り結果確認
print_sparsity(model)
# =====result check=====
# Sparsity in conv_stem.weight: 29.98%
# Sparsity in blocks.0.0.conv.weight: 30.00%
# .....
実行結果から、各モジュールのweight
パラメータの0
の割合が30%
になっていることが確認できます。
prune.global_unstructured
def get_parameters_to_prune(model):
result = []
prune_types = get_pruning_module_types(model)
parameters_to_prune = [module for _, module in model.named_modules() if isinstance(module, tuple(prune_types))]
for module in parameters_to_prune:
result.append((module, "weight"))
if has_valid_bias(module):
result.append((module, "bias"))
return tuple(result)
def apply_pruning_to_whole_model(model, method, amount):
parameters_to_prune = get_parameters_to_prune(model)
prune.global_unstructured(
parameters_to_prune,
pruning_method=method,
amount=amount,
)
def main():
model = timm.create_model('mobilenetv4_conv_small.e2400_r224_in1k', pretrained=True, num_classes=3)
apply_pruning_to_whole_model(model, method=prune.L1Unstructured, amount=0.3)
print_sparsity(model)
# =====result check=====
# Sparsity in conv_stem.weight: 37.73%
# Sparsity in blocks.0.0.conv.weight: 44.84%
# .....
実行結果を見ると、l1_unstructured
の場合とは異なり、各モジュールの枝刈り割合が異なっていることが確認できます。
パラメータを固定する
今回は省きますが、枝刈り後の学習によってパラメータの調整を終え、精度と速さに納得がいった後prune.remove()
を実行することで、パラメータを固定します。
def fix_pruning_result(model):
prune_types = get_pruning_module_types(model)
for _, module in model.named_modules():
if not isinstance(module, tuple(prune_types)): continue
prune.remove(module, 'weight')
if not has_valid_bias(module): continue
prune.remove(module, 'bias')
def main():
model = timm.create_model('mobilenetv4_conv_small.e2400_r224_in1k', pretrained=True, num_classes=3)
apply_pruning_to_whole_model(model, method=prune.L1Unstructured, amount=0.3)
# remove実行
fix_pruning_result(model)
以上がPyTorchでの枝刈りの基本的な流れです。
あとは枝刈りの手法や割合などを状況に合わせて調整してみてください。
補足
枝刈り後の変化
PyTorchの枝刈りを適用すると、モデルに以下の変更が加えられます。
新しいプロパティの追加:
-
weight_mask
: パラメーターの使用可否を指定する0
または1
のTensor -
weight_orig
: 元のモデルパラメータを保持するTensor
新しいメソッドの追加:
prune()
: 上記のプロパティと枝刈りフックを削除するメソッド
_forward_pre_hooks
への追加:
weight_mask
をweight_orig
に適用し、枝刈り後のパラメータを計算するフックが_forward_pre_hooks
に追加されます。
_forward_pre_hooks
は、モジュールのforward()
メソッド呼び出し直前に実行されるフックが登録されているOrderedDict
で、module._forward_pre_hooks
で中身を確認することができます。
prune.remove()
の効果
prune.remove()
を実行すると以下の変更が行われます。
-
weight_orig
とweight_mask
プロパティの削除 -
_forward_pre_hooks
から枝刈りフックの削除 - モジュールのパラメータを枝刈り適用後の値に置き換え
これにより、forward()
メソッド実行時の追加計算が省略され、モデルの実行速度が向上する可能性があります。
注意点:
- 枝刈りの効果は、モデルの構造やサイズ、適用方法によって異なります
- 実行速度の向上は、枝刈りの程度や対象レイヤーによって変わります
- 過度の枝刈りはモデルの性能低下を招く可能性があるため、適切なバランスが必要です