0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

AiDaeng-Thai-RoPEをTokenClassificationPipelineで使うには

Last updated at Posted at 2025-09-28

JonusNattapongが、AiDaeng-Thai-RoPEというタイ語生成モデルを公開した。6260万パラメータ(hidden_size:384, num_hidden_layers:6, num_attention_heads:6)という小ぶりなモデルだが、max_position_embeddings:2048なので、そこそこ使えそうである。ただ、transformersのpipelineに対応していないので、このモデルの上にTokenClassificationPipelineを実装するには、多少の手間がかかる。Google Colaboratoryでやってみよう。

!pip install transformers triton accelerate datasets evaluate seqeval
!test -d AiDaeng-Thai-RoPE || git clone --depth=1 https://huggingface.co/JonusNattapong/AiDaeng-Thai-RoPE
with open("AiDaeng-Thai-RoPE/hf_model_t.py","w",encoding="utf-8") as w:
  print('''from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_layers import GenericForTokenClassification
from .hf_model import ThaiTransformerPreTrainedModel, ThaiTransformerModel, ThaiTransformerConfig
class ThaiTransformerNoHead(ThaiTransformerModel):
    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, token_type_ids=None, labels=None, reasoning_effort=0.0, **kwargs):
        tok_emb = inputs_embeds if inputs_embeds else self.token_embedding_table(input_ids)
        x = self.rotary_pos_emb(tok_emb)
        x = self.dropout(x)
        if reasoning_effort > 0.0:
            import torch
            meta_adapt = self.meta_adapter(x)
            gate = torch.sigmoid(self.reasoning_gate(x))
            x = x + reasoning_effort * gate * meta_adapt
        for block in self.blocks:
            x = block(x, self.rotary_pos_emb, reasoning_effort=reasoning_effort)
        x = self.ln_f(x)
        return BaseModelOutput(last_hidden_state=x)
class ThaiTransformerForTokenClassification(GenericForTokenClassification, ThaiTransformerPreTrainedModel):
    def __init__(self, config):
        import torch.nn as nn
        ThaiTransformerPreTrainedModel.__init__(self, config)
        self.model = ThaiTransformerNoHead(config)
        self.config = config
        self.num_labels = config.num_labels
        if getattr(config, "classifier_dropout", None) is not None:
            classifier_dropout = config.classifier_dropout
        elif getattr(config, "hidden_dropout", None) is not None:
            classifier_dropout = config.hidden_dropout
        else:
            classifier_dropout = 0.1
        self.dropout = nn.Dropout(classifier_dropout)
        self.score = nn.Linear(config.hidden_size, config.num_labels)
        self.post_init()''',file=w)
import json
with open("AiDaeng-Thai-RoPE/config.json","r",encoding="utf-8") as r:
  d=json.load(r)
d["auto_map"]={
  "AutoConfig":"hf_model.ThaiTransformerConfig",
  "AutoModel":"hf_model_t.ThaiTransformerNoHead",
  "AutoModelForCausalLM":"hf_model.ThaiTransformerModel",
  "AutoModelForTokenClassification":"hf_model_t.ThaiTransformerForTokenClassification"
}
d["hidden_size"]=384
d["num_hidden_layers"]=6
d["num_attention_heads"]=6
d["tokenizer_class"]="PreTrainedTokenizerFast"
with open("AiDaeng-Thai-RoPE/config.json","w",encoding="utf-8") as w:
  json.dump(d,w,indent=2)
s='$1=="transformers"{printf("-b v%s",$2)}'
!test -d transformers || git clone `pip list | awk '{s}'` https://github.com/huggingface/transformers
!test -d UD_Thai-TUD || git clone -b dev https://github.com/UniversalDependencies/UD_Thai-TUD
def makejson(conllu_file,json_file):
  with open(conllu_file,"r",encoding="utf-8") as r, open(json_file,"w",encoding="utf-8") as w:
    d,f={"tokens":["[CLS]"],"tags":["SYM"]},False
    for s in r:
      if s.strip()=="":
        if len(d["tokens"])>1:
          d["tokens"].append("[SEP]")
          d["tags"].append("SYM")
          print(json.dumps(d),file=w)
        d,f={"tokens":["[CLS]"],"tags":["SYM"]},False
      else:
        t=s.split("\t")
        if len(t)==10 and t[0].isdecimal():
          d["tokens"].append(" "+t[1] if f else t[1])
          d["tags"].append(t[3].upper() if t[3]!="CONJ" else "CCONJ")
          f=t[9].find("SpaceAfter=No")<0
