11
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【 BERT後継モデル? 】BERTにGANの枠組みを取り入れたElectraが、RoBERTaの1/4の学習データで、RoBERTaと同じ精度を達成!Electraの仕組みを徹底解説!

Last updated at Posted at 2020-12-22

Google発のElectraとは?

2019年9月26日に、Googleが初稿論文を公開した、複数の言語処理タスクに適用可能な学習済のモデルです。
2020年3月10日には、更新版の論文が2020年のICLRカンファレンスに採択されています。

( 論文 )

・ 論文(初稿)提出日:2019年9月26日
・ 論文(更新日):2020年3月10日

Kevin Clark et.al., ELECTRA: PRE-TRAINING TEXT ENCODERS AS DISCRIMINATORS RATHER THAN GENERATORS

この論文は、2020/12/22現在、後出の論文から199回、引用されています(Google Scholarによる集計)。

Google Scholar

スクリーンショット 2020-12-23 0.08.27.png

Electron論文を引用している論文(一例) )

スクリーンショット 2020-12-23 0.08.40.png

学習済みのモデル(pretrained model)は、Google公式のGitHubリポジトリで公開されています(大中小3つのサイズが公開)。TensorFlow HubとHuggingfaceにも、学習済みのモデルがすでに登録済みです。

さらに、日本語コーパスで学習済みのモデルも、Cinammon社から公開されています。

そんなElectraモデルですが、日本語では、アルゴリズムを解説したり、このモデルをいろいろな言語処理タスクに適用して精度検証を試みたWebサイトは、あまり見当たりませんでした。

このモデルが多くの人の目にとまり、__TransformerBERT__並に使われるようになるのは、2021年に入ってから??でしょうか。

スクリーンショット 2020-12-22 22.27.38.png

Electraの概要

  1. __RoBERTa__の__1/4の計算量で、RoBERTaと同等の精度を達成__した。
    同じ計算量でElectraを学習させると、RoBERTaを上回る精度を達成

  2. BERTが、マスキングされた[Mask]トークン部分に本来あるべき単語を、文の前後両方の文脈から予測する(Masked Language Modellingタスク)のに対して、__当該部分を「別の単語」に置き換えた上で、どこの部分のトークンが「別の単語に置き換えられた部分」なのかを当てさせる__タスクを課す。

  3. 学習コーパスのなかの一部の単語を、「別の単語」に置き換える作業は、Generator役のAI Agentが担う。どの単語が「別の単語」に置き換えられたのかを、Discriminator役のAI Agentが「見抜く」

  4. 置き換える「別の単語」が、前後の文脈に照らして「自然」で適切な単語だと、Discriminator役のAI Agentは、どの単語が置き換えられたのかを容易に見抜けなくなり、学習が進まない。なので、Generator役のAI Agentの精度を適度に抑えるために、Generator役のAI Agentのネットワークのサイズは小さくする

最後の点は、通常のGANモデルのアプローチと同様に、Generator役のAI AgentとDiscriminator役のAI Agentを、とことん互いに切磋琢磨しあわせることで、とてつもなく精度の高いマルチタスク対応可能なpretrained modelを生み出すことができるのではないのかな?? と疑問に感じました。

なぜ、1/4の学習データで、RoBERTaに並ぶ精度を出せるのか?

学習用データをすべて使って、学習をするからです。(BERTは、学習用データの一部しか、学習素材として活用しない)

・ BERT: Masked Language Modellingタスクでは、学習用のコーパスの内、[Mask]したトークン部分しか、学習段階で利用しない([Mask]されなかったトークンは、学習データをそのまま出力すると「正解」になる)。
・ Electra: 学習用コーパスデータの__全トークンをすべて(モデルに)眺めさせた上で__、どの部分(箇所)のトークンが、「別の単語」に置き換えられた(文脈上、)不自然なトークンであるのかを、発見(検出)させる。

GeneratorG)とDiscriminatorD) )

・ Generator役とDiscriminator役の2つのAIが対峙する構図。この部分は、__GANモデル__の枠組みを取り入れている。
・ Generator役のAI Agentの精度を適度に抑えるために、Generator役のAI Agentのネットワークのサイズは、Discriminator役のAI Agentよりも小さくする。

