6
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Pytorchで枝刈りを行う

Last updated at Posted at 2024-11-13

timmMobileNet-V4というモデルを使用して、PyTorchでの枝刈り(プルーニング)を試してみたので、手順をまとめます。

枝刈りの対象となるモジュールの特定

枝刈りは主にweightパラメータに対して行います。そのため、まずweightパラメータを持つモジュール(層)を特定する必要があります。

weightパラメータを持つモジュールの確認

モデル内の全モジュールからweight属性を持つものを抽出し、種類と名前を表示する関数を実行します。

pruning.py
import torch
from torch import nn
import timm
from typing import Type, List, Tuple, Dict
import torch.nn.utils.prune as prune


def is_prunable_module(module: nn.Module) -> bool:
    """モジュールが枝刈り可能かどうかを判定"""
    return (hasattr(module, 'weight') and 
            module.weight is not None and 
            isinstance(module.weight, torch.Tensor))


def get_prunable_modules(model: nn.Module) -> List[Tuple[str, nn.Module]]:
    """モデル内の枝刈り可能なモジュールのリストを取得"""
    prunable_modules = []
    for name, module in model.named_modules():
        if is_prunable_module(module):
            prunable_modules.append((name, module))
    return prunable_modules


def get_prunable_module_dict(model: nn.Module) -> Dict[Type[nn.Module], List[Tuple[str, nn.Module]]]:
    """モデル内の枝刈り可能なモジュールを辞書形式で取得"""
    prunable_modules = get_prunable_modules(model)
    unique_modules = {}
    
    for name, module in prunable_modules:
        module_type = type(module)
        if module_type not in unique_modules:
            unique_modules[module_type] = []
        unique_modules[module_type].append((name, module))
    return unique_modules


def print_prunable_module_types(model: nn.Module):
    unique_modules = get_prunable_module_dict(model)
    
    print("Prunable modules:")
    for module_type, instances in unique_modules.items():
        print(f"  {module_type.__name__}:")
        for name, _ in instances:
            print(f"    - {name}")


def main():
    model = timm.create_model('mobilenetv4_conv_small.e2400_r224_in1k', pretrained=True, num_classes=3)
    # 枝刈り可能なモジュールの型と名前を表示
    print_prunable_module_types(model)


if __name__ == '__main__':
    main()

# 実行結果
# Prunable modules:
#   Conv2d:
#     - conv_stem
#     .....
#   BatchNormAct2d:
#     - bn1
#     .....
#   Linear:
#     - classifier

実行結果から、このモデルではConv2d, BatchNormAct2d, Linearモジュールがweightパラメータを持つことがわかりました。

枝刈りの対象とするモジュールの選択

全てのweightパラメータを持つモジュールが枝刈りに適しているわけではありません。以下の点を考慮して対象を絞り込みます:

  • 多次元のテンソルを持つモジュール(例:Conv2d, Linear)を主な対象とします
  • パラメータ数が少ないモジュール(例:BatchNormAct2d)は除外します

この選択基準は、各モジュールのweightパラメータの次元数や、一般的な枝刈りの慣行に基づいています。

これを踏まえて、今回はConv2dLinearを枝刈りの対象としました。

枝刈り実行

今回はl1_unstructuredglobal_unstructuredの2つのメソッドを試しました。

PyTorchの枝刈りメソッドであるl1_unstructuredglobal_unstructuredの主な違いは、枝刈りの適用範囲にあります。

l1_unstructured:

  • 各モジュールごとに個別に枝刈りを適用します
  • 指定した割合(amount)に基づいて、各モジュール内のパラメータを削減します

global_unstructured:

  • モデル全体を対象に枝刈りを適用します
  • 指定した割合に基づいて、モデル全体のパラメータを削減します
  • 枝刈りの方法(例:L1ノルムに基づく削減)は別途指定する必要があります

global_unstructuredでは、異なる枝刈り方法を指定できますが、この例ではl1_unstructuredと同様にL1ノルムに基づいて枝刈りを行うprune.L1Unstructuredを使用しています。

prune.l1_unstructured

pruning.py
def has_valid_bias(module):
    if not hasattr(module, 'bias'): return False
    return True if module.bias is not None else False


def get_pruning_module_types(model):
    """枝刈りの対象にするモジュールの型を取得"""
    types_to_remove = [
        timm.layers.norm_act.BatchNormAct2d,
    ]
    prune_types = [t for t in get_prunable_module_dict(model).keys() if t not in types_to_remove]
    return prune_types


def apply_pruning_to_each_layer(model, method, amount):
    """モジュール毎に枝刈りを実行"""
    prune_types = get_pruning_module_types(model)
    for _, module in model.named_modules():
        if not isinstance(module, tuple(prune_types)): continue
        method(module, name='weight', amount=amount)

        if not has_valid_bias(module): continue
        method(module, name='bias', amount= amount)


