13
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

PyTorch向けのデータセットをChainerから使ってみる

Posted at

tl;dr

  • 各フレームワークにおけるデータセットの提供状態について調べた。
    • 自然言語処理はPyTorchが、化学系はChainerがそれぞれかなり優れている
    • Tensorflowは貧弱
  • PyTorchのデータセットはChainerでも使えないのかを試した。
    • PyTorch-NLPは最小限の変更でChainerでも使うことができた

はじめに

In machine learning and deep learning we can’t do anything without data.
---- Jeremy Howard (fast.ai) and Jed Sundwall (Open Data Global Lead, AWS)1

応用機械学習系の研究開発においてデータに素早くアクセスできることは大きなメリットになります。近年ではデータへのアクセスを高めるために、様々なサービスが現れてきました:

  • AWS Amazon Dataset Program: AWSでの研究を簡単にするためにS3上でデータセットをホスティング
  • fast.ai: すぐデータセットを使えるようにデータセットを代表的な整理・ホスティング
  • Kaggle/CodaLab: コンペをホスティング。データセットもすぐ扱えるよう提供。

とはいえ、データセットのデータ形式はまちまちで、データローダを書くのもひと手間です。そのときに、Chainerのchainer.datasets.get_mnistようにデータが一発ではいるようなコマンドがフレームワークがあると、実験を行うのに非常に魅力的です。実験が再現しやすくなるというメリットもあるでしょう。

そこで、本記事では、各フレームワークにおけるデータセットの提供状態について調べました。また、Chainer以外のフレームワークで活用されているデータセットをChainerで活用する例を2つ紹介します。

各フレームワークにおけるデータセットの提供状況

調査対象

フレームワーク名(Chainer, Tensorflow, PyTorch)+"dataset"で適当にググり、ライブラリを見つけました。

ライブラリ名 URL バージョン 公式
Chainer https://chainer.org/ 5.0.0 o
ChainerCV https://chainercv.readthedocs.io/en/stable/ 0.11.0 o
Chainer Chemistry https://chainer-chemistry.readthedocs.io/en/stable/index.html 0.4.0 o
torchvision https://pytorch.org/docs/stable/torchvision/index.html 0.2.0 o
torchtext https://torchtext.readthedocs.io/en/latest/ 0.4.0 o
Pytorch-NLP https://pytorchnlp.readthedocs.io/en/latest/ 0.3.0
AllenNLP https://allennlp.org/ 0.7.2
tensorflow_datasets https://github.com/tensorflow/datasets 0.0.1 o
fast.ai https://www.fast.ai/ 2018/12/9時点

なお、torchvision、torchtext、Pytorch-NLP、AllenNLPは全てPyTorch系です。fast.aiはzipでデータを配布しているので、特に何系というのはないと思われます(ライブラリ本体はPyTorch系です)。調査結果に示す"xx系"はこの基準に則っています。

調査結果

画像系

データセット名 Chainer Chainer examples ChainerCV Chainer系 torchvision PyTorch系 tensorflow_datasets fast.ai
<画像分類>
MNIST o o o o o o
Fashion-MNIST o o o o o
EMNIST o o
CIFAR-10 o o o o o o
CIFAR-100 o o o o
ILSVRC2012 (ImageNet) o o o o
SVHN o o o o
Caltech-UCSD Bird o o o
Caltech 101 o
Oxford-IIIT Pet o
Oxford 102 Flowers o
Food-101 o
Stanford cars o
STL10 o o
CMP Facade o o
LSUN (classification) o o
Diabetic Retinopathy Detection o
<物体検出>
Pascal VOC o
COCO detection o o
<セグメンテーション>
ADE20K (Image segmentation) o o
CamVid (Object detection in video) o o o
Cityscapes o o

ChainerCVのおかげかChainerが善戦しています。fast.aiは他のライブラリにはないデータを取り揃えているのが面白いですね。

言語系