(1)論文中に掲載されている例では、3番目のトークン "cooked"が、"ate"(eatの過去形)に置き換えられています。
(2)Discriminator(D)は、この3番目のトークンが、Generator(G)によって置き換えられたことを見抜く(検出する)ことができれば、正解です。
(3)置き換え前の文「シェフは料理を作った」と比べると、置き換え後の文「シェフは料理を食べた」は、一般的に少し「変な」文だと(人間がみると)感じます。「食べた」ではなく、「作った」じゃない?一般的に。という感覚を、Discriminatorが、学習用コーパスを「学習」する過程で身に着けることができるかが、問われています。

スクリーンショット 2020-12-22 22.10.17.png

( 学習段階で用いる損失関数 )

i番目のtokenが、Generator役のAI Agentによって置き換えられたものであるか、そうでないか、2クラス分類を行います。そのため、0〜1の値を出力する「シグモイド関数」を採用している。

スクリーンショット 2020-12-22 22.10.41.png

( 参考 ) シグモイド関数(Wikipedia日本語版より)
シグモイド関数(Wikipedia日本語版より)

その他、ElectraがBERT(系列の手法)より優れていること

・ BERTは、学習段階と推論段階とで、実行するデータ処理の内容が、完全に一致しないが、Electraでは一致する。
・ BERTは、学習段階では[Mask]トークンが含まれたコーパスを扱います。ところが、推論段階で新たに入力される文章には、通常、[Mask]トークンは含まれていない。
・ Electraが学習段階で扱うコーパスは、普通の英語の単語が敷き詰められた普通の文章コーパスである。学習段階で取り組むコーパスが、「別の単語」に置き換えられはしているものの、英単語が並んでいるコーパスという点では、推論段階で入力される文章と同じ形式のデータである。

( Google AI Blog )
Google AI Blog More Efficient NLP Model Pre-training with ELECTRA
Tuesday, March 10, 2020 More Efficient NLP Model Pre-training with ELECTRA

スクリーンショット 2020-12-22 21.38.02.png

( Google公式GitHubリポジトリ )

(GitHUb) google-research/electra

学習済モデルとして、3つのサイズが公開されている

スクリーンショット 2020-12-22 22.18.53.png

( 解説している和文記事 )

AI SCHOLAR 2020年03月13日 「ELECTRA」新たな自然言語処理モデルが示したMLMの問題点とは!?
創造日記 2020年3月11日 (論文まとめ)ELECTRA : Pre-training Text Encoders As Discriminators Rather than Generators
オブジェクトの広場 2020年12月17日 はじめての自然言語処理 第12回 ELECTRA(BERT の事前学習手法の改良)による固有表現抽出の検証

( 日本語データセットで学習済のElectra pretraained model )

cinammon AI BLOG 2020年6月22日 自然言語処理の最新モデル 日本語版ELECTRAを公開しました

(GitHub) Cinnamon/electra_japanese

TensorFlow Hubには、学習済の言語モデルが多数収録されています。

TensorFlow Hub text-embedding)

スクリーンショット 2020-12-22 21.36.04.png

学習済みのElectraモデルを動かす

TensorFlow Hub electra

以下から、electgra_smallを開く

TensorFlow HUb electra

スクリーンショット 2020-12-22 21.34.18.png

TensorHubに登録済みのElectraを実行してみた

TensorFlow electra_small/TF2.0 Saved Model (v2)

Python3.6.3
ocean@AfoGuardMacBook-Pro Desktop % python
Python 3.6.3 (default, Dec 10 2020, 22:43:16) 
[GCC Apple LLVM 12.0.0 (clang-1200.0.32.27)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import numpy as np
>>> import tensorflow as tf
>>> import tensorflow_hub as hub
>>> import tensorflow_text as text
>>> 
>>> sentences = [
  "Here We Go Then, You And I is a 1999 album by Norwegian pop artist Morten Abel. It was Abel's second CD as a solo artist.",
  "The album went straight to number one on the Norwegian album chart, and sold to double platinum.",
  "Ceylon spinach is a common name for several plants and may refer to: Basella alba Talinum fruticosum",
  "A solar eclipse occurs when the Moon passes between Earth and the Sun, thereby totally or partly obscuring the image of the Sun for a viewer on Earth.",
  "A partial solar eclipse occurs in the polar regions of the Earth when the center of the Moon's shadow misses the Earth.",
]
>>> 
>>> preprocess = hub.load('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2')
2020-12-22 21:14:56.626050: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2020-12-22 21:14:56.626254: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2020-12-22 21:14:58.182271: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
>>> 
>>> bert = hub.load('https://tfhub.dev/google/electra_small/2')
>>> 
>>> bert_inputs = preprocess(sentences)
>>> print(type(bert_inputs))
<class 'dict'>
>>> 
>>> from pprint import pprint
>>> pprint(bert_inputs)
{'input_mask': <tf.Tensor: shape=(5, 128), dtype=int32, numpy=
array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
      dtype=int32)>,
 'input_type_ids': <tf.Tensor: shape=(5, 128), dtype=int32, numpy=
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
      dtype=int32)>,
 'input_word_ids': <tf.Tensor: shape=(5, 128), dtype=int32, numpy=
