4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

huggingfaceの翻訳タスクを触ってみた

Last updated at Posted at 2024-03-09

ようやくtransformersの翻訳タスクでlossが下がった。
思った結果が出るかはわかりませんが、今日も備忘録を書きます。

さて、huggingfaceの翻訳タスクですが、T5を利用した英語toフランス語のチュートリアルは載っているのですが、英日翻訳の手順が載っていないので悩みながら実装しました。

画像はlossが下がったので多分動いてそうな気がする。

スクリーンショット 2024-03-09 22.41.11.png

翻訳フレームワークの選定

opennmp -> fairseq -> transformers と流れつきました。

  • opennmpはハイパーパラメータ周りを触ると動かない。バグか設定ミスか判断が難しい。バージョンアップは続いているのに長いこと直っていない状態?
  • fairseqは2022年から更新がなく「バグが多いのにどうなってるんや。。」と思い諦めました。(後継のfairseq2がリポジトリにありますね。)

そんな感じでtransformersでT5の翻訳を試しています。

コーパス準備

この辺の処理を使ってコーパスを作成しました。

上のコードで幾らか綺麗になったコーパスから日本語と英語の行が対になったテキストファイルを準備しました。

$ head text.ja 
ヨーロッパ で は インター シティ が 運行 さ れ て いる 。
好き な こと ・ もの : ゲーム は 純愛 系 が いちばん 好き か も 。
弱気 が 美人 を 得 た ためし が ない 。

$ head text.en
european cities are leading this transition .
i must confess i love the theater best , though .
faint heart never won fair lady .
$ wc -l text.*          
  720000 text.en
  720000 text.ja
 1440000 total

tokenizerを学習

全てのコーパスでサブワードを作成します。
sentencepieceでトレーニングした後、T5用のモデルで保存します。
サブワードはunigramが主流みたいです。

byte_fallbackの例を探し見つけたこちらのコードを参考に作成しました。

$ cat text.ja text.en > mixed.txt
create_tokenize.py
import sentencepiece as spm
from transformers import T5Tokenizer

fname = "mixed.txt"
spm.SentencePieceTrainer.train(
    input=fname,
    model_type="unigram",
    model_prefix="spm",
    add_dummy_prefix=False,
    byte_fallback=True,
    vocab_size=32000,
    character_coverage=0.9995,
    unk_piece="[UNK]",
    pad_piece="[PAD]",
    pad_id=0,
    unk_id=1,
    bos_id=2,
    eos_id=3,
    input_sentence_size=12000000
)

tokenizer = T5Tokenizer(
    vocab_file="spm.model",
    unk_token = '[UNK]',
    bos_token = '<s>',
    eos_token = '</s>',
    pad_token = '[PAD]',
    extra_ids=0,
    model_max_length=32000,
)
tokenizer.save_pretrained("t5_models")

出力ファイル

