0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【PyTorch】AttributeError: 'GeneralizedRCNNTransform' object has no attribute 'fixed_size'

Last updated at Posted at 2021-11-27

TL;DR

GPU環境で学習したモデルをtorch.saveで保存し、CPU環境でtorch.loadで読み出し推定を行うとエラーが発生する場合がる。

解決方法
torch.loadではなく各種学習モデルのクラスからインスタンス化してload_state_dictで学習したパラメータを読み出す。

詳細

[PyTorch] TORCHVISION OBJECT DETECTION FINETUNING TUTORIALを基に物体検出を試す

環境

CPU側

  • anaconda3 Python 3.8.10
  • pytorch 1.10.0-py3.8_cpu_0
  • torchvision 0.11.1-py38_cpu

原因

torch.save, torch.loadでモデルを保存・読み込みすると、__init__()経由でインスタンス化されずNoneのクラスメンバが未定義となってしまう。結果としてAttributeErrorで落ちる。

train.py
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# モデルの定義
device = torch.device("cuda")
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)

# 学習処理
# ~中略~

# モデルの保存
torch.save(model, "path/to/model.pth")
torch.save(model.state_dict(), "path/to/model_weight.pth")
predict.py
import torch
from PIL import Image
from torchvision.transforms import functional as F

device = torch.device('cpu')
model = torch.load("path/to/model.pth", map_location=device)
model.to(device)
model.eval()

img = Image.open("path/to/image").convert("RGB")
img = F.to_tensor(img)
imgs = [img.to(device)]
with torch.no_grad():
     out = model(imgs)
     print(out[0]["boxes"].numpy())

修正

predict.py
- model = torch.load("path/to/model.pth", map_location=device)
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
+ model.load_state_dict(torch.load("path/to/model_weight.pth", map_location=device))
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?