array([[  101,  2182,  2057,  2175,  2059,  1010,  2017,  1998,  1045,
         2003,  1037,  2639,  2201,  2011,  5046,  3769,  3063, 22294,
         2368, 16768,  1012,  2009,  2001, 16768,  1005,  1055,  2117,
         3729,  2004,  1037,  3948,  3063,  1012,   102,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0],
       [  101,  1996,  2201,  2253,  3442,  2000,  2193,  2028,  2006,
         1996,  5046,  2201,  3673,  1010,  1998,  2853,  2000,  3313,
         8899,  1012,   102,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0],
       [  101, 16447,  6714,  6776,  2003,  1037,  2691,  2171,  2005,
         2195,  4264,  1998,  2089,  6523,  2000,  1024, 14040,  2721,
        18255, 21368,  2378,  2819, 10424, 21823, 13186,  2819,   102,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0],
       [  101,  1037,  5943, 13232,  5158,  2043,  1996,  4231,  5235,
         2090,  3011,  1998,  1996,  3103,  1010,  8558,  6135,  2030,
         6576, 27885, 28817,  4892,  1996,  3746,  1997,  1996,  3103,
         2005,  1037, 13972,  2006,  3011,  1012,   102,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0],
       [  101,  1037,  7704,  5943, 13232,  5158,  1999,  1996, 11508,
         4655,  1997,  1996,  3011,  2043,  1996,  2415,  1997,  1996,
         4231,  1005,  1055,  5192, 22182,  1996,  3011,  1012,   102,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0]], dtype=int32)>}
>>> 
>>> bert_outputs = bert(bert_inputs)
>>> print(type(bert_outputs))
<class 'dict'>
>>> 
>>> print(bert_outputs)
{'sequence_output': <tf.Tensor: shape=(5, 128, 256), dtype=float32, numpy=
array([[[-6.51304305e-01,  1.97172821e-01,  4.23162282e-01, ...,
          1.14034176e+00, -1.18491602e+00, -5.26066661e-01],
        [-5.33896029e-01, -1.33612430e+00, -6.00645900e-01, ...,
         -1.31376281e-01, -1.63535762e+00,  2.68128633e-01],
        [-4.67965692e-01,  1.00177482e-01, -7.08638966e-01, ...,
         -5.79182804e-01, -5.64018846e-01,  4.12297666e-01],
        ...,
        [ 4.49147880e-01,  1.31464213e-01, -3.28022912e-02, ...,
         -4.12482798e-01, -4.43876266e-01, -6.99156821e-01],
        [ 5.01909733e-01,  2.25501731e-01,  1.13074489e-01, ...,
         -3.90683711e-01, -5.02406836e-01, -7.42523551e-01],
        [ 4.54771638e-01,  1.97450966e-01,  1.78489164e-01, ...,
         -3.88776809e-01, -5.82028866e-01, -8.34517002e-01]],

( ・・・省略・・・ )
        ...,
        [ 5.88703930e-01,  2.08802462e-01,  1.21231206e-01, ...,
         -4.46579695e-01, -4.31066483e-01, -4.42861825e-01],
        [ 4.31374907e-01, -5.79594672e-02,  1.28517807e-01, ...,
         -4.48998451e-01, -4.31935519e-01, -5.75835943e-01],
        [ 6.19926453e-01,  1.23447776e-01,  2.12610722e-01, ...,
         -4.62378234e-01, -4.12042171e-01, -4.83302236e-01]]],
      dtype=float32)>], 'pooled_output': <tf.Tensor: shape=(5, 256), dtype=float32, numpy=