def print_sparsity(model):
    """枝刈り対象にしたモジュールのパラメータに含まれる0の割合"""
    prune_types = get_pruning_module_types(model)
    print("=====result check=====")
    for name, module in model.named_modules():
        if not isinstance(module, tuple(prune_types)): continue
        print(f"Sparsity in {name}.weight: "
                  f"{100. * float(torch.sum(module.weight == 0)) / float(module.weight.nelement()):.2f}%")
        if not has_valid_bias(module): continue
        print(f"Sparsity in {name}.bias: "
              f"{100. * float(torch.sum(module.bias == 0)) / float(module.bias.nelement()):.2f}%")


def main():
    model = timm.create_model('mobilenetv4_conv_small.e2400_r224_in1k', pretrained=True, num_classes=3)
    # print_prunable_module_types(model)
    # モデルの各モジュールに対してl1_unstructured実行
    apply_pruning_to_each_layer(model=model, method=prune.l1_unstructured, amount=0.3)
    # 枝刈り結果確認
    print_sparsity(model)

# =====result check=====
# Sparsity in conv_stem.weight: 29.98%
# Sparsity in blocks.0.0.conv.weight: 30.00%
# .....

実行結果から、各モジュールのweightパラメータの0の割合が30%になっていることが確認できます。

prune.global_unstructured

pruning.py
def get_parameters_to_prune(model):
    result = []
    prune_types = get_pruning_module_types(model)
    parameters_to_prune = [module for _, module in model.named_modules() if isinstance(module, tuple(prune_types))]

    for module in parameters_to_prune:
        result.append((module, "weight"))
        if has_valid_bias(module):
            result.append((module, "bias"))
    return tuple(result)


def apply_pruning_to_whole_model(model, method, amount):
    parameters_to_prune = get_parameters_to_prune(model)
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=method,
        amount=amount,
    )

def main():
    model = timm.create_model('mobilenetv4_conv_small.e2400_r224_in1k', pretrained=True, num_classes=3)
    apply_pruning_to_whole_model(model, method=prune.L1Unstructured, amount=0.3)
    print_sparsity(model)

# =====result check=====
# Sparsity in conv_stem.weight: 37.73%
# Sparsity in blocks.0.0.conv.weight: 44.84%
# .....

実行結果を見ると、l1_unstructuredの場合とは異なり、各モジュールの枝刈り割合が異なっていることが確認できます。

パラメータを固定する

今回は省きますが、枝刈り後の学習によってパラメータの調整を終え、精度と速さに納得がいった後prune.remove()を実行することで、パラメータを固定します。

pruning.py
def fix_pruning_result(model):
    prune_types = get_pruning_module_types(model)
    for _, module in model.named_modules():
        if not isinstance(module, tuple(prune_types)): continue
        prune.remove(module, 'weight')
        if not has_valid_bias(module): continue
        prune.remove(module, 'bias')

def main():
    model = timm.create_model('mobilenetv4_conv_small.e2400_r224_in1k', pretrained=True, num_classes=3)
    apply_pruning_to_whole_model(model, method=prune.L1Unstructured, amount=0.3)
    # remove実行
    fix_pruning_result(model)

以上がPyTorchでの枝刈りの基本的な流れです。
あとは枝刈りの手法や割合などを状況に合わせて調整してみてください。

補足

枝刈り後の変化

PyTorchの枝刈りを適用すると、モデルに以下の変更が加えられます。

新しいプロパティの追加:

  • weight_mask: パラメーターの使用可否を指定する0または1のTensor
  • weight_orig: 元のモデルパラメータを保持するTensor

新しいメソッドの追加:

prune(): 上記のプロパティと枝刈りフックを削除するメソッド

_forward_pre_hooksへの追加:

weight_maskweight_origに適用し、枝刈り後のパラメータを計算するフックが_forward_pre_hooksに追加されます。

_forward_pre_hooksは、モジュールのforward()メソッド呼び出し直前に実行されるフックが登録されているOrderedDictで、module._forward_pre_hooksで中身を確認することができます。

prune.remove()の効果

prune.remove()を実行すると以下の変更が行われます。

  • weight_origweight_maskプロパティの削除
  • _forward_pre_hooksから枝刈りフックの削除
  • モジュールのパラメータを枝刈り適用後の値に置き換え

これにより、forward()メソッド実行時の追加計算が省略され、モデルの実行速度が向上する可能性があります。

注意点:

  • 枝刈りの効果は、モデルの構造やサイズ、適用方法によって異なります
  • 実行速度の向上は、枝刈りの程度や対象レイヤーによって変わります
  • 過度の枝刈りはモデルの性能低下を招く可能性があるため、適切なバランスが必要です
6
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?