2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【深層学習】TabBERTモデルによるクレジットカードの不正利用検知 ~環境構築・事前学習編~

Last updated at Posted at 2022-09-18

はじめに

※こちらは、8月に投稿した記事にコード解説の項目を加筆して再投稿したものとなります。以前の記事よりマシになっているはずなので、もう一度読んでいただけると幸いです。

今まで主にスマホアプリ開発をしていたのですが、2022年8月からAI関連の案件に携わっており、BERTという自然言語処理モデルについて勉強しています。

最近は、多変量の時系列表データの学習に使用する、TabBERT(Hierarchical Tabular BERT)というBERTの応用モデルに関する論文を読みました(Qiitaにも紹介記事を書きました)。

この論文、なんとありがたいことに、著者が事前学習のコードとデータをGitHubにあげています。

環境構築やモデル学習の経験を積むことで理解をさらに深められると思い、こちらのコードをいろいろ触ってみることにしました。
とりあえず第一歩として、AWSで仮想環境をたてて事前学習を実行するところまでをやってみました。

以下について知りたい方に、特に役立つ記事になると思います。

  • GPUインスタンスの立て方(AMIとインスタンス選び、vCPUの制限解除方法)
  • Anacondaによる環境構築(condaのエラー解消、仮想環境の選択)
  • 事前学習まわりのコードの中身

※仮想環境を選択したのは、ローカルマシンのGPUがNVIDIA製ではなく、CUDAとcuDNNが動かせなかったためです。setup.ymlで環境構築を行おうとしたら、ResolvePackageNotFoundエラーでいきなり失敗しました...

(base) [9:10:58] → conda env create -f setup.yml
                                     
Collecting package metadata (repodata.json): | / done
Solving environment: failed

ResolvePackageNotFound:
  - pytorch==1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0
setup.yml
name: tabformer
channels:
  - anaconda
  - pytorch
  - huggingface
  - conda-forge
dependencies:
  - python>=3.8
  - pip>=21.0
  - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0
  - torchvision
  - pandas
  - scikit-learn
  - transformers
  - numpy
  - libgcc
  - pip:
      - transformers==3.2.0

環境構築

AWSでGPUインスタンスをたてて、その中でさらにAnacondaで仮想環境をつくるまでの流れとなります。

GPUインスタンスの作成

AWSで用意されていたDeep Learning用のAMIと一番安価なGPUインスタンスを使用します。

  • AMI(Amazon Machine Image): amazon/Deep Learning AMI GPU PyTorch 1.12.0 (Amazon Linux 2) 20220803
  • インスタンスのタイプ: g4dn.2xlarge(xlargeだとメモリが枯渇したため) *説明ではg4ad.xlargeとなっています。すみません。
  • EBSボリューム: 300 GB(45 GBだとストレージが枯渇したためOSError: [Errno 28] No space left on device)

スクリーンショット 2022-08-11 22.18.11.png

インスタンスを起動しようとしたら失敗しました。
スクリーンショット 2022-08-11 22.23.50.png

g4adインスタンスを作成するのに必要なvCPUが足りないとのことです。
スクリーンショット 2022-08-11 22.25.08.png

右上の「制限緩和のリクエスト」から以下のリクエストを送ったところ、4時間くらいで制限が解除され、無事にインスタンスを起動することができました。
スクリーンショット 2022-08-11 22.27.47.png

Anacondaで仮想環境構築

SSHでEC2インスタンスに接続します。
※パブリックIPはコンソール画面のEC2 > インスタンスからパブリックIPv4アドレスをコピーしてください。

(base) [10:55:26] → ssh -i ~/.ssh/MyKeyPair.pem ec2-user@<パブリックIP>

インスタンスにコードを落とします。

[ec2-user@ip-172-31-21-45 ~]$ git clone https://github.com/IBM/TabFormer.git

setup.ymlのあるディレクトリに移動して、conda env create ~を実行します。

[ec2-user@ip-172-31-21-45 ~]$ cd TabFormer
[ec2-user@ip-172-31-21-45 TabFormer]$ conda config --set channel_priority flexible
[ec2-user@ip-172-31-21-45 TabFormer]$ conda env create -f setup.yml

conda config --set channel_priority flexibleをしないと、以下の文言が出て構築に失敗します。

Collecting package metadata (repodata.json): done
Solving environment: /
Found conflicts! Looking for incompatible packages.

conda init後に、作成した仮想環境(tabformer)をアクティベートします(アクティベート前に再度リモートログインが必要)。

[ec2-user@ip-172-31-21-45 TabFormer]$ conda init bash
[ec2-user@ip-172-31-21-45 TabFormer]$ conda activate tabformer

最後に、特定バージョンのcudatoolkitpytorchがコンフリクトするというAnacondaのバグがあったので、以下を実行します。

[ec2-user@ip-172-31-21-45 TabFormer]$ pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

※以上をインストールしないと、python実行時に以下のエラーが出ます。

OSError: /home/ka37/anaconda3/envs/fail/lib/python3.8/site-packages/torch/lib/../../../../libcublas.so.11: symbol free_gemm_select version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference

追記

まだ以下のようなエラーが出る場合があります。

ImportError: /lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29' not found (required by /opt/conda/envs/tabformer-opt/lib/python3.8/site-packages/pandas/_libs/window/aggregations.cpython-38-x86_64-linux-gnu.so)

libstdc++.so.6GLIBCXX_3.4.29のバージョンが含まれていないとのことです。

