ResNetの出力クラス数を変更したい
有名なResNetですが、Pytorchで簡単に使うことができます。しかし、出力次元数を変えるのに少し調べる必要があったので、備忘録としてまとめておきます。
PytorchのResNetのドキュメントを確認する
多くの人がPytorchのResNetのドキュメントを確認すると思いますが、ここには、出力のクラス数を変更できそうな記述がありません。そのため、途方に暮れてしまうわけですが、GitHubのソースコードへのリンクがあるので、このソースコードをもとに、ResNetの出力数を変更していきます。
出力クラス数を変更したモデルを作成する
上のソースコードを解析した結果以下のコードで、出力クラス数を変更したモデルを作成できました。何か間違いあったら教えてください!!(研究でも使っているので...)
from torchvision.models.resnet import ResNet,Bottleneck,BasicBlock
resnet = "resnet18"
if resnet=="resnet18":
network = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
num_classes=100)
elif resnet=="resnet34":
network = ResNet(
block=BasicBlock,
layers=[3, 4, 6, 3],
num_classes=100)
elif resnet=="resnet50":
network = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
num_classes=100)
elif resnet=="resnet101":
network = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
num_classes=100)
elif resnet=="resnet152":
network = ResNet(
block=Bottleneck,
layers=[3, 8, 36, 3],
num_classes=100)
else:
raise Exception("resnetの指定が間違っています")