0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

torchtextとpytorchのversionが合わないため、IMDBを自作する。

Last updated at Posted at 2024-12-07

背景

IMDBは以下のサイトで、torchtextでは、このサイトから映画のレビューのポジネガの二値分類のデータセットが得られます。

python機械学習プログラミングという本があります。

この本を読み進めていると、15章で、torchtextを利用しますが、torchtextで以下のようなエラーが出ます。

OSError                                   Traceback (most recent call last)
<ipython-input-14-59f0fbcb2fc8> in <cell line: 1>()
----> 1 import torchtext

5 frames
/usr/lib/python3.10/ctypes/__init__.py in __init__(self, name, mode, handle, use_errno, use_last_error, winmode)
    372 
    373         if handle is None:
--> 374             self._handle = _dlopen(self._name, mode)
    375         else:
    376             self._handle = handle

OSError: /usr/local/lib/python3.10/dist-packages/torchtext/lib/libtorchtext.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKSs

これはpytorchとversionがあっていないことが原因のようです。versionを落としても、その他もろもろのライブラリにまで影響してしまい、結局自分でデータセットを作成しました。自作といっても、データセットをpytorchで使える形に直すだけです。

データセットの自作

まず、以下のサイトからzipでデータを入手します。
https://ai.stanford.edu/~amaas/data/sentiment/

これを、[(label, text), (label, text), ...]という形式で読み込みます。今回は, google colaboratoryで実行しているため、content直下に, aclImdb ディレクトリを作成しています。

import os

def read_imdb(neg_data_dir: str="/content/aclImdb/train/neg", pos_data_dir: str="/content/aclImdb/train/pos"):
    # データを格納するリスト
    dataset = []
    directory_paths = [neg_data_dir, pos_data_dir]
    for i, directory_path in enumerate(directory_paths):
      # ディレクトリ内の *.txt ファイルを読み込む
      for filename in os.listdir(directory_path):
          if filename.endswith(".txt"):  # .txtファイルをフィルタリング
              file_path = os.path.join(directory_path, filename)
              with open(file_path, 'r', encoding='utf-8') as file:
                  dataset.append((i, file.read()))
    return dataset

train_dataset = read_imdb()
test_dataset = read_imdb("/content/aclImdb/test/neg", "/content/aclImdb/test/pos")

次に、 tokenizerを以下の様に定義して, train_datasetの中身から{単語: 出現回数}という辞書を作成します。

import re
def tokenizer(text):
    text = re.sub('<[^>]*>', '', text)
    emoticons = re.findall('(?::|;|=)(?:-)?(?:\)|\(|D|P)', text.lower())
    text = re.sub('[\W]+', ' ', text.lower()) + ' '.join(emoticons).replace('-', '')
    tokenized = text.split()
    return tokenized

from collections import Counter
counter = Counter()

for (label, line) in train_dataset:
    counter.update(tokenizer(line))

print(len(counter))

次に, トークンのインデックスを作成します. paddingとunknown tokenように, index 0, 1は開けておきます。また、 unknown tokenがindex 1であるので、defaultdictを使用して, keyに存在しない単語の場合は1を返すようにします.

from collections import defaultdict

def make_vocab(counter: dict):
  i = 2
  vocab = defaultdict(lambda: 1)
  vocab['<pad>'] = 0
  vocab['<unk>'] = 1
  for k, v in counter.items():
    if v > 5:
      vocab[k] = i
      i += 1
  return vocab

vocab = make_vocab(counter)

あとは, pytorchのdataloaderに読み込ませるのみです。ついでにtrainとvalidもrandom_splitで分割します。以下のコードは、python機械学習プログラミングの本からほとんど引用しています。

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch.nn as nn

text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]
label_pipeline = lambda x: float(x)

def collate_batch(batch):
  text_list, label_list, lengths = [], [], []
  for _label, _text in batch:
    label_list.append(label_pipeline(_label))
    processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
    text_list.append(processed_text)
    lengths.append(processed_text.size(0))
  label_list = torch.tensor(label_list)
  lengths = torch.tensor(lengths)
  text_list = nn.utils.rnn.pad_sequence(text_list, batch_first=True)
  return text_list, label_list, lengths

train_dataset, valid_dataset = random_split(train_dataset, [int(len(train_dataset) * 0.8), len(train_dataset) - int(len(train_dataset) * 0.8)])

train_dl = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_batch)
valid_dl = DataLoader(valid_dataset, batch_size=64, shuffle=True, collate_fn=collate_batch)

def check_traindl(train_dl):
  for x, y, z in train_dl:
    print(x.shape)
    print(y.shape)
    print(z.shape)
    print(x[0])
    break
check_traindl(train_dl)

出力結果が以下の様にtokenになっていて、paddingされていたら、正しい結果です。

torch.Size([64, 690])
torch.Size([64])
torch.Size([64])
tensor([   78,   186,   922,  4907,    19,    70,   871, 11662,  4061,    17,
          557,  5798,    28,    17,    40,  6848,    26,   865,    28,    60,
          192,  1789,   387,   122,   237,  4114,  3166,   471,   859,   448,
          588,    26,   208,     2,   109,   141,    34,   141,    51,  3972,
          208,   106,   221,    17,  7978,  4202,   430,    33,     4,  5555,
           28,   319,    60,   192,    80,   170,    46,   618,    28,    25,
           45,   139,    74,  1105,    16,    19,  1347,    28,  3456,    28,
           17,   656,    26,  1938,  4647,    49,   846,    28,    55,    41,
         1077, 13787,    41,    46,  3023,    17,   618,  4231,    48,    28,
           17,   618,  4231,    45,    43,    14,  6959,   636,   296,  8484,
           51,   125,   470,   126,   376, 11662,    78,   871,    78,    34,
          121,    28,   636,   296,  8484,    60,   139,    30, 10831,   656,
          247,   154,  1077,    30,    53,    60,  1225,    70,   511,   534,
         1415,    14,  5798,    28,   557,    14, 11662,    24,    60,   263,
         1225,    70,  1378,   534,  5057,    46,   846,    33,   324,  2739,
           28,  1023,    45,   139,    41,   204,    53,     8,    16,    46,
          346,   246,     2,    53,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0])
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?