概要
エンコードされた固有表現抽出データセットを作成します。
この記事で作成するデータセットを使って、言語モデルをファインチューニングします。
ポートフォリオとして、自作のデータセットでファインチューニングした言語モデルを使ったアプリを公開しました。
この記事を読むだけでも、この記事の内容をある程度理解できるとは思いますが、データセット作成の導入記事や、前段階の記事を読んでいる前提で書かれています。
固有表現抽出データセット(トークン化、エンコードはされていない)
この記事で作成されるもの
エンコードされた固有表現抽出データセット
作成には、”cl-tohoku/bert-base-japanese-v2”のトークナイザーを使用しました。
補足情報
開発はGoogle Colaboratoryで行われ、このノートブックで作成しました。
このノートブックを含むリポジトリの構造は、実際の開発環境と同一です。
from
で参照されている自作モジュールは、このディレクトリにあるものです。
その他のコード内で参照しているパスやディレクトリから/content/drive/MyDrive
を省くと、そのパスやディレクトリの中身をリポジトリから確認することができます。
方針
トークン化前のデータセットをエンコードして、エンコード済みデータセットを作成します。
大まかな手順
-
トークン化前のデータセットの作成で作成したトークン化前のデータセットの読み込み
[ { 'text': '山梨または、青森県のやつを、検索してくれませんか?', 'entities': [ {'name': '山梨', 'span': [0, 2], 'type': 'AREA'}, {'name': '青森県', 'span': [6, 9], 'type': 'AREA'} ] }, { 'text': '富山県または、静岡県のお料理が、あれば探して?', 'entities': [ {'name': '富山県', 'span': [0, 3], 'type': 'AREA'}, {'name': '静岡県', 'span': [7, 10], 'type': 'AREA'} ] }, { 'text': '群馬と北東北で食べられる、料理を知ってたら、教えて?', 'entities': [ {'name': '群馬', 'span': [0, 2], 'type': 'AREA'}, {'name': '北東北', 'span': [3, 6], 'type': 'AREA'} ] }, ..., { 'text': 'ぶりかエノキを使用した春に食べられている肉料理で、北陸地方の郷土料理があったら探して', 'entities': [ {'name': 'ぶり', 'span': [0, 2], 'type': 'INGR'}, {'name': 'エノキ', 'span': [3, 6], 'type': 'INGR'}, {'name': '春', 'span': [11, 12], 'type': 'SZN'}, {'name': '肉料理', 'span': [20, 23], 'type': 'TYPE'}, {'name': '北陸地方', 'span': [25, 29], 'type': 'AREA'} ] }, { 'text': 'フグあるいは練り辛子が使われている通年に食べられている飯料理で、北海道地方の料理をご存じでしたら、教えて', 'entities': [ {'name': 'フグ', 'span': [0, 2], 'type': 'INGR'}, {'name': '練り辛子', 'span': [6, 10], 'type': 'INGR'}, {'name': '通年', 'span': [17, 19], 'type': 'SZN'}, {'name': '飯料理', 'span': [27, 30], 'type': 'TYPE'}, {'name': '北海道地方', 'span': [32, 37], 'type': 'AREA'} ] }, { 'text': '鰌あるいはナガネギが使われている夏に食べられている野菜系で、滋賀県のものがあったら調べて下さい', 'entities': [ {'name': '鰌', 'span': [0, 1], 'type': 'INGR'}, {'name': 'ナガネギ', 'span': [5, 9], 'type': 'INGR'}, {'name': '夏', 'span': [16, 17], 'type': 'SZN'}, {'name': '野菜系', 'span': [25, 28], 'type': 'TYPE'}, {'name': '滋賀県', 'span': [30, 33], 'type': 'AREA'} ] } ]
-
データを1つずつ参照し、
text
をトークン化し、各トークンへラベル付け
トークナイザーが知らない語彙を含むデータも、データセットに採用しています。
トークナイザーが知らない語彙を含む入力文からでも、文脈から固有表現を抽出するように学習することを期待しました。正しくラベル付けできない文章だけは、データセットに採用しないことにしました。
「宮崎 に 、 お 肉 料理 、 また は 、 米 系 です が 使わ れ て いる 通年 の 料理 が あれ ば 探し て 頂 け ませ ん か ?」
という文章を実際にデータセットから省きました。
「米 系 です が」の”です”の”す”は食材の”酢”のことであり、”す”を含むトークンに食材のラベルを付けたかったのですが、”が”という文字列が余計にくっついてしまっているため、正確なラベル付けができません。
こういった文章は、データセットに採用しないことにしています。
コード
import、install
from google.colab import drive
drive.mount('/content/drive')
!pip install transformers fugashi ipadic
!pip install unidic-lite
from typing import List, Dict, Tuple
import math
import pandas as pd
from transformers import BertJapaneseTokenizer
from transformers.tokenization_utils_base import BatchEncoding
import sys
sys.path.append('/content/drive/MyDrive/local_cuisine_search_app/modules')
from utility import load_json_obj
from pandas_utility import save_csv_df
クラスの定義
class DatasetMaker:
"""
データセット作成用のクラス
"""
@staticmethod
def create_and_save(
untokenized_dataset_path: str,
model_name: str,
labels_dic_path: str,
file_name: str,
save_dir: str
) -> pd.DataFrame:
"""
データセットの作成と保存
Parameters
----------
untokenized_dataset_path : str
トークン化されていないデータセットが保存されているパス
model_name : str
事前学習済み言語モデルの名前
トークナイザーの設定に使う
labels_dic_path : str
特殊トークンのラベルとそのidの辞書が保存されているパス
file_name : str
保存するデータセットのファイル名
save_dir : str
データセットの保存先ディレクトリ
Returns
-------
pd.DataFrame
エンコード済みのデータセット
"""
untokenized_dataset = load_json_obj(untokenized_dataset_path)
texts = [data['text'] for data in untokenized_dataset]
tokens_max_len = DatasetMaker._decide_tokens_max_len(texts)
data_maker = DataMaker(model_name, tokens_max_len, labels_dic_path)
dataset: List[BatchEncoding] = []
for untokenized_data in untokenized_dataset:
data = data_maker.create(untokenized_data)
if data:
dataset.append(data)
data_maker.show_unk_words_and_remove_texts()
dataset = pd.DataFrame(
data=dataset, columns = ['input_ids', 'attention_mask', 'labels']
)
save_csv_df(dataset, file_name, save_dir)
return dataset
@staticmethod
def _decide_tokens_max_len(texts: List[str]) -> int: # ※1
"""
tokens_max_lenの決定
各データのトークン数の決定
Parameters
----------
texts : List[str]
トークン化されていないデータセットの入力文のリスト
Returns
-------
int
最大トークン数
"""
max_len_of_text = 0
for text in texts:
len_of_text = len(text)
if len_of_text > max_len_of_text:
max_len_of_text = len_of_text
log_of_max_len = math.log2(max_len_of_text)
rounded_up_log = math.ceil(log_of_max_len)
tokens_max_len = 2 ** rounded_up_log
return tokens_max_len
class DataMaker:
"""
データ作成用のクラス
Attributes
----------
_sep_token : str
一文の終わりを示す特殊トークン
_unk_token : str
トークナイザーが知らない語彙用の特殊トークン
_tokenizer: BertJapaneseTokenizer
トークナイザー
_tokens_max_len : int
最大トークン数
_unk_words: List[str]
トークナイザーが知らなかった語彙のリスト
_labels_maker : LabelsMaker
正解ラベルのリスト作成用のオブジェクト
_remove_texts: List[str]
データセットに使わない文章のリスト
"""
_sep_token = '[SEP]'
_unk_token = '[UNK]'
def __init__(
self, model_name: str, tokens_max_len: int, labels_dic_path: str
):
"""
コンストラクタ
Parameters
----------
model_name : str
事前学習済み言語モデルの名前
トークナイザーの設定に使う
tokens_max_len : int
最大トークン数
labels_dic_path : str
特殊トークンのラベルとそのidの辞書が保存されているパス
"""
self._tokenizer: BertJapaneseTokenizer = BertJapaneseTokenizer.from_pretrained(model_name)
self._tokens_max_len = tokens_max_len
self._unk_words: List[str] = []
self._labels_maker = LabelsMaker(labels_dic_path)
self._remove_texts: List[str] = []
def create(
self,
untokenized_data: Dict[str, str | List[Dict[str, str | List[int]]]]
) -> BatchEncoding | None:
"""
データの作成
Parameters
----------
untokenized_data : Dict[str, str | List[Dict[str, str | List[int]]]]
トークン化されていない学習データ
入力文と、抽出対象固有表現の情報を持つ辞書
Returns
-------
BatchEncoding | None
エンコード済みの学習データ
トークン化の区切り位置が良くなかった場合はNone
"""
text: str = untokenized_data['text']
unlabeled_data = self._tokenizer.encode_plus(
text,
max_length=self._tokens_max_len,
padding='max_length',
return_token_type_ids=False
)
input_ids: List[int] = unlabeled_data['input_ids']
tokens = self._decode(input_ids, text)
entity_infos: List[Dict[str, str | List[int]]] = untokenized_data['entities']
labels = self._labels_maker.create(
tokens, entity_infos, self._tokens_max_len, self._remove_texts
)
if labels:
unlabeled_data.update({'labels': labels})
data = unlabeled_data
return data
else:
return None
def _decode(self, input_ids: List[int], text: str) -> List[str]:
"""
デコード
input_idsをトークンのリストに変換する
Parameters
----------
input_ids : List[int]
入力文の各トークンのidのリスト
text : str
トークン化されていない入力文
Returns
-------
List[str]
トークンのリスト
"""
tokens = self._tokenizer.convert_ids_to_tokens(input_ids)
tokens = self._remove_extra_tokens_and_strs(tokens)
if self._unk_token in tokens:
tokens = Unknown.restore(
tokens, text, self._unk_token, self._unk_words
)
return tokens
def _remove_extra_tokens_and_strs(self, tokens: List[str]) -> List[str]:
"""
余分なトークンと文字列の削除
トークン化されていない文章の文字数に、tokensの文字数をそろえる
Parameters
----------
tokens : List[str]
トークンのリスト
Returns
-------
List[str]
余分なトークンと文字列が削除されたトークンのリスト
"""
sep_token_idx = tokens.index(self._sep_token)
tokens = tokens[1:sep_token_idx] # ※2
tokens = [token.replace('##', '') for token in tokens] # ※3
return tokens
def show_unk_words_and_remove_texts(self) -> None:
"""
トークナイザーが知らなかった語彙とデータセットに採用しない入力文の表示
"""
print('\nトークナイザーが知らない語彙')
unk_words_str = '、'.join(self._unk_words)
print(f' {unk_words_str}')
print('\n削除した文章')
for remove_text in self._remove_texts:
print(f' {remove_text}')
print(f'\n削除した文章数: {len(self._remove_texts)}')
class Unknown:
"""
[UNK]トークンに関する処理を担うヘルパークラス
正解ラベルのリストを作成するために、全てのトークンの元の文字数の情報が必要
Attributes
----------
_sep : str
分割用文字列
"""
_sep = '[sep]'
@staticmethod
def restore(
tokens: List[str], text: str, unk_token: str, unk_words: List[str]
) -> List[str]:
"""
トークンのリストの復元
Parameters
----------
tokens : List[str]
[UNK]トークンを含むトークンのリスト
text : str
tokensのトークン化前の文字列
unk_token : str
[UNK]トークン
unk_words : List[str]
トークナイザーが知らなかった語彙のリスト
Returns
-------
List[str]
[UNK]トークンが復元されたトークンのリスト
"""
include_unk_words = Unknown._restore_unk_words(tokens, text, unk_token)
for unk_word in include_unk_words:
unk_token_idx = tokens.index(unk_token)
tokens[unk_token_idx] = unk_word
if unk_word not in unk_words:
unk_words.append(unk_word)
return tokens
@staticmethod
def _restore_unk_words(
tokens: List[str], text: str, unk_token: str
) -> List[str]:
"""
[UNK]トークンの復元
Parameters
----------
tokens : List[str]
[UNK]トークンを含むトークンのリスト
text : str
tokensのトークン化前の文字列
unk_token : str
[UNK]トークン
Returns
-------
List[str]
textに含まれていて、トークナイザーが知らなかった語彙のリスト
"""
decoded_text = ''.join(tokens)
strs_without_unk = Unknown._str_to_lst(decoded_text, unk_token)
unk_words_str = text.replace('?', '?') # ※4
for string in strs_without_unk:
unk_words_str = unk_words_str.replace(string, Unknown._sep)
unk_words = Unknown._str_to_lst(unk_words_str, Unknown._sep)
return unk_words
@staticmethod
def _str_to_lst(string: str, split_str: str) -> List[str]:
"""
文字列をリストへ変換
Parameters
----------
string : str
文字列
split_str : str
区切り文字
Returns
-------
List[str]
リスト
"""
string = string.strip(split_str) # ※5
lst = string.split(split_str)
return lst
class LabelsMaker:
"""
正解ラベルのリスト作成用のクラス
Attributes
----------
_label2id_dic : Dict[str, int]
ラベルをidに変換する辞書
_other_token_id
抽出対象じゃないトークンのラベルのid
"""
def __init__(self, labels_dic_path: str):
"""
コンストラクタ
Parameters
----------
labels_dic_path : str
特殊トークンのラベルとそのidの辞書が保存されているパス
"""
id2label_dic = load_json_obj(labels_dic_path)
self._label2id_dic: Dict[str, int] = {
label: id for id, label in id2label_dic.items()
}
self._other_token_id = list(id2label_dic.keys())[0]
def create(
self,
tokens: List[str],
entity_infos: List[Dict[str, str | List[int]]],
tokens_max_len: int,
remove_texts: List[str]
) -> List[int] | None:
"""
ラベルのリストの作成
トークンの区切り位置が良くなかった場合は作成しない
Parameters
----------
tokens : List[str]
トークンのリスト
entity_infos : List[Dict[str, str | List[int]]]
tokensに含まれる固有表現の情報の辞書のリスト
tokens_max_len : int
最大トークン数
remove_texts : List[str]
データセットに使わない文章のリスト
Returns
-------
List[int] | None
正解ラベルのidのリスト
トークン化の区切り位置が良くなかった場合はNone
"""
token_start_idxs, token_end_idxs = Index.create_start_end_idxs(tokens)
ts_idxs = token_start_idxs
te_idxs = token_end_idxs
entity_spans: List[List[int]] = [
entity_info['span'] for entity_info in entity_infos
]
entity_start_idxs, entity_end_idxs = Index.create_start_end_idxs(entity_spans)
es_idxs = entity_start_idxs
ee_idxs = entity_end_idxs
if Index.is_idxs_match(ts_idxs, te_idxs, es_idxs, ee_idxs):
entity_types: List[str] = [
entity_info['type'] for entity_info in entity_infos
]
labels = self._create_labels(
ts_idxs, te_idxs, es_idxs, ee_idxs, entity_types, tokens_max_len
)
return labels
else:
remove_text = ' '.join(tokens)
remove_texts.append(remove_text)
return None
def _create_labels(
self,
ts_idxs: List[int],
te_idxs: List[int],
es_idxs: List[int],
ee_idxs: List[int],
entity_types: List[str],
tokens_max_len: int
) -> List[int]:
"""
ラベルのリストの作成
Parameters
----------
ts_idxs : List[int]
入力文に対する、全トークンの開始位置のインデックスのリスト
te_idxs : List[int]
入力文に対する、全トークンの終了位置のインデックスのリスト
es_idxs : List[int]
入力文に対する、全固有表現の開始位置のインデックスのリスト
ee_idxs : List[int]
入力文に対する、全固有表現の終了位置のインデックスのリスト
entity_types: List[str]
入力文に含まれる全固有表現の種類のリスト
tokens_max_len : int
最大トークン数
Returns
-------
List[int]
ラベルのidのリスト
"""
labels = [self._other_token_id] * tokens_max_len
for es_idx, ee_idx, entity_type in zip(es_idxs, ee_idxs, entity_types):
entity_begin_token_idx = ts_idxs.index(es_idx) + 1
entity_last_token_idx = te_idxs.index(ee_idx) + 1
begin_token_label_id = self._label2id_dic[f'B-{entity_type}']
labels[entity_begin_token_idx] = begin_token_label_id
if entity_begin_token_idx != entity_last_token_idx:
inside_token_label_id = self._label2id_dic[f'I-{entity_type}']
inside_token_idxs = slice(
entity_begin_token_idx + 1, entity_last_token_idx + 1
)
id_num = entity_last_token_idx - entity_begin_token_idx
labels[inside_token_idxs] = [inside_token_label_id] * id_num
return labels
class Index:
@staticmethod
def create_start_end_idxs(
tokens_or_entity_spans: List[str] | List[List[int]]
) -> Tuple[List[int], List[int]]:
"""
開始位置と終了位置のインデックスのリストの作成
Parameters
----------
tokens_or_entity_spans : List[str] | List[List[int]]
トークンのリストか、全固有表現の開始位置と終了位置のリスト
Returns
-------
Tuple[List[int], List[int]]
開始位置のインデックスのリストと、
終了位置のインデックスのリストのタプル
"""
if isinstance(tokens_or_entity_spans[0], str):
return Index._create_token_idxs(tokens_or_entity_spans)
else:
return Index._create_entity_idxs(tokens_or_entity_spans)
@staticmethod
def _create_token_idxs(tokens: List[str]) -> Tuple[List[int], List[int]]:
"""
全トークンの開始位置と終了位置のインデックスのリストの作成
Parameters
----------
tokens : List[str]
トークンのリスト
Returns
-------
Tuple[List[int], List[int]]
トークンの開始位置のインデックスのリストと、
終了位置のインデックスのリストのタプル
"""
start_idxs = []
end_idxs = []
current_idx = 0
for token in tokens:
start_idx = current_idx
end_idx = current_idx + len(token)
start_idxs.append(start_idx)
end_idxs.append(end_idx)
current_idx = end_idx
return start_idxs, end_idxs
@staticmethod
def _create_entity_idxs(
entity_spans: List[List[int]]
) -> Tuple[List[int], List[int]]:
"""
全固有表現の開始位置と終了位置のインデックスのリストの作成
Parameters
----------
entity_spans : List[List[int]]
全固有表現の開始位置と終了位置のインデックスのリスト
Returns
-------
Tuple[List[int], List[int]]
全固有表現の開始位置のインデックスのリストと、
終了位置のインデックスのリストのタプル
"""
start_idxs = []
end_idxs = []
for entity_span in entity_spans:
entity_start_idx = entity_span[0]
entity_end_idx = entity_span[1]
start_idxs.append(entity_start_idx)
end_idxs.append(entity_end_idx)
return start_idxs, end_idxs
@staticmethod
def is_idxs_match(
ts_idxs: List[int],
te_idxs: List[int],
es_idxs: List[int],
ee_idxs: List[int]
) -> bool:
"""
トークンと固有表現の開始位置と終了位置の確認
位置がそろっていれば、正解ラベルを付けることができる
Parameters
----------
ts_idxs : List[int]
入力文に対する、全トークンの開始位置のインデックスのリスト
te_idxs : List[int]
入力文に対する、全トークンの終了位置のインデックスのリスト
es_idxs : List[int]
入力文に対する、全固有表現の開始位置のインデックスのリスト
ee_idxs : List[int]
入力文に対する、全固有表現の終了位置のインデックスのリスト
Returns
-------
bool
そろっていればTrue、そろっていなければFalse
"""
token_idxs_lst = [ts_idxs, te_idxs]
entity_idxs_lst = [es_idxs, ee_idxs]
for token_idxs, entity_idxs in zip(token_idxs_lst, entity_idxs_lst):
if any(entity_idx not in token_idxs for entity_idx in entity_idxs):
return False
return True
実行
untokenized_dataset_path = '/content/drive/MyDrive/local_cuisine_search_app/data/processed_data/04_encoded_dataset_dataframe/encoded_dataset_dataframe_dependencies/01_untokenized_dataset_list/untokenized_dataset_list.json'
model_name = 'cl-tohoku/bert-base-japanese-v2'
labels_dic_path = '/content/drive/MyDrive/local_cuisine_search_app/data/processed_data/03_labels_dictionary/labels_dictionary.json'
file_name = 'encoded_dataset_dataframe'
save_dir = '/content/drive/MyDrive/local_cuisine_search_app/data/processed_data/04_encoded_dataset_dataframe'
dataset = DatasetMaker.create_and_save(
untokenized_dataset_path, model_name, labels_dic_path, file_name, save_dir
)
メモ
※1
深層学習のフレームワークは、2のべき乗のシーケンス長に最適化されていることが多いようなので、このような処理でデータのmax_length
を決めることにしました。
※2
抽出対象トークンに正解ラベルを付与する処理のために、”[CLS]”、”[SEP]”、”[PAD]”をtokens
から省いています。
※3
サブワードに付く##
も、省いておかないと、正解ラベルを付与するための処理で各トークンのspan
と、抽出対象の語彙のspan
にずれが生じてしまいます。
※4
今回の処理に使ったトークナイザーは、語彙に半角の”?”は持っていますが、全角の”?”は持っていません。
よって、encode_plus(...)
に渡されたtext
内の全角の”?”は、convert_ids_to_tokens(input_ids)
によって半角の”?”として出力されます。
なので、unk_words_str.replace(string, Unknown._sep)
で置き換えられるように、text
の全角の”?”は半角の”?”に変えておく必要があります。
※5
string
の先頭がunk_token
だと、string.split(split_str)
の最初の要素が空文字(''
)になってしまいます。
同様に、string
の末尾がUnknown._sep
だと、string.split(split_str)
の最後の要素も空文字(''
)になってしまいます。
参考資料
huggingfaceのライブラリで機械学習(この記事を参考に、"input_ids", "attention_mask", "labels"の3つをデータセットへ持たせることにしました。)