spm.model, spm.vocab, t5_models/*

データーセットの用意

transformersの翻訳チュートリアルのデータと同じ構成になるように用意しました。

create_dataset.py
from tqdm import tqdm
from datasets import Dataset, DatasetDict

def load(name):
    with open(name) as f:
        lines = f.read().split("\n")
    return lines

en_list = load("text.en")
ja_list = load("text.ja")

buff = []

for idx in tqdm(range(len(en_list))):
    try:
        en = en_list[idx]
        ja = ja_list[idx]
        ja = ja.replace(" ", "")
        d = { "id": idx, "translation": { "en": en, "ja": ja} }
        buff.append(d)
    except Exception as e:
        print(e)
        continue

data = Dataset.from_list(buff)
train_test = data.train_test_split(test_size=0.1)
test_valid = train_test['test'].train_test_split(test_size=0.5)
ds = DatasetDict({
    'train': train_test['train'],
    'test': test_valid['test'],
    'validation': test_valid['train']})
ds.save_to_disk("output_en_ja")
print(ds)
print(ds['train'][0])

出力結果

DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 648000
    })
    test: Dataset({
        features: ['id', 'translation'],
        num_rows: 36001
    })
    validation: Dataset({
        features: ['id', 'translation'],
        num_rows: 36000
    })
})
{'id': 241687, 'translation': {'en': 'the two lived happily ever after .', 'ja': '二人は末永く幸せに暮らした。'}}

トレーニング実行前のフォルダ構成

$ tree
.
├── create_dataset.py
├── create_tokenize.py
├── mixed.txt
├── output_en_ja/*
├── spm.model
├── spm.vocab
├── t5_models
│   ├── special_tokens_map.json
│   ├── spiece.model
│   └── tokenizer_config.json
├── text.en
└── text.ja

トレーニング

huggingfaceの翻訳チュートリアルから4点程修正を加えました。

  • 読み込み方法をload_from_diskに変更。
  • model.configを見るとtask_specific_paramstranslation_en_to_jaがないので追加。
  • 自作のtokenizerを適用
    • decoder_start_token_idで初期化
    • resize_token_embeddingsで自作のvocab_sizeに拡張
  • 濁点対応にkeep_accents=Trueを設定
train.py
import numpy as np
from datasets import load_from_disk
from transformers import (
    Seq2SeqTrainingArguments, Seq2SeqTrainer,
    AutoTokenizer, T5Config, T5Tokenizer, T5ForConditionalGeneration,
    DataCollatorForSeq2Seq,
)
import evaluate

checkpoint = "google-t5/t5-base"
tokenizer = AutoTokenizer.from_pretrained("./t5_models", keep_accents=True)
tokenizer.backend_tokenizer.pre_tokenizer.add_prefix_space=False

def get_model(checkpoint: str, device: str, tokenizer: T5Tokenizer) -> T5ForConditionalGeneration:
    config = T5Config(decoder_start_token_id=tokenizer.pad_token_id)
    model = T5ForConditionalGeneration(config).from_pretrained(checkpoint)
    model.resize_token_embeddings(len(tokenizer))
    model = model.to(device)
    return model

model = get_model(checkpoint=checkpoint, device="cuda", tokenizer=tokenizer)

translation_dict_param = {
    "early_stopping": True,
    "max_length": 300,
    "num_beams": 4,
    "prefix": "translate English to Japanese: "
}
translation_dict = { "translation_en_to_ja": translation_dict_param }
model.config.task_specific_params = translation_dict

source_lang = "en"
target_lang = "ja"
prefix = "translate English to Japanese: "

def preprocess_function(examples):
    inputs = [prefix + example[source_lang] for example in examples["translation"]]
    targets = [example[target_lang] for example in examples["translation"]]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True)
    return model_inputs


books = load_from_disk("output_en_ja")
tokenized_books = books.map(preprocess_function, batched=True)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)
metric = evaluate.load("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

training_args = Seq2SeqTrainingArguments(
    output_dir="checkpoints",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
    report_to=['tensorboard'],
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_books["train"],
    eval_dataset=tokenized_books["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)
trainer.train()

推論のお試し

そういえばhuggingfaceの翻訳チュートリアルの推論コードは動かないんですよね。

test.py
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.trainer_utils import get_last_checkpoint
import warnings
warnings.simplefilter('ignore')

prefix = "translate English to Japanese: "
text = prefix + "Cat is cute."

model_name = get_last_checkpoint("./checkpoints")
print(model_name)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16)
model = model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name, keep_accents=True)
inputs = tokenizer(text, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=40, do_sample=True, top_k=30, top_p=0.95)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)

おわりに

初学者のメモなのでぬるい目で見守ってください。

その他疑問点

transformersで英日翻訳試してるブログを見かけないのだが探し方が悪いのだろうか?
それとも何か罠があって触る人が少ない?
確かに英日翻訳のチュートリアルはないので初心者には辛い…
最近はLLMが主流なので要約のタスク等に人が流れているのでしょうね。

3/9 - 3/15

lossは大分下がってきた。
高速でメモリ潤沢なGPUが欲しくなる。

スクリーンショット 2024-03-15 22.04.07.png

bleuが中々上がらない。
このままゆっくりしか上がらなければ使い物にはならないなぁ。

スクリーンショット 2024-03-15 22.03.29.png

4
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?