データセット名 Chainer Chainer examples Chainer系 torchtext PytorchNLP AllenNLP PyTorch系 tensorflow_datasets fast.ai
<言語モデル>
PTB (単語のみ) o o o o o o
Wikitext-103 o o o
WikiText-2 o o o o
テキスト分類
SST o o o o o o
IMDb o o o o o o
DBPedia Ontology o o o
TREC Question Classification o o o o o
Customer Review o o
MPQA Optinion o o
Scale Movie o o
Subjectivity o o
AG News o
Amazon Review o
Sogou news o
Yahoo! Answers o
Yelp Reviews o
<機械翻訳>
WMT14 En-De o o o
WMT15 (En-Fr) o o o
IWSLT17 o o o
<構文・意味解析>
CoNLL 00 o o o
CoNLL 03 o o
OntoNotes 5.0 NER o o
PTB o o
CCGbank o o
Universal Dependencies o o
Universal Dependencies (POS) o o
SemEval 2015 Task 18 o o
ATIS o o
NLVR o o
text2sql o o
WikitableQuestions o o
<質疑応答>
SquAD o o
TriviaQa o o
QuAC o o
SimpleQuestions o o
<その他>
SNLI o o o o
bAbI o o o o
Quora Paraphrase o o
Event2Mind o o

NLP系はPyTorch勢が強いです。構文・意味解析系のデータセットはChainerには1つもありません。

マルチモーダル、その他

| データセット名 | Chainer | Chainer examples | Chainer Chemistry | Chainer系 | torchvision | torchtext | PytorchNLP | PyTorch系 | tensorflow_datasets | fast.ai |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| マルチモーダル、その他 | | | | | | | | | | |
| COCO captions | | o | | o | o | | | o | | |
| Multi30K | | | | | | o | o | o | | |
| Photo Tourism (2D<->3D) | | | | | o | | | o | | |
| CSTR VCTK | | | | | o | | | o | | |
| BAIR robot pushing | | | | | | | | | o | |
| Tox21 | | o | | o | | | | | | |
| QM9 | | o | | o | | | | | | |
| MoleculeNet | | o | | o | | | | | | |

ここらへんはニッチなのであまり気にすることはないかもしれません。化学系はChainerの独壇場です。

総括

自然言語処理はPyTorchが、化学系はChainerがそれぞれかなり優れています。そして、意外にも貧弱なのがTensorflowです。TensorFlowは(短期の)研究よりも、開発向けのフレームワークだからでしょうか。

PyTorchは非公式のライブラリが非常に強力なのが印象的です。Chainerは基本的にPFNだけの努力に支えられるような形になっています。

他フレームワークのデータセットをChainerから使う

さて、Chainerは自然言語処理のデータセットにおいてPyTorchに大きく間を開けられていることがわかりました。PyTorchとChainerは記法などがほとんど同じなので、PyTorchのデータセットはChainerでも使えるのではないかと考え、試してみました。

PyTorch-NLP

Chainerにはない、SNLI (Stanford Natural Language Inference)のデータセットを読み込んでみます。

とりあえずpytorch-nlpを入れます。PyTorchに依存してしまっているため、PyTorch自体も必要です。そのため、(OS環境に)gfortranなどが用意されている必要があります。

pip install pytorch-nlp==0.3.7.post1
torch==0.4.1

PyTorch-NLPのexampleの頭の部分を持ってきています。ほとんどそのままです。

import os

from torchnlp.datasets import snli_dataset
from torchnlp.utils import datasets_iterator
from torchnlp.text_encoders import TreebankEncoder, IdentityEncoder


# load dataset
train, dev, test = snli_dataset(train=True, dev=True, test=True)

# Preprocess
for row in datasets_iterator(train, dev, test):
    row['premise'] = row['premise'].lower()
    row['hypothesis'] = row['hypothesis'].lower()

# Make Encoders
sentence_corpus = [row['premise'] for row in datasets_iterator(train, dev, test)]
sentence_corpus += [row['hypothesis'] for row in datasets_iterator(train, dev, test)]
# ExampleではWhitespaceEncoder (空白で単語分割)を使っていたが、せっかくなのでちゃんとしたトークナイザを使う
sentence_encoder = TreebankEncoder(sentence_corpus)

label_corpus = [row['label'] for row in datasets_iterator(train, dev, test)]
label_encoder = IdentityEncoder(label_corpus)

# Encode
for row in datasets_iterator(train, dev, test):
    row['premise'] = sentence_encoder.encode(row['premise']).numpy()
    row['hypothesis'] = sentence_encoder.encode(row['hypothesis']).numpy()
    row['label'] = label_encoder.encode(row['label']).numpy()

print(train[0]) # -> torchnlp.datasets.dataset.Dataset

