概要
昨年の年末に帰省する関係で自宅のデスクトップのGPUが遊んでいたので言語モデルの事前学習をしました。昨今多様な大規模言語モデルが公開されており、今更個人のGPUで事前学習できる言語モデルなんてたかが知れています。そこで、一般的な言語モデルではなくひらがなに特化した言語モデルにすることを思い立ちました。
ひらがなを出力する言語モデルが従来の日本語言語モデルの性能を上回ることは考えにくいので、他のモデルでは解きにくいタスクとして回文を生成させることにしました。通常の日本語言語モデルの場合1つのトークンが何文字に対応するか決まっておらず、漢字には複数の読み方が考えられるため回文になるように制約をつけて生成するのは困難です。一方で、語彙が全てひらがな1文字の言語モデルを作ってしまえば回文の条件を満たすように制約をつけて生成させることができます。
そんなわけで、全ての語彙がひらがな1文字の言語モデルを事前学習することにしました。
回文について
回文というのは「たけやぶやけた」や「ねつきいいきつね」など、前から読んでも後ろから読んでも同じ読み方の文のことです。
「LLMならある程度出力できるのでは?」と思う方もいるかもしれませんが、日本語の読みを理解した上で左右どちらから読んでも意味を満たすようなテキストを出力する必要があるため、未だにチャレンジングなタスクです。(実生活では何の役にも立ちませんが)
参考:https://www.softbank.jp/biz/blog/cloud-technology/articles/202412/many-shot-icl-palindrome/
具体的な作業
言語モデルの事前学習
データ
事前学習にはCC100の日本語データを利用しました。wikipediaについては中身の英語の割合が高く使いづらそうだったので利用しませんでした。
漢字からひらがなへの変換についてはsudachidictを使って変換しました。
https://github.com/WorksApplications/SudachiDict
tokenizer
回文を作る関係で1トークンが1つのひらがなと結びつくように学習する必要があります。そのため、char-baseのtokenizerを自前で作成しました。そのため、今回作る言語モデルの語彙数は[BOS]などの特殊トークンも含めて94語とめちゃくちゃ語彙が少ないモデルになっています。
モデル
せっかくなので最新モデルを事前学習しようと思ったのですが、3090x1枚で現実的に事前学習を行うのは難しかったので、gpt2-xsmallを事前学習しました。基本的なコードはrinna社が公開していた下記のリポジトリのコードを利用させていただきました。
https://github.com/rinnakk/japanese-pretrained-models
rinna社が公開していたハイパラはあくまで通常の日本語モデルを学習するためのパラメータなので、ひらがな言語モデルを学習するにあたって調整する必要があるのかなと思っていたのですが特に問題なくlossが落ちていったので安心しました。
ちなみに学習は8.5日ほどかかりました。さらに、今回は回文を作る必要があるので順方向のモデルと逆方向のモデルを両方つくったので、2倍時間がかかりました。
作った言語モデルはこちらです。
-
順方向のモデル
https://huggingface.co/hukuda222/hiragana-gpt2-xsmall -
逆方向のモデル
https://huggingface.co/hukuda222/hiragana-reverse-gpt2-xsmall
回文を生成してみる
愚直に書くとこんな感じのコードで回文が生成できます。順方向モデルと逆方向モデルの生成確率の和を取ることで前から読んでも後ろから読んでもそこそこ意味が通りそうな文を出力できます。
from collections import defaultdict
import json
import torch
from transformers import AutoModelForCausalLM
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
DEVICE = torch.device("cuda")
class CharacterTokenizer(PreTrainedTokenizer):
def __init__(
self, characters: Sequence[str] = "", model_max_length: int = 1024, **kwargs
):
self.characters = characters
self.model_max_length = model_max_length
cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
eos_token = AddedToken("[EOS]", lstrip=False, rstrip=False)
mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)
pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)
self._vocab_str_to_int = {
"[CLS]": 0,
"[SEP]": 1,
"[BOS]": 2,
"[MASK]": 3,
"[PAD]": 4,
"[EOS]": 5,
"[UNK]": 6,
**{ch: i + 7 for i, ch in enumerate(characters)},
}
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
unk_token=unk_token,
add_prefix_space=False,
model_max_length=model_max_length,
**kwargs,
)
def vocab_size(self) -> int:
return len(self._vocab_str_to_int)
def get_vocab(self):
return self._vocab_str_to_int
def _tokenize(self, text: str) -> List[str]:
return list(text)
def _convert_token_to_id(self, token: str) -> int:
return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
def _convert_id_to_token(self, index: int) -> str:
return self._vocab_int_to_str[index]
def convert_tokens_to_string(self, tokens):
return "".join(tokens)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
cls = [self.cls_token_id]
result = cls + token_ids_0 + sep
if token_ids_1 is not None:
result += token_ids_1 + sep
return result
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
result = [1] + ([0] * len(token_ids_0)) + [1]
if token_ids_1 is not None:
result += ([0] * len(token_ids_1)) + [1]
return result
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
cls = [self.cls_token_id]
result = len(cls + token_ids_0 + sep) * [0]
if token_ids_1 is not None:
result += len(token_ids_1 + sep) * [1]
return result
def get_config(self) -> Dict:
return {
"char_ords": [ord(ch) for ch in self.characters],
"model_max_length": self.model_max_length,
}
@classmethod
def from_config(cls, config: Dict) -> "HiraganaTokenizer":
cfg = {}
cfg["characters"] = [chr(i) for i in config["char_ords"]]
cfg["model_max_length"] = config["model_max_length"]
return cls(**cfg)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
cfg_file = os.path.join(save_directory, "tokenizer_config.json")
cfg = self.get_config()
with open(cfg_file, "w") as f:
json.dump(cfg, f, indent=4)
@classmethod
def _from_pretrained(
cls,
resolved_vocab_files,
pretrained_model_name_or_path,
init_configuration,
*init_inputs,
token=None,
cache_dir=None,
local_files_only=False,
_commit_hash=None,
_is_local=False,
trust_remote_code=False,
**kwargs,
):
config_file = resolved_vocab_files["tokenizer_config_file"]
with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)
return cls.from_config(config)
tokenizer = CharacterTokenizer.from_pretrained("hukuda222/hiragana-gpt2-xsmall")
model_forward = AutoModelForCausalLM.from_pretrained(
"hukuda222/hiragana-gpt2-xsmall"
).to(DEVICE)
model_backward = AutoModelForCausalLM.from_pretrained(
"hukuda222/hiragana-reverse-gpt2-xsmall"
).to(DEVICE)
input_ids_forward = torch.tensor([[2]]).to(DEVICE)
input_ids_backward = torch.tensor([[5]]).to(DEVICE)
generated_token_str = ""
for step in range(20):
with torch.no_grad():
outputs_forward = model_forward(input_ids=input_ids_forward)
outputs_backward = model_backward(input_ids=input_ids_backward)
next_token_logits = (
outputs_forward.logits[:, -1, :] + outputs_backward.logits[:, -1, :]
)
next_token_id = torch.argmax(next_token_logits, dim=-1)
input_ids_forward = torch.cat(
[input_ids_forward, next_token_id.unsqueeze(-1)], dim=-1
)
input_ids_backward = torch.cat(
[input_ids_backward, next_token_id.unsqueeze(-1)], dim=-1
)
generated_token_str += tokenizer.decode(next_token_id)
print(generated_token_str[:-1] + generated_token_str[::-1])
一応PreTrainedTokenizerを継承してCharacterTokenizerを自前で定義していますが、tokenizerとは名ばかりで1文字ごとに分割しているだけです。
上記のコードを私の環境で実行すると下記のようなテキストが生成できます。
すまいてしんしんあんていかいしいおんせんせんおいしいかいてんあんしんしています
うっすら日本語として読めるものの、回文というにはクオリティが低いテキストになってしまいました。
webサービスとして公開
せっかく作ったのでmicrosoft社が公開しているONNX Runtimeを使って、ユーザーの入力に応じて回文を生成するようなサービスを公開しました。
https://hukuda222.github.io/palindrome/
ユーザーのブラウザ上で言語モデルによる推論を行うため、端末によっては推論に時間がかかったりフリーズする可能性があります。現状では筆者のデスクトップの他に、2019年のmac book proとpixel 6でデフォルトのパラメタで特に問題なく動くことを確認しています。
実装した追加の工夫
愚直な実装だと微妙な回文を生成するだけになってしまったので、追加で下記の工夫を行いました
- n-gram blocking
- beam search
- 上限まで生成するのではなく適切な部分で打ち切る
- 簡易的な実装として、上限まで出力してみてperplexityが一番低い出力を返すような実装にしました
- 確率的decoding
- temperatureパラメータが0.001以上の時はsample beam searchを行い、0.001未満の場合は普通のbeam searchを行うようにしています
- 確率値の和ではなくminを使う
- 片方のモデルの出力に引っ張られて非文になるケースが多かったので、どちらから見ても変な出力になりにくいようにminにしました
回文としてギリギリ成立しそうな出力集
上述の工夫を行った上でなんとか出力できた回文を紹介します。
- うましてしまう
「産ましてしまう」 - たいていのおともとおのいていた
「大抵の音も遠のいていた」 - すまいもおといあわせわあいとおもいます
「住まいもお問い合わせ『わあい』と思います」 - わるあがきいんふるふんいきがあるわ
「悪あがきインフル雰囲気があるわ」 - あなたいがいだんせいならしらないせんだいがいたなあ
「あなた以外男性なら知らない先代がいたなぁ」
まとめ
長い回文を作るのは全然上手くいきませんでした。gpt2-xsmallのmax lengthは1024トークンなので特殊トークンを除いて理論上は2046語の回文を生成できるわけですが、意味が通りそうなのは20語程度が限界でした。回文は特殊な言い回しが必要になることが多々あるので、一般的な日本語データセットだけでなく回文で学習するのが重要だなと強く感じました。回文データセット誰か作ってください。
また、個人の所有するGPUで事前学習するとこの規模のモデルが限界なので、実社会ではあまり役立たないようなことをしても許される組織でこういう取り組みができたら良いなぁと思いました。
今後の展望
性能が上がりそうなアイデアを列挙しておきます。
- 一般的な回文を学習させる
- これが一番効果あると思います
- アルファベットの固有表現をひらがなに変換できていないので、英語の読みに強い辞書を導入する
- 現在の実装では、アルファベットを含む文字列は学習データから除いています
- 助詞を確率的に削除したデータで学習する
- 回文は助詞を削ることで無理やり成立させるケースが度々あるためです
- 現在は「や」と「ゃ」、「た」と「だ」を区別しているが、それらを区別しないようにする
- 解釈の難しさが上がる問題とのトレードオフになりそうです
- 前後の文脈を考慮した上で生成するタスクを事前学習に追加する
- 既存の学習済み言語モデルをベースとして、ひらがなのみで出力するようにfinetuneする
- GPUを借りて実装が公開されている大規模モデルを学習する