makejson("UD_Thai-TUD/th_tud-ud-train.conllu","train.json")
makejson("UD_Thai-TUD/th_tud-ud-dev.conllu","dev.json")
makejson("UD_Thai-TUD/th_tud-ud-test.conllu","test.json")
!env WANDB_DISABLED=true python transformers/examples/pytorch/token-classification/run_ner.py --task_name pos --model_name_or_path ./AiDaeng-Thai-RoPE --trust_remote_code --train_file train.json --validation_file dev.json --test_file test.json --output_dir ./AiDaeng-Thai-RoPE-upos --overwrite_output_dir --do_train --do_eval --do_predict

AiDaeng-Thai-RoPEにThaiTransformerForTokenClassificationを実装しつつ、transformersのrun_ner.pyの助けを借りて、UD_Thai-TUDUPOS品詞付与を試してみた。Google Colaboratory (GPU版)だと、5分ほどで以下のmetricsが出力されて、AiDaeng-Thai-RoPE-uposが出来上がった。

***** train metrics *****
  epoch                    =        3.0
  total_flos               =   336687GF
  train_loss               =     1.5463
  train_runtime            = 0:01:04.69
  train_samples            =       2902
  train_samples_per_second =    134.579
  train_steps_per_second   =     16.834

***** eval metrics *****
  epoch                   =        3.0
  eval_accuracy           =     0.5218
  eval_f1                 =     0.3684
  eval_loss               =     1.4764
  eval_precision          =     0.4373
  eval_recall             =     0.3182
  eval_runtime            = 0:00:02.89
  eval_samples            =        362
  eval_samples_per_second =    125.023
  eval_steps_per_second   =     15.887

***** predict metrics *****
  predict_accuracy           =      0.535
  predict_f1                 =     0.3825
  predict_loss               =     1.4313
  predict_precision          =     0.4553
  predict_recall             =     0.3297
  predict_runtime            = 0:00:00.92
  predict_samples_per_second =    392.324
  predict_steps_per_second   =     49.716

eval・predictともに、F1値が0.4未満で使い物にならない。ちょっとだけ動かしてみよう。

from transformers import pipeline
nlp=pipeline("token-classification",model="AiDaeng-Thai-RoPE-upos",aggregation_strategy="simple",trust_remote_code=True)
print(nlp("แม่อย่าเก็บไว้คนเดียว"))

出来立てのAiDaeng-Thai-RoPE-uposで「แม่อย่าเก็บไว้คนเดียว」に品詞付与してみたところ、私(安岡孝一)の手元では以下の結果が得られた。

[{'entity_group': 'CCONJ', 'score': np.float32(0.3048764), 'word': 'แ', 'start': 0, 'end': 1}, {'entity_group': 'NUM', 'score': np.float32(0.19828455), 'word': 'ม่', 'start': 1, 'end': 3}, {'entity_group': 'SCONJ', 'score': np.float32(0.15385833), 'word': 'อย', 'start': 3, 'end': 5}, {'entity_group': 'NUM', 'score': np.float32(0.2596558), 'word': '่า', 'start': 5, 'end': 7}, {'entity_group': 'NOUN', 'score': np.float32(0.49739757), 'word': 'เก็', 'start': 7, 'end': 10}, {'entity_group': 'NUM', 'score': np.float32(0.59695584), 'word': 'บ', 'start': 10, 'end': 11}, {'entity_group': 'NOUN', 'score': np.float32(0.5277243), 'word': 'ไว้', 'start': 11, 'end': 14}, {'entity_group': 'NUM', 'score': np.float32(0.5918634), 'word': 'คนเด', 'start': 14, 'end': 18}, {'entity_group': 'NOUN', 'score': np.float32(0.85041416), 'word': 'ียว', 'start': 18, 'end': 21}]

残念ながら、全く読めていない。私の実装が間違ってるのか、あるいは、そもそもTokenClassificationPipelineに適さないモデルなのかも知れない。うーむ、残念。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?