例えば
from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained('bert-base-japanese-whole-word-masking')
のようなことをしていたとして、このBertForSequenceClassification
のアーキテクチャを見たいとする。
手段1:単純
単純に
from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained('bert-base-japanese-whole-word-masking')
print(model)
とする。
手段2:ちょっと複雑だがpretrainedパラメータをダウンロードしなくて済む
from_pretrained()
を使って重みをダウンロードせずに、アーキテクチャを見たい場合。
from transformers import BertForSequenceClassification, BertConfig
config = BertConfig() # get config
model = BertForSequenceClassification(config) # get model instance
print(model)
このようにmodelを事前学習のものではなく初期化した状態で得るときは引数configが必要となる。configとはconfigurationの略で、モデルのハイパーパラメータの一覧である。
huggingfaceには何々Config
というクラスが沢山あり、そのなかでそれっぽい名前のもの(今回はBertConfig
)のインスタンスを渡せばいいと思う。今回私は Bert まで打って入力補完でゴリ押した:
おまけ:入出力テンソルの形が見たい
散々使ってきたprint(model)
だと各層の入出力テンソルのshapeが見れない。それを見たいときは https://github.com/TylerYep/torchinfo を使うといい。