LoginSignup
3
3

More than 5 years have passed since last update.

【fast.ai】 Basic Data API解説

Last updated at Posted at 2019-03-04

概要

本記事はfast.aiのwikiのBasic Dataページの要約となります。

筆者の理解した範囲内で記載します。

TrainingのためのDataを準備するための簡易API。
具体的には、Learnerモジュールに用いるData Bunchオブジェクトを用いる。

Data Bunch

DataBunch(train_dl:DataLoader, valid_dl:DataLoader, fix_dl:DataLoader=None, test_dl:Optional[DataLoader]=None, device:device=None, dl_tfms:Optional[Collection[Callable]]=None, path:PathOrStr='.', collate_fn:Callable='data_collate', no_check:bool=False)

Dataオブジェクトにtrain_dl valid_dl, test_dl(随意)を結びつける。(dlはdataloaderの省略形)

全てのdataloaderがdeviceに取り付けられていることと、tfmsでdata augmentationがなされていることを保証し、
collate_fnにてPyTorchのDataloaderにbatchごとのデータとファイル名の照合を促す。

なお、train_dl valid_dl test_dl(随意) はDeviceDataLoaderに包まれている。

create

create(train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None, path:PathOrStr='.', bs:int=64, val_bs:int=None, num_workers:int=4, dl_tfms:Optional[Collection[Callable]]=None, device:device=None, collate_fn:Callable='data_collate', no_check:bool=False, **dl_kwargs)  DataBunch

Data Bunchクラスをbs(batch size)でtrain_dl valid_dl test_dlより生成。

show_batch

show_batch(rows:int=5, ds_type:DatasetType=<DatasetType.Train: 1>, **kwargs)

指定したrowにてdataのbatchを表示。

dl

dl(ds_type:DatasetType=<DatasetType.Valid: 2>)  DeviceDataLoader

ds_typeにて指定されたvalidation,training,testのDatasetを返す。

one_batch

one_batch(ds_type:DatasetType=<DatasetType.Train: 1>, detach:bool=True, denorm:bool=True, cpu:bool=True)  Collection[Tensor]

1つのbatchをDataLoaderより持ってくる。

one_item

one_item(item, detach:bool=False, denorm:bool=False, cpu:bool=False)

itemをbatchへ持ってくる。

sanity_check

sanity_check()

sanity checkを行い、データを確認する。

save

save(fname:PathOrStr='data_save.pkl')

DataBunchself.path/fnameへ保存。

load_data

load_data(path:PathOrStr, fname:str='data_save.pkl', bs:int=64, val_bs:int=None, num_workers:int=4, dl_tfms:Optional[Collection[Callable]]=None, device:device=None, collate_fn:Callable='data_collate', no_check:bool=False, **kwargs)  DataBunch

DataBunchpath/fnameから読み込む。

PyTorchのDatasetとの互換性

PyTorchのDatasetとの互換性は部分的にサポートされている。詳しくはこちら

DeviceDataLoader

DeviceDataLoader(dl:DataLoader, device:device, tfms:List[Callable]=None, collate_fn:Callable='data_collate')

DataLoadertorch.deviceに結びつける。
tfmsを行った後にdlのバッチをdeviceへ結ぶ。全てのDataLoaderはこのタイプ。

create (with DeviceDataLoader)

create(dataset:Dataset, bs:int=64, shuffle:bool=False, device:device=device(type='cuda'), tfms:Collection[Callable]=None, num_workers:int=4, collate_fn:Callable='data_collate', **kwargs:Any)

shufflebsdatasetnum_workers用いてDeviceDataLoaderを生成。
collate_fnによって、1つのbatchへとサンプルを照合する。

shuffleを用いるとdataがシャッフルされ、tfmsはdata augmmentationに用いられる。

add_tfm

add_tfm(tfm:Callable)

self.tfmstfmを追加。

remove_tfm

remove_tfm(tfm:Callable)

self.tfmsよりtfmを削除。

new

new(**kwargs)

kwagsを用いてコピーを生成。

proc_batch

proc_batch(b:Tensor)  Tensor

TensorImageのbatchbを処理。

最後に

  1. basic_dataを用いて自作のdataloaderを作成してみる。
  2. PyTorchからの実際の移植作業をやってみたい。
  3. 自作のノートブックを用いて、実例の用法を紹介してみる。

間違いやご指摘などが御座いましたらご教示願います!

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3