array([[-5.7254738e-01,  1.9465676e-01,  3.9959115e-01, ...,
         8.1452924e-01, -8.2899541e-01, -4.8236871e-01],
       [ 2.4967931e-01,  5.9453052e-01,  5.2411832e-02, ...,
         4.8331147e-01, -6.2177205e-01, -1.0117788e-01],
       [-4.5478921e-02,  6.4704663e-01,  1.6987489e-01, ...,
         8.6458802e-01, -7.7583915e-01, -7.3523715e-02],
       [ 3.0683041e-01,  2.9057503e-01, -6.7763887e-02, ...,
         6.5365136e-01, -6.2018847e-01, -1.3292597e-01],
       [-1.3636658e-01,  4.0657955e-01,  7.6103024e-05, ...,
         8.9919376e-01, -6.3883138e-01,  5.0814760e-01]], dtype=float32)>}
>>> 
>>> pooled_output = bert_outputs['pooled_output']
>>> print(type(pooled_output))
<class 'tensorflow.python.framework.ops.EagerTensor'>
>>> 
>>> sequence_output = bert_outputs['sequence_output']
>>> print(type(sequence_output))
<class 'tensorflow.python.framework.ops.EagerTensor'>
>>> 
>>> sequence_output = bert_outputs['sequence_output']
>>> print(type(sequence_output))
<class 'tensorflow.python.framework.ops.EagerTensor'>
>>> 
>>> print(pooled_output)
tf.Tensor(
[[-5.7254738e-01  1.9465676e-01  3.9959115e-01 ...  8.1452924e-01
  -8.2899541e-01 -4.8236871e-01]
 [ 2.4967931e-01  5.9453052e-01  5.2411832e-02 ...  4.8331147e-01
  -6.2177205e-01 -1.0117788e-01]
 [-4.5478921e-02  6.4704663e-01  1.6987489e-01 ...  8.6458802e-01
  -7.7583915e-01 -7.3523715e-02]
 [ 3.0683041e-01  2.9057503e-01 -6.7763887e-02 ...  6.5365136e-01
  -6.2018847e-01 -1.3292597e-01]
 [-1.3636658e-01  4.0657955e-01  7.6103024e-05 ...  8.9919376e-01
  -6.3883138e-01  5.0814760e-01]], shape=(5, 256), dtype=float32)
