2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

機械学習でなにかやるときに気をつけたほうがいいこと&やったほうがいいことまとめ

Posted at

大学院での研究の関係で機械学習を使うことが多いのですが、コード実装のミスだったりやった方がいい小技を実装し忘れたりすることが多々あったので備忘録のためにもまとめておきました。

随時更新していきます。画像認識系のタスクを想定しており、Python3, Pytorchでの実装になります。

① Dataloaderの引数にあるnum_workersの値を変更する

PytorchのDataloaderを定義する際、num_workersの値を何も設定しないと2になりますが、ここの値の数を大きくすると学習の速度を大幅に向上させることができる。batch_sizeの値も合わせて大きくすると学習速度が大幅に向上する。

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = batch_size,
    shuffle=False,
    num_workers = num_workers, # <-ここの値を変更する
    pin_memory =True # <- ここを変えてもそこまで大きな影響は出ない...?
)

num_workersの値は学習に使用するCPUのスペックに合わせて調整するべきであるが、最低でも4か8には設定しておきたい。DGXのようなハードで学習を行う場合、筆者は40くらいに設定している。CPUの物理コア数ギリギリの数値に設定してしまうと学習が途中で不安定になったり強制終了しやすくなる。

②GPUが複数枚あるハードウェアで学習させたモデルをLoadする際には追加のコードを書く

DGXなど、GPUが複数枚あるデバイスで学習させた.pthファイルにはstate_dictのkeysに"module."が付け加えられている。このためload_state_dictでロードする際は以下のようなコードを追加で書かないといけない。

def fix_model_state_dict(self, state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if name.startswith('module.'):
            name = name[7:]  # remove 'module.' of dataparallel
        new_state_dict[name] = v
    return new_state_dict

def getNetwork(self, net):
    print("Loading Pretrained Network")
    pretrained_state_dict = torch.load(self.pretrained_weights_path)
    net.load_state_dict(self.fix_model_state_dict(pretrained_state_dict))
    net = net.to(self.device)

③ 推論を行うときはネットワークからの出力をcpuへ転送する

以下のような処理をしておかないと100回くらい推論したあたりで開放できなかった過去の推論結果のデータが溜まり、GPUメモリが満杯になり推論が強制終了する。

def prediction(self, input):
    output = self.net(input)
    output_np = output.to('cpu').detach().numpy().copy()

    return output_np

④ 学習で使うハイパーパラメータなどはyamlで管理して、学習スクリプトを実行するたびにyamlをコピーして保存しておく

自分が学習で使うハイパーパラメータをすべてyamlで管理するタイプなので、毎回学習するときには使用したyamlファイルを学習済みモデルと同じディレクトリに保存して

if __name__ == "__main__":
    parser = argparse.ArgumentParser("train.py")

    parser.add_argument(
        '--train_cfg', '-c',
        type=str,
        required=False,
        default='../pyyaml/train_config.yaml',
        help='Training configuration file'
    )

    FLAGS, unparsed = parser.parse_known_args()

    print("Load YAML file")

    try:
        print("Opening train config file %s", FLAGS.train_cfg)
        CFG = yaml.safe_load(open(FLAGS.train_cfg, 'r'))
    except Exception as e:
        print(e)
        print("Error opening train config file %s", FLAGS.train_cfg)
        quit()

    save_top_path = CFG["save_top_path"]
    yaml_path = save_top_path + "/train_config.yaml"
    shutil.copy(FLAGS.train_cfg, yaml_path)

⑤ Adam系のoptimizerで最適化を行いつつ、L2正則化を行う場合はスクラッチで実装する必要がある

Adam、AdamW、RAdamなどのoptimizerを使って最適化を行いつつ、L2正則化を用いて特徴選択を行いたい場合、optimizer側で設定するweight_decayのパラメータをいじっても正則化の効果が(自分の経験では)得られない。そのため、以下のようにL2正則化をスクラッチで実装する必要がある。

output = self.net(input)
output_loss = self.criterion(output, output_label) #損失計算
if self.device == 'cpu':
    l2norm = torch.tensor(0., requires_grad = True).cpu()
else:
    l2norm = torch.tensor(0., requires_grad = True).cuda()
for w in self.net.parameters():
    l2norm = l2norm + torch.norm(w)**2

#self.alphaは正則化のペナルティ項、学習の成功にかかわる重要パラメータ
total_loss = output_loss + self.alpha*l2norm 

Twitterを見る限り情報が錯綜しているが、PytorchのAdam系のoptimizerとweight_decayとの相性はあまり良くないとのこと。今後のアップデートでこの仕様が修正される可能性もあるため続報を待ちたいところではある。

⑥ GPUが複数枚あるハードで学習を行う際は追加のコードを書く必要がある

DGXなどGPUが複数枚あるデバイスで学習を行う場合にはDataParallelの設定を追加で行う必要がある。

if self.multiGPU == True and self.device == "cuda":
    net = nn.DataParallel(net)
    cudnn.benchmark = True
    print("Training on multiGPU Device")
else:
    cudnn.benchmark = True
    print("Training on Single GPU Device")
2
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?