1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

YOLOv8モデルの重みを編集して検出結果の変化を調べる

Posted at

モチベーション

YOLOv8(detection)のスコアに繋がる重みを書き換えて検出結果がどう変化するかを調べる。

YOLOv8の構造

Detectモジュール内の Cls.Loss がクラスごとのスコアに関係していると推測

サイズnの場合 w=0.25, r=2.0 のため Detectの手前の層の形状は

  • P3: 80x80x256xw = 80x80x64
  • P4: 40x40x512xw = 40x40x128
  • P5: 20x20x512xwxr = 20x20x256

image.png

モデルの構造を表示してみる

>>> import ultralytics
>>> model = ultralytics.YOLO('yolov8n.pt')
>>> model
>>> model
YOLO(
  (model): DetectionModel(

 ~~~ 中略 ~~~
 
      (22): Detect(
        (cv2): ModuleList(
        ~~~ 中略 ~~~
        )
        (cv3): ModuleList(
          (0): Sequential(
            (0): Conv(
              (conv): Conv2d(64, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
              (act): SiLU(inplace=True)
            )
            (1): Conv(
              (conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
              (act): SiLU(inplace=True)
            )
            (2): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): Sequential(
            (0): Conv(
              (conv): Conv2d(128, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
              (act): SiLU(inplace=True)
            )
            (1): Conv(
              (conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
              (act): SiLU(inplace=True)
            )
            (2): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1))
          )
          (2): Sequential(
            (0): Conv(
              (conv): Conv2d(256, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
              (act): SiLU(inplace=True)
            )
            (1): Conv(
              (conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
              (act): SiLU(inplace=True)
            )
            (2): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1))
          )
        )
        (dfl): DFL(
          (conv): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
      )
    )
  )
)


(22):Detect: の (cv3):の Sequential のブロック3個が順にP3,P4,P5 からのそれぞれの出力を受け取り 80クラスのスコアを出力する。

image.png

パラメータ名の確認

>>> for i,x in enumerate(model.named_parameters()):
...   i,x[0],x[1].shape
(0, 'model.model.0.conv.weight', torch.Size([16, 3, 3, 3]))
(1, 'model.model.0.bn.weight', torch.Size([16]))

~~~ 中略 ~~~

(159, 'model.model.22.cv3.0.0.conv.weight', torch.Size([80, 64, 3, 3]))
(160, 'model.model.22.cv3.0.0.bn.weight', torch.Size([80]))
(161, 'model.model.22.cv3.0.0.bn.bias', torch.Size([80]))
(162, 'model.model.22.cv3.0.1.conv.weight', torch.Size([80, 80, 3, 3]))
(163, 'model.model.22.cv3.0.1.bn.weight', torch.Size([80]))
(164, 'model.model.22.cv3.0.1.bn.bias', torch.Size([80]))
(165, 'model.model.22.cv3.0.2.weight', torch.Size([80, 80, 1, 1]))
(166, 'model.model.22.cv3.0.2.bias', torch.Size([80]))
(167, 'model.model.22.cv3.1.0.conv.weight', torch.Size([80, 128, 3, 3]))
(168, 'model.model.22.cv3.1.0.bn.weight', torch.Size([80]))
(169, 'model.model.22.cv3.1.0.bn.bias', torch.Size([80]))
(170, 'model.model.22.cv3.1.1.conv.weight', torch.Size([80, 80, 3, 3]))
(171, 'model.model.22.cv3.1.1.bn.weight', torch.Size([80]))
(172, 'model.model.22.cv3.1.1.bn.bias', torch.Size([80]))
(173, 'model.model.22.cv3.1.2.weight', torch.Size([80, 80, 1, 1]))
(174, 'model.model.22.cv3.1.2.bias', torch.Size([80]))
(175, 'model.model.22.cv3.2.0.conv.weight', torch.Size([80, 256, 3, 3]))
(176, 'model.model.22.cv3.2.0.bn.weight', torch.Size([80]))
(177, 'model.model.22.cv3.2.0.bn.bias', torch.Size([80]))
(178, 'model.model.22.cv3.2.1.conv.weight', torch.Size([80, 80, 3, 3]))
(179, 'model.model.22.cv3.2.1.bn.weight', torch.Size([80]))
(180, 'model.model.22.cv3.2.1.bn.bias', torch.Size([80]))
(181, 'model.model.22.cv3.2.2.weight', torch.Size([80, 80, 1, 1]))
(182, 'model.model.22.cv3.2.2.bias', torch.Size([80]))
(183, 'model.model.22.dfl.conv.weight', torch.Size([1, 16, 1, 1]))

アーキテクチャの図と照らし合わせて以下の重みを書き換えれば良いと推測。
ただしモデルの構造で見た時とは次元の順序が逆になっており、クラス数に対応する部分は最も左側になっている。

(165, 'model.model.22.cv3.0.2.weight', torch.Size([80, 80, 1, 1]))
(166, 'model.model.22.cv3.0.2.bias', torch.Size([80]))
(173, 'model.model.22.cv3.1.2.weight', torch.Size([80, 80, 1, 1]))
(174, 'model.model.22.cv3.1.2.bias', torch.Size([80]))
(181, 'model.model.22.cv3.2.2.weight', torch.Size([80, 80, 1, 1]))
(182, 'model.model.22.cv3.2.2.bias', torch.Size([80]))

重みの書き換え

クラスに直結する層は3個のSequentialブロックの最後 (2): Conv2D と推測。
P3,P4,P5からの違いを見る為に以下のように書き換える。

  • P3 --> (0):Sequential: 書き換えない
  • P4 --> (1):Sequential: 1ずつずらす
  • P5 --> (2):Sequential: 5ずつずらす
modify_weight.py
import ultralytics

model = ultralytics.YOLO('yolov8n.pt')

for name,param in model.named_parameters():
	# P4 -> (1):Sequential : 1ずつずらす
	if name=='model.model.22.cv3.1.2.weight':
		p2 = param.clone()
		param[1:80,:] = p2[0:(80-1),:]
		param[0:1,:] = p2[(80-1):80,:]
	elif name=='model.model.22.cv3.1.2.bias':
		p2 = param.clone()
		param[1:80] = p2[0:(80-1)]
		param[0:1] = p2[(80-1):80]
  
	# P5 -> (2):Sequential : 5ずつずらす
	elif name=='model.model.22.cv3.2.2.weight':
		p2 = param.clone()
		param[5:80,:] = p2[0:(80-5),:]
		param[0:5,:] = p2[(80-5):80,:]
	elif name=='model.model.22.cv3.2.2.bias':
		p2 = param.clone()
		param[5:80] = p2[0:(80-5)]
		param[0:5] = p2[(80-5):80]

model.save('yolov8n_shift.pt')

推論してみる

predict.py
import ultralytics

model = ultralytics.YOLO('yolov8n.pt')
model.predict('bird.jpg', save=True)

model2 = ultralytics.YOLO('yolov8n_shift.pt')
model2.predict('bird.jpg', save=True)

変更前の重みによる推論結果
9127680850_3ff5c884c6_z.jpg

変更後の重みによる推論結果
9127680850_3ff5c884c6_z.jpg

クラス番号と意味は以下の通り(抜粋)。

クラス番号 意味
14 'bird'
15 'cat'
19 'cow'
33 'kite'
34 'baseball bat'

上記から検出クラスの変化によりP3,P4,P5のどこで検出されたかがわかる。

検出クラスの変化 どこで検出したか
bird -> bird (変化なし) P3
bird -> cat (1ずれる) P4
bird -> cow (5ずれる) P5

検証環境

ultralytics         8.2.87
torch               2.4.0
torchvision         0.19.0

Python 3.9.17
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?