LoginSignup
0
0

Truncation = Trueを忘れたらエラーになった話

Posted at

TL; DR

transformersライブラリでBERTを使って文書分類をやろうとしたらdata_collatorからValueErrorを受け取った話。
原因は文書の長さをそろえていなかったこと。

背景

transformersライブラリを使ってlivedoorニュースコーパスのカテゴリ分類をやろうとした。

困ったこと

DataCollatorWithPaddingを使ってデータをテンソル化しようとした際にValueErrorが出る。

samples = tokenized_train_dataset[:8]
[len(x) for x in samples["input_ids"]]

batch = data_collator(samples)
ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`token_type_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

原因

def tokenize_function(example):
  category = example['category']
  example = tokenizer(example['content'])
  example['labels'] = category2num[category]

  return example

ここでtokenizerのオプションにtruncation=Trueを入れてなかった。公式ドキュメントにはこれをすると各文書をmax_lengthのトークン長に切りそろえてくれるとある。

今回使ったBERTの最大トークン長は512であるため、max_length=512も必要。

対応

tokenizerのオプションにtruncation=Truemax_length~512を追加。

def tokenize_function(example):
  category = example['category']
  example = tokenizer(example['content'], truncation=True, max_length=512)
  example['labels'] = category2num[category]

  return example

これで解決。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0