手順の概要
- 環境
- 問題点の確認
- Config内のdataを書き換え
環境
mmcv v1.7.1
mmdetection v2.28.2
問題点の確認
mmdetection v2.25.2以降では,https://github.com/open-mmlab/mmdetection/pull/8575 において,tools/train.py
が書き換えられた.その影響でこれまで使っていたConfigファイルから学習させようとしたときに,以下のようなエラーがでて動かない場合がある.
Traceback (most recent call last):
File "/mmdetection/tools/train.py", line 248, in <module>
main()
File "/mmdetection/tools/train.py", line 225, in main
print(cfg.data.train.dataset)
File "/mmcv/mmcv/utils/config.py", line 50, in __getattr__
raise ex
AttributeError: 'ConfigDict' object has no attribute 'dataset'
https://github.com/open-mmlab/mmdetection/pull/8575 のコメントにある通り,エラーが発生して動かないという文句が書いてある.今回は,その解決方法を書く.
もちろん,この解決方法が気に入らない場合は,mmdetection v2.25.1のコードを使うと良いと思う.(未検証なので動作保証はできません)
Config内のdataを書き換え
https://github.com/open-mmlab/mmdetection/pull/8575 にある通り,tools/train.py
はRepeatDataset
などに対応するように変更されたらしい.そこで,Config内のdataを書き換えることで解決を試みた.以下がその例である.data変数のtrainキーにtimes=1
のRepeatDataset
を適用した.RepeatDataset
に対応させたということなので,RepeatDataset
を適用するという方法である.
# 変更前
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# 変更後
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train = dict(
type = 'RepeatDataset',
times = 1,
dataset = dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline)
),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))