そもそもlibstdc++.so.6のパスが間違っていそうなので、こちらをconda環境が含まれているパスに修正します。
以下のコマンドでパスを調べます。

find / -name "libstdc++.so*"

すると、conda環境を含むパスにlibstdc++を確認できます。

...
/opt/conda/envs/tabformer-opt/lib/libstdc++.so.6
/opt/conda/envs/tabformer-opt/lib/libstdc++.so.6.0.21
/opt/conda/envs/tabformer-opt/lib/libstdc++.so.6.0.29

libstdc++.so.6のままだとGLIBCXX_3.4.29は存在しないので、libstdc++.so.6.0.29のパスに変更します。

export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/envs/tabformer/lib

事前学習

論文では2種類のデータが用意されているのですが、今回はクレジットカードのトランザクションデータを使用し、学習を行います。

まずは、ローカルからデータをコピーします。
ディレクトリごとコピーするために、-rオプションをつける必要があります。
*READMEに記載されていた./data/card/ではなく、./data/prsa/にデータをコピーします(多分誤記)。

(base) [14:31:15] → scp -r -i ~/.ssh/MyKeyPair.pem ~/Programs/TabFormer/data/credit_card ec2-user@18.180.227.218:/home/ec2-user/TabFormer/data/

学習を実行します。

[ec2-user@ip-172-31-21-45 TabFormer]$ python main.py --do_train --mlm --field_ce --lm_type bert --field_hs 64 --data_type card --output_dir ./output_card/

6時間でストレージ(300 GB)がいっぱいになってしまいましたが、とりあえず動きました。
以下が学習が止まった時点でのcheckpoint(35000, 35500)の中身です。

スクリーンショット 2022-08-14 22.47.06.png
作成された学習済モデルpytorch_model.binconfig.json optimizer.ptをFine-Tuningの際に使用します。
以下がconfig.jsonpytorch_model.binの読み込み例です。

prediction_card.py
from transformers import BertConfig, BertModel

config = BertConfig.from_pretrained(
    "./output_card/checkpoint-35000/config.json"
)

model = BertModel.from_pretrained("./output_card/checkpoint-35000/pytorch_model.bin", config=config)

print(config)
print(model)
config
BertConfig {
  "architectures": [
    "TabFormerHierarchicalLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "field_hidden_size": 64,
  "flatten": false,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "ncols": 12,
  "nhead": 8,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_layers": 1,
  "pad_token_id": 0,
  "total_flos": 41577468710400000,
  "type_vocab_size": 2,
  "vocab_size": 143492
}
pytorch_model
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(143492, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
    ...

コード解説

コード全文の説明は難しいので、要点のみ解説します。
以下がmain.pyです。

Dataset

TransactionDatasetクラスはPytorchのDatasetクラスを継承しており、データの前処理やトークンとidの対応を辞書に登録するようなメソッドが登録されています。

  • encode_data():生データを前処理データに変換
  • init_vocab():トークンとidの対応を登録(辞書作成)
  • user_level_data():ユーザーごとのデータとラベルをすべての連結してリストで出力
    • unique_users: Userカラムのリスト(24001行までなら[0 1])
    • user_data: Userごとのデータフレーム
    • trans_data: 1Userのラベル以外のデータを全部まとめたリスト[[…], […], […]…] special tokenはまだ入っていない
    • trans_labels: 1Userのラベルを全部まとめたリスト
  • format_trans():user_level_dataで出力された連結データを入力して、11フィールドずつのレコードに分割。その後にトークンをidに変換、[SEP]トークンを付加して出力
  • prepare_samples():format_transで出力されたデータとラベルを10データ(window)ずつ連結して出力
    • user_idx: trans_dataのインデックス
    • user_row: 1Userの生データリスト
    • user_row_ids: 1Userのid変換後のリスト [SEP]付与済
    • ids: strideで5こずつずらしながらseq_lenでuser_row_idsを10連結ずつに区切ったリスト[…]
    • self.data: 10連結リストidsをまとめたリスト[[…], […]…] 全User含む
  • __getitem__():prepare_samplesで作成されたデータからindexに対応したものを取得。flatten(平坦化)フラグがFalseであれば10行(window)ずつreshape[10, 12]。最後のカラムはformat_transで付与した[SEP]

自然言語(例えば英語)の場合であれば、単語をBertTokenizerに入力すれば対応idを出力してくれますが、TabBERT用のTokenizerは存在しないため、辞書を作成するような処理が必要となります。
トークンから辞書を作成するメソッドset_id、辞書からidを取得するメソッドget_idVocabularyクラスから呼び出して使用しています。

Model

TabFormerBertLMからモデルを読み込みます。

DataCollator

DataCollatorForLanguageModelingを継承したTransDataCollatorForLanguageModelingを作成し、トークンのマスク化やinput_idsとlabels(MASKの正解)の出力を行います。

  • mask_tokens():MLMのためにトークンをマスク化
  • __call__():Datasetの__getitem__で取得したデータをバッチサイズ分まとめてinput_ids([8, 10, 12])とlabelsとしてTensorで出力。バッチサイズは固定。バッチ内のUserは混ざっている。

Trainer

Dataset、Model、DataCollatorをそれぞれ引数にとり、事前学習をやってくれます。
DataLoaderがラップされており、DataCollatorの__call__の処理はこの中で実行してくれます。

おわりに

Fine-Tuning以降のコードはなさそうだったので、現在調査しながら実装しています。
実装できたら新しく記事を書きます。

参考資料

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?