これを実行すると、自動でトークナイザに必要なデータやデータセットをダウンロードしてくれます。
実行結果から、trainの1データが以下のようなdictであることがわかります。

{'premise': array([ 5,  6,  7,  5,  8,  9, 10,  5, 11, 12, 13, 14]), 'hypothesis': array([   5,    6,   22, 3117,   36,    8,  153,    5, 1434,   14]), 'label': array([5]), 'premise_transitions': ['shift', 'shift', 'shift', 'shift', 'shift', ...

trainは__getitem____len__をサポートしているため、これはChainerのDictDatasetと同等です。
したがって、(色々余計なデータもついていますが)これをこのままChainerのIteratorにいれても動きます。

これをPyTorch用に使う場合との差異は以下のとおりです。

 # Encode
 for row in datasets_iterator(train, dev, test):
-    row['premise'] = sentence_encoder.encode(row['premise'])
+    row['premise'] = sentence_encoder.encode(row['premise']).numpy()
-    row['hypothesis'] = sentence_encoder.encode(row['hypothesis'])
+    row['hypothesis'] = sentence_encoder.encode(row['hypothesis']).numpy()
-    row['label'] = label_encoder.encode(row['label'])
+    row['label'] = label_encoder.encode(row['label']).numpy()
 
 print(train[0])

つまり、pytorch-nlpのデータセットを使いたい場合は、PyTorchの場合と同様にデータセットを作り、そのデータセットの各項目に対して .numpy()を実行してゆくだけで動くことになります。

AllenNLP

やはりChainerにはない、Quora Paraphraseデータセットを読み込んでみます。

AllenNLPを導入します。AllenNLPでも同様にPyTorchが必要です(自動的に依存関係を解決してくれます)。

pip install allennlp==0.7.2

AllenNLPは自動的にデータダウンロードをしてくれないので、データを自分で準備します。なお、今回対象としたQuora Paraphraseデータセットにおいて、AllenNLPは(なぜか)ファイルの形式変更を要求しています。

# Check lisence at https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs
wget http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv
awk 'BEGIN {FS="\t"; OFS="\t"} {print $6, $4, $5, $1}' quora_duplicate_questions.tsv | tail -n +2 > quora_duplicate_questions_allennlp.tsv

非公式の投稿を参考にデータロードを行います。

import allennlp
from allennlp.data import Vocabulary
from allennlp.data.fields import TextField, LabelField
from allennlp.data.dataset_readers import QuoraParaphraseDatasetReader


reader = QuoraParaphraseDatasetReader()
# train_dataset is list of Instance
train_dataset = reader.read('quora_duplicate_questions_allennlp.tsv')
vocab = Vocabulary.from_instances(train_dataset, min_count={'tokens': 3})

# index each data
for data in train_dataset:
    data.index_fields(vocab)

def extract_field(instance):
    ret = {}
    for field_key, field in instance.fields.items():
        field_value = field.as_tensor(field.get_padding_lengths())
        if isinstance(field, TextField):
            ret.update({field_key + '_' + k: v.numpy() for k, v in field_value.items()})
        elif isinstance(field, LabelField):
            ret[field_key] = field_value.numpy()
        else:
            raise NotImplementedError('Parse rule for Field "%s" is not implemented' % type(field))
    return ret

train_dataset = [extract_field(data) for data in train_dataset]
print(train_dataset[0])

これを実行することで、以下のようにDictDatasetと互換性がある出力を得ることができます。

{'premise_tokens': array([   3,    4,    2, 1413,   56, 1413, 3315,    7,  515,    9,  751,
         598,    9,  782]),
 'hypothesis_tokens': array([   3,    4,    2, 1413,   56, 1413, 3315,    7,  515,    9,  751,
         905]),
 'label': array(0)}

AllenNLPにおいては、データ (上記train_datasetに相当)は複数のInstanceから構成されており、Instanceは複数のFieldから構成されています。Fieldの挙動は種類によって違うため、他のデータセットを使うためにはもう少し変更が必要かもしれません。

AllenNLPは、ドキュメンテーションが乏しい (具体的なデータの入手先が記載されていない)、使用例がほとんどない、などAllenNLP自体に使いにくさがあります。今後AllenNLPの使用感がよくなれば、Chainerのデータセットへの変換スクリプトを書くメリットがうまれるかなと思います。

  1. https://www.fast.ai/2018/10/16/aws-datasets/

13
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?