背景
IMDBは以下のサイトで、torchtextでは、このサイトから映画のレビューのポジネガの二値分類のデータセットが得られます。
python機械学習プログラミングという本があります。
この本を読み進めていると、15章で、torchtextを利用しますが、torchtextで以下のようなエラーが出ます。
OSError Traceback (most recent call last)
<ipython-input-14-59f0fbcb2fc8> in <cell line: 1>()
----> 1 import torchtext
5 frames
/usr/lib/python3.10/ctypes/__init__.py in __init__(self, name, mode, handle, use_errno, use_last_error, winmode)
372
373 if handle is None:
--> 374 self._handle = _dlopen(self._name, mode)
375 else:
376 self._handle = handle
OSError: /usr/local/lib/python3.10/dist-packages/torchtext/lib/libtorchtext.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKSs
これはpytorchとversionがあっていないことが原因のようです。versionを落としても、その他もろもろのライブラリにまで影響してしまい、結局自分でデータセットを作成しました。自作といっても、データセットをpytorchで使える形に直すだけです。
データセットの自作
まず、以下のサイトからzipでデータを入手します。
https://ai.stanford.edu/~amaas/data/sentiment/
これを、[(label, text), (label, text), ...]という形式で読み込みます。今回は, google colaboratoryで実行しているため、content直下に, aclImdb ディレクトリを作成しています。
import os
def read_imdb(neg_data_dir: str="/content/aclImdb/train/neg", pos_data_dir: str="/content/aclImdb/train/pos"):
# データを格納するリスト
dataset = []
directory_paths = [neg_data_dir, pos_data_dir]
for i, directory_path in enumerate(directory_paths):
# ディレクトリ内の *.txt ファイルを読み込む
for filename in os.listdir(directory_path):
if filename.endswith(".txt"): # .txtファイルをフィルタリング
file_path = os.path.join(directory_path, filename)
with open(file_path, 'r', encoding='utf-8') as file:
dataset.append((i, file.read()))
return dataset
train_dataset = read_imdb()
test_dataset = read_imdb("/content/aclImdb/test/neg", "/content/aclImdb/test/pos")
次に、 tokenizerを以下の様に定義して, train_datasetの中身から{単語: 出現回数}という辞書を作成します。
import re
def tokenizer(text):
text = re.sub('<[^>]*>', '', text)
emoticons = re.findall('(?::|;|=)(?:-)?(?:\)|\(|D|P)', text.lower())
text = re.sub('[\W]+', ' ', text.lower()) + ' '.join(emoticons).replace('-', '')
tokenized = text.split()
return tokenized
from collections import Counter
counter = Counter()
for (label, line) in train_dataset:
counter.update(tokenizer(line))
print(len(counter))
次に, トークンのインデックスを作成します. paddingとunknown tokenように, index 0, 1は開けておきます。また、 unknown tokenがindex 1であるので、defaultdictを使用して, keyに存在しない単語の場合は1を返すようにします.
from collections import defaultdict
def make_vocab(counter: dict):
i = 2
vocab = defaultdict(lambda: 1)
vocab['<pad>'] = 0
vocab['<unk>'] = 1
for k, v in counter.items():
if v > 5:
vocab[k] = i
i += 1
return vocab
vocab = make_vocab(counter)
あとは, pytorchのdataloaderに読み込ませるのみです。ついでにtrainとvalidもrandom_splitで分割します。以下のコードは、python機械学習プログラミングの本からほとんど引用しています。
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch.nn as nn
text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]
label_pipeline = lambda x: float(x)
def collate_batch(batch):
text_list, label_list, lengths = [], [], []
for _label, _text in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(processed_text)
lengths.append(processed_text.size(0))
label_list = torch.tensor(label_list)
lengths = torch.tensor(lengths)
text_list = nn.utils.rnn.pad_sequence(text_list, batch_first=True)
return text_list, label_list, lengths
train_dataset, valid_dataset = random_split(train_dataset, [int(len(train_dataset) * 0.8), len(train_dataset) - int(len(train_dataset) * 0.8)])
train_dl = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_batch)
valid_dl = DataLoader(valid_dataset, batch_size=64, shuffle=True, collate_fn=collate_batch)
def check_traindl(train_dl):
for x, y, z in train_dl:
print(x.shape)
print(y.shape)
print(z.shape)
print(x[0])
break
check_traindl(train_dl)
出力結果が以下の様にtokenになっていて、paddingされていたら、正しい結果です。
torch.Size([64, 690])
torch.Size([64])
torch.Size([64])
tensor([ 78, 186, 922, 4907, 19, 70, 871, 11662, 4061, 17,
557, 5798, 28, 17, 40, 6848, 26, 865, 28, 60,
192, 1789, 387, 122, 237, 4114, 3166, 471, 859, 448,
588, 26, 208, 2, 109, 141, 34, 141, 51, 3972,
208, 106, 221, 17, 7978, 4202, 430, 33, 4, 5555,
28, 319, 60, 192, 80, 170, 46, 618, 28, 25,
45, 139, 74, 1105, 16, 19, 1347, 28, 3456, 28,
17, 656, 26, 1938, 4647, 49, 846, 28, 55, 41,
1077, 13787, 41, 46, 3023, 17, 618, 4231, 48, 28,
17, 618, 4231, 45, 43, 14, 6959, 636, 296, 8484,
51, 125, 470, 126, 376, 11662, 78, 871, 78, 34,
121, 28, 636, 296, 8484, 60, 139, 30, 10831, 656,
247, 154, 1077, 30, 53, 60, 1225, 70, 511, 534,
1415, 14, 5798, 28, 557, 14, 11662, 24, 60, 263,
1225, 70, 1378, 534, 5057, 46, 846, 33, 324, 2739,
28, 1023, 45, 139, 41, 204, 53, 8, 16, 46,
346, 246, 2, 53, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0])