はじめに
PyTorchの2つの機械学習モデルの重みを平均してみました。
モデルのstate_dict
を取得することで、パラメータに直接計算を加えることができます。
環境
- Python 3.9.5
- torch 1.9.0+cu111
- torchvision 0.10.0+cu111
準備
モデル同士の足し算と、モデルの定数倍を行う関数を定義。
モデル同士の足し算
sum_model.py
def sum_model_params(modelA, modelB):
""" modelA + modelB """
sdA = modelA.state_dict()
sdB = modelB.state_dict()
for key in sdA:
sdB[key] = (sdB[key] + sdA[key])
modelB.load_state_dict(sdB)
return modelB
モデルの定数倍
milti_model.py
def multi_model_params(model, a):
""" a * model """
sd = model.state_dict()
for key in sd:
sd[key] = sd[key] * a
model.load_state_dict(sd)
return model
2つのモデルの平均
average_model.py
# C = AとBの平均
C = multi_model_params(sum_model_params(A, B), 1/2)
足し算と定数倍で、平均以外にもある程度の計算が可能になると思います。