>>> 
>>> print(sequence_output)
tf.Tensor(
[[[-6.51304305e-01  1.97172821e-01  4.23162282e-01 ...  1.14034176e+00
   -1.18491602e+00 -5.26066661e-01]
  [-5.33896029e-01 -1.33612430e+00 -6.00645900e-01 ... -1.31376281e-01
   -1.63535762e+00  2.68128633e-01]
  [-4.67965692e-01  1.00177482e-01 -7.08638966e-01 ... -5.79182804e-01
   -5.64018846e-01  4.12297666e-01]
  ...
  [ 4.49147880e-01  1.31464213e-01 -3.28022912e-02 ... -4.12482798e-01
   -4.43876266e-01 -6.99156821e-01]
  [ 5.01909733e-01  2.25501731e-01  1.13074489e-01 ... -3.90683711e-01
   -5.02406836e-01 -7.42523551e-01]
  [ 4.54771638e-01  1.97450966e-01  1.78489164e-01 ... -3.88776809e-01
   -5.82028866e-01 -8.34517002e-01]]

 [[ 2.55070806e-01  6.84644520e-01  5.24599105e-02 ...  5.27296066e-01
   -7.27888882e-01 -1.01525277e-01]
  [ 2.32167661e-01  1.40547127e-01 -4.19199765e-01 ... -3.20724636e-01
   -4.32940781e-01 -5.45868337e-01]
  [-1.23737156e-01 -3.58567059e-01 -4.82293725e-01 ...  2.99984127e-01
    9.23124373e-01 -2.30475098e-01]
  ...
  [ 6.71229661e-01  3.02966118e-01  6.00305051e-02 ... -3.68209392e-01
   -2.77368844e-01 -3.86038721e-01]
  [ 5.10866106e-01  1.69248492e-01  8.19482133e-02 ... -3.35396767e-01
   -4.33593214e-01 -5.66044271e-01]
  [ 5.85968196e-01  2.73102880e-01  7.47930557e-02 ... -4.43574101e-01
   -3.76365542e-01 -6.17670596e-01]]

 [[-4.55103181e-02  7.70201504e-01  1.71537846e-01 ...  1.31123722e+00
   -1.03483212e+00 -7.36566335e-02]
  [ 3.39367300e-01 -8.28021705e-01 -6.72565579e-01 ...  3.82193267e-01
   -8.84645343e-01  2.86033750e-02]
  [ 5.35988033e-01 -9.91185680e-02 -5.02614155e-02 ...  2.43397042e-01
   -2.87167251e-01 -2.08450347e-01]
  ...
  [ 6.17190599e-01  2.65768200e-01  1.58384234e-01 ... -3.78390461e-01
   -3.30653846e-01 -4.40317780e-01]
  [ 6.29271448e-01  2.72890776e-01  2.45292068e-01 ... -3.38562310e-01
   -2.39716768e-01 -4.87368494e-01]
  [ 8.02732527e-01  9.94282141e-02  1.29609987e-01 ... -4.14599776e-01
   -1.85184866e-01 -3.86599004e-01]]

 [[ 3.17042619e-01  2.99194217e-01 -6.78678975e-02 ...  7.81647563e-01
   -7.25311279e-01 -1.33717299e-01]
  [ 2.92305313e-02 -6.05383039e-01  1.62178650e-03 ...  2.14333817e-01
   -1.07719374e+00  7.46604741e-01]
  [ 4.12054986e-01 -5.27828157e-01 -5.22005796e-01 ... -5.80940068e-01
   -2.28683174e-01  6.76999211e-01]
  ...
  [ 6.14811718e-01  9.08084288e-02  1.22900024e-01 ... -4.83617455e-01
   -5.61085939e-01 -4.56480205e-01]
  [ 6.63995862e-01  2.27823853e-02  1.21976584e-01 ... -4.74959046e-01
   -3.69891405e-01 -4.18557405e-01]
  [ 6.16299689e-01 -1.20768845e-01  3.73016447e-02 ... -4.32481796e-01
   -3.58047962e-01 -4.39199656e-01]]

 [[-1.37221441e-01  4.31506544e-01  7.61030242e-05 ...  1.46799219e+00
   -7.56196856e-01  5.60229361e-01]
  [ 8.77072662e-02 -5.99296033e-01 -3.76423076e-02 ...  3.57749104e-01
   -1.25473130e+00  9.52858925e-01]
  [-1.10358655e-01 -8.14601853e-02 -6.13647103e-01 ... -1.25089556e-01
   -5.48545003e-01  5.75951219e-01]
  ...
  [ 5.88703930e-01  2.08802462e-01  1.21231206e-01 ... -4.46579695e-01
   -4.31066483e-01 -4.42861825e-01]
  [ 4.31374907e-01 -5.79594672e-02  1.28517807e-01 ... -4.48998451e-01
   -4.31935519e-01 -5.75835943e-01]
  [ 6.19926453e-01  1.23447776e-01  2.12610722e-01 ... -4.62378234e-01
   -4.12042171e-01 -4.83302236e-01]]], shape=(5, 128, 256), dtype=float32)
>>> 

Hugggingfaceにも収録されている。

Huggingface Docs >> ELECTRA

スクリーンショット 2020-12-22 21.42.48.png

サンプルコードは、tuple型のオブジェクトへのアクセスの仕方がおかしい。実行するとエラーが起きます。

Python3.6.3
% python
Python 3.6.3 (default, Dec 10 2020, 22:43:16) 
[GCC Apple LLVM 12.0.0 (clang-1200.0.32.27)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> from transformers import ElectraTokenizer, ElectraModel
>>> import torch
>>> 
>>> tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
Downloading: 100%|██████████████████████████████████████████████████████████████████████████████████████| 232k/232k [00:00<00:00, 331kB/s]
>>> 
>>> model = ElectraModel.from_pretrained('google/electra-small-discriminator')
Downloading: 100%|████████████████████████████████████████████████████████████████████████████████████████| 466/466 [00:00<00:00, 127kB/s]
Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████| 54.2M/54.2M [00:03<00:00, 13.8MB/s]
>>> 
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> 
>>> last_hidden_states = outputs.last_hidden_state
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'tuple' object has no attribute 'last_hidden_state'
>>> 
11
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?