5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

torchvision.modelsのpretrainedが非推奨になっていた

Last updated at Posted at 2022-11-20

有名どころのモデルの実装と学習済みの重みを1行で取得できるtorchvision.modelsモジュールは大変便利で、特に転移学習の際に重宝します。

上記チュートリアルではバックボーンネットワークとしてResNet18を使用しています。pretrainedをTrueにすることで、自動で学習済みの重みがDLされ、そのまま使うもよし、最終層を弄って転移学習に用いるもよしです。

from torchvision import models
model_ft = models.resnet18(pretrained=True)

ただし、torchvisionのバージョンが0.13以降の場合、以下のような警告が出るようになりました。

/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)

pretained引数は0.13以降、非推奨になっており、0.15で削除される予定だそうです。
多くのサイトでは上記のコードで解説されていますので、そのうちアップデートが必要になるでしょう。新しいクラス仕様をチェックしたいと思います。

まずは修正

親切な警告文の通りソースを修正しましょう。
アップデートに伴い、DLできる重みも細分化されました。

pretrainedと同じ結果を得たい場合は、resnet18の引数にweights=ResNet18_Weights.IMAGENET1K_V1またはweights=ResNet18_Weights.DEFAULTを与えます。

非推奨の引数であるpretrainedを用いている場合、互換用のデコレータを介し、weights=ResNet18_Weights.IMAGENET1K_V1を引数にとったresnet18がインスタンス化されていました。
今後、この互換用のデコレータが除去されるハズですので、古い記述を見たら、以下の新しい書き方に脳内変換しましょう。

from torchvision import models
# pretrained=Trueと同等のモデルを呼び出したい!
# 1. 明示的に重みを指定
model_ft = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
# 2. DEFAULTでもIMAGENET1K_V1を指す
model_ft = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

これで警告が消えました。
名称からもわかる通り、重みの中身はこれまで(~0.13)と同じくImageNetを用いたモノです。

定義済みの前処理(transforms)を利用する

学習済みモデルの利用に関する機能変更は、警告文で示されている内容以外にも存在します。

ResNet18_Weights.IMAGENET1K_V1.transformsで推論時に用いるtransformを呼び出すことができます。学習時の実装をのぞくことなく、学習時と同様の前処理を簡単かつ正確に行うことができるようになりました。

from PIL import Image

trnsfrm = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
img = Image.new("RGB", (320, 240), (0, 128, 255))
tnsr = trnsfrm(img)

その他の学習済みモデル

全てのモデルを確認してはいませんが、他の学習済みモデルでも同様の仕様になっているようです。

例えば、転移学習でよく使用されるVGG16ではpretrainedに代わり、以下の引数が用意されています。

  • VGG16_Weights.DEFAULT: デフォルトを指定した場合は、VGG16_Weights.IMAGENET1K_V1を指します。
  • VGG16_Weights.IMAGENET1K_V1: ImageNetを用いてトレーニングされた重みパラメータ
  • VGG16_Weights.IMAGENET1K_FEATURES: classifierモジュール(全結合部分)の重みパラメータが無効値(nan)の重みパラメータ。分類問題には使用できません。

VGG16_Weights.IMAGENET1K_FEATURESは全結合部分の重みパラメータが存在しません。名前からしても、特徴量マップのみ使用したい場合に使うのでしょうか。

その他のモデルついては、使用するモデルのドキュメントを参照し、正確な引数を調べてみてください。

5
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?