有名どころのモデルの実装と学習済みの重みを1行で取得できるtorchvision.models
モジュールは大変便利で、特に転移学習の際に重宝します。
上記チュートリアルではバックボーンネットワークとしてResNet18を使用しています。pretrained
をTrueにすることで、自動で学習済みの重みがDLされ、そのまま使うもよし、最終層を弄って転移学習に用いるもよしです。
from torchvision import models
model_ft = models.resnet18(pretrained=True)
ただし、torchvisionのバージョンが0.13以降の場合、以下のような警告が出るようになりました。
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
pretained
引数は0.13以降、非推奨になっており、0.15で削除される予定だそうです。
多くのサイトでは上記のコードで解説されていますので、そのうちアップデートが必要になるでしょう。新しいクラス仕様をチェックしたいと思います。
まずは修正
親切な警告文の通りソースを修正しましょう。
アップデートに伴い、DLできる重みも細分化されました。
pretrainedと同じ結果を得たい場合は、resnet18
の引数にweights=ResNet18_Weights.IMAGENET1K_V1
またはweights=ResNet18_Weights.DEFAULT
を与えます。
非推奨の引数であるpretrained
を用いている場合、互換用のデコレータを介し、weights=ResNet18_Weights.IMAGENET1K_V1
を引数にとったresnet18
がインスタンス化されていました。
今後、この互換用のデコレータが除去されるハズですので、古い記述を見たら、以下の新しい書き方に脳内変換しましょう。
from torchvision import models
# pretrained=Trueと同等のモデルを呼び出したい!
# 1. 明示的に重みを指定
model_ft = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
# 2. DEFAULTでもIMAGENET1K_V1を指す
model_ft = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
これで警告が消えました。
名称からもわかる通り、重みの中身はこれまで(~0.13)と同じくImageNetを用いたモノです。
定義済みの前処理(transforms)を利用する
学習済みモデルの利用に関する機能変更は、警告文で示されている内容以外にも存在します。
ResNet18_Weights.IMAGENET1K_V1.transforms
で推論時に用いるtransformを呼び出すことができます。学習時の実装をのぞくことなく、学習時と同様の前処理を簡単かつ正確に行うことができるようになりました。
from PIL import Image
trnsfrm = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
img = Image.new("RGB", (320, 240), (0, 128, 255))
tnsr = trnsfrm(img)
その他の学習済みモデル
全てのモデルを確認してはいませんが、他の学習済みモデルでも同様の仕様になっているようです。
例えば、転移学習でよく使用されるVGG16ではpretrained
に代わり、以下の引数が用意されています。
- VGG16_Weights.DEFAULT: デフォルトを指定した場合は、
VGG16_Weights.IMAGENET1K_V1
を指します。 - VGG16_Weights.IMAGENET1K_V1: ImageNetを用いてトレーニングされた重みパラメータ
- VGG16_Weights.IMAGENET1K_FEATURES: classifierモジュール(全結合部分)の重みパラメータが無効値(nan)の重みパラメータ。分類問題には使用できません。
VGG16_Weights.IMAGENET1K_FEATURESは全結合部分の重みパラメータが存在しません。名前からしても、特徴量マップのみ使用したい場合に使うのでしょうか。
その他のモデルついては、使用するモデルのドキュメントを参照し、正確な引数を調べてみてください。