RWKVのトークナイザでイライラしていたはずなのに、イキオイ余ってRwkvForTokenClassificationを書いてしまった。
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.rwkv.modeling_rwkv import RwkvModel, RwkvPreTrainedModel
class RwkvForTokenClassification(RwkvPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.rwkv = RwkvModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.rwkv.embeddings
def set_input_embeddings(self, new_embeddings):
self.rwkv.embeddings = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # noqa
inputs_embeds: Optional[torch.FloatTensor] = None,
state: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
rwkv_outputs = self.rwkv(
input_ids,
inputs_embeds=inputs_embeds,
state=state,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = rwkv_outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + rwkv_outputs[1:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=rwkv_outputs.hidden_states,
attentions=rwkv_outputs.attentions,
)
基本的なところは、半年前のPull Requestを踏襲している。ただ、RWKVが固有表現抽出や品詞付与に向いているのかどうか、正直あまり自信がない。私(安岡孝一)個人としても、RwkvLinearAttentionまわりをちゃんとファインチューニングできるかどうか、はなはだ心もとない。まあ、ちょっと悩んでみようかな。