0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

機械学習モデルから毎回違うサンプルが取り出されてしまうときの対処法

Last updated at Posted at 2024-10-09

きっかけ

卒業研究の一環で読唇(lip reading)の推論済みモデルをgithubからクローンして動かしていました。ミニバッチ学習時は問題なかったのですが、個別に推論をしたところ、毎回違う出力結果が出てしまっておりました。

考えるべきこと

1. Dataloaderの構造

以下のようにshuffle=Trueになっていると、Dataloaderはデータセット全体からbatch_sizeだけサンプルをランダムに取ってきます。そうすると1つのデータを推論したいとき、毎回違うサンプルで行われてしまいます。

dset_loaders = {x: torch.utils.data.DataLoader(
                        dsets[x],
                        batch_size=args.batch_size,
                        shuffle=True,
                        collate_fn=pad_packed_collate,
                        pin_memory=True,
                        num_workers=args.workers,
                        worker_init_fn=np.random.seed(1)) for x in partitions}

2. random_seedの有無

以下のようにrandom_seedの記載が無いと、推論結果が毎回異なってしまいます。

torch.manual_seed(1)
np.random.seed(1)
random.seed(1)
torch.backends.cudnn.benchmark = True

3. glob.globの使用

os.path.joinメソッドを使用する際、globを用いると、データがランダムな順序で渡されます。以下のようにsort()メソッドを使用することで解決されます。

        search_str_mp4 = os.path.join(dir_fp, '*', self._data_partition, '*.mp4')
        self._data_files.extend( glob.glob( search_str_mp4 ) )
        
        #glob.glob()によるファイル検索の際に、OSによってファイルがランダムな順序で返される
        self._data_files.sort()
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?