LoginSignup
0
0

More than 1 year has passed since last update.

huggingface で用意されているモデルの構造を見たい

Last updated at Posted at 2022-06-27

例えば

from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained('bert-base-japanese-whole-word-masking')

のようなことをしていたとして、このBertForSequenceClassificationのアーキテクチャを見たいとする。

手段1:単純

単純に

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained('bert-base-japanese-whole-word-masking')
print(model)

とする。

手段2:ちょっと複雑だがpretrainedパラメータをダウンロードしなくて済む

from_pretrained()を使って重みをダウンロードせずに、アーキテクチャを見たい場合。

from transformers import BertForSequenceClassification, BertConfig

config = BertConfig() # get config 
model = BertForSequenceClassification(config) # get model instance

print(model)

このようにmodelを事前学習のものではなく初期化した状態で得るときは引数configが必要となる。configとはconfigurationの略で、モデルのハイパーパラメータの一覧である。

huggingfaceには何々Configというクラスが沢山あり、そのなかでそれっぽい名前のもの(今回はBertConfig)のインスタンスを渡せばいいと思う。今回私は Bert まで打って入力補完でゴリ押した:
bertConfig_demo.gif

おまけ:入出力テンソルの形が見たい

散々使ってきたprint(model)だと各層の入出力テンソルのshapeが見れない。それを見たいときは https://github.com/TylerYep/torchinfo を使うといい。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0