1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

DatabricksでのBERTとPytorch Lightningによる文書分類

Last updated at Posted at 2024-09-24

何箇所か修正させていただきつつ、こちらをウォークスルーさします。

ソースコードはこちら。

クラスターの準備

GPUクラスターを使います。
Screenshot 2024-09-24 at 11.37.35.png

データの準備

こちらのファイルをボリュームにアップロードして解凍します。

unzip -d /Volumes/users/takaaki_yayoi/corpus/text /Volumes/users/takaaki_yayoi/corpus/texts.zip

Screenshot 2024-09-24 at 11.48.32.png

ノートブックの実行

ライブラリのインストール

今時点では利用できなくなっていたバージョンがいくつかあったので修正しています。

!pip install protobuf==3.9.2
!pip install transformers==4.20.1 fugashi==1.1.2 ipadic==1.0.0 torchtext==0.12.0 pytorch-lightning==1.6.4 numpy==1.21.6 openpyxl
!pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 -f https://download.pytorch.org/whl/torch_stable.html
dbutils.library.restartPython()

モデルのパスtohoku-nlp/bert-base-japanese-char-whole-word-maskingになっていました。

import glob  # ファイルの取得に使用
import os # ファイルの取得に使用
import pandas as pd # dataframeを扱う
from sklearn.model_selection import train_test_split # データ分割
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import BertJapaneseTokenizer, BertForSequenceClassification
import pytorch_lightning as pl
import random
import numpy as np
from tqdm import tqdm
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging

# 日本語の事前学習モデル
MODEL_NAME = 'tohoku-nlp/bert-base-japanese-char-whole-word-masking'
seed = 1
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

データの準備

path = "/Volumes/users/takaaki_yayoi/corpus/text/texts/"  # フォルダの場所を指定

# ディレクトリ名のみを取得
dir_files = os.listdir(path=path)
dir_names = [f for f in dir_files if os.path.isdir(os.path.join(path, f))]

# ラベル数
label_n = len(dir_names)
print(f'=={label_n}==')

# dataframe作成用にテキストとラベルのリストを作る
text_list = []
label_list = []
label_dic = {}

for i, dir_name in enumerate(dir_names):
    label_dic[i] = dir_name
    file_names = glob.glob(path + dir_name + "/*.txt") 
    for file_name in file_names:
        if os.path.basename(file_name) == "LICENSE.txt":
            continue
        with open(file_name, "r") as f:
            text = f.readlines()[3:]
            text = "".join(text)
            text = text.translate(str.maketrans({"\n":"", "\t":"", "\r":"", "\u3000":""})) 
            text_list.append(text)
            label_list.append(i)
# dataframeを作成する
document_df = pd.DataFrame(data={'text': text_list,'label': label_list})
print(f'{document_df.head()}')
                                                text  label
0  もうすぐジューン・ブライドと呼ばれる6月。独女の中には自分の式はまだなのに呼ばれてばかり……...      0
1  携帯電話が普及する以前、恋人への連絡ツールは一般電話が普通だった。恋人と別れたら、手帳に書か...      0
2  「男性はやっぱり、女性の“すっぴん”が大好きなんですかね」と不満そうに話すのは、出版関係で働...      0
3  ヒップの加齢による変化は「たわむ→下がる→内に流れる」、バストは「そげる→たわむ→外に流れる...      0
4  6月から支給される子ども手当だが、当初は子ども一人当たり月額2万6000円が支給されるはずだ...      0
# 比重でラベルの重み計算
label_count = document_df['label'].value_counts(sort=False)
label_count_dict = label_count.to_dict()
print(f'各ラベル数:{label_count_dict}')

label_weight = []
for i in range(label_n):
    rate = (label_count_dict[i] / len(document_df))*100
    weight = 100 - rate
    label_weight.append(weight)
print(f'ラベル重み:{label_weight}')
各ラベル数:{0: 859, 1: 859, 2: 854, 3: 500, 4: 859, 5: 831, 6: 860, 7: 889, 8: 758}
ラベル重み:[88.18269363048563, 88.18269363048563, 88.2514788829275, 93.12147475581236, 88.18269363048563, 88.56789104416013, 88.16893657999725, 87.76998211583437, 89.57215572981153]
# 学習&検証データとテストデータを分割
train_val_df, test_df = train_test_split(document_df, test_size=0.2, random_state=42)
# 学習データと検証データに分割
train_df, val_df = train_test_split(train_val_df, test_size=0.1, random_state=42)

# インデックスをリセット
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

クラスの定義

class DatasetGenerator(Dataset):
    def __init__(self, data, tokenizer, label_weight):
        self.data = data
        self.tokenizer = tokenizer
        self.class_weights = label_weight
        self.sample_weights = [0] * len(data)
        self.max_length = 256

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data_row = self.data.iloc[index]
        text = data_row['text']
        labels = data_row['label']
        
        encoding = self.tokenizer.encode_plus(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        return dict(
            input_ids=encoding["input_ids"].flatten(),
            attention_mask=encoding["attention_mask"].flatten(),
            token_type_ids=encoding["token_type_ids"].flatten(),
            labels=torch.tensor(labels)
        )
    
    def get_sampler(self):
        for idx, row in self.data.iterrows():
            label = row['label']
            class_weight = self.class_weights[label]
            self.sample_weights[idx] = class_weight
        sampler = WeightedRandomSampler(self.sample_weights, num_samples=len(self.sample_weights), replacement=True)
        return sampler
class DataModuleGenerator(pl.LightningDataModule):
    def __init__(self, train_df, val_df, test_df, tokenizer, batch_size, label_weight):
        super().__init__()
        self.train_df = train_df
        self.val_df = val_df
        self.test_df = test_df
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.label_weight = label_weight
        
    def setup(self):
        self.train_dataset = DatasetGenerator(self.train_df, self.tokenizer, self.label_weight)
        self.valid_dataset = DatasetGenerator(self.val_df, self.tokenizer, self.label_weight)
        self.test_dataset = DatasetGenerator(self.test_df, self.tokenizer, self.label_weight)
        self.train_sampler = self.train_dataset.get_sampler()
        self.valid_sampler = self.valid_dataset.get_sampler()

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size["train"], num_workers=os.cpu_count(), sampler=self.train_sampler)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=self.batch_size["val"], num_workers=os.cpu_count(), sampler=self.valid_sampler)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size["test"], num_workers=os.cpu_count())
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)
data_module = DataModuleGenerator(train_df, val_df, test_df, tokenizer, {"train": 8, "val": 8, "test": 8}, label_weight)
data_module.setup()
from torch import nn
from IPython.core.debugger import Pdb;

# Pytorch Lightningを使用するためのClass
class Model(pl.LightningModule):
        
    def __init__(self, model_name, num_labels, lr):
        super().__init__()
        self.save_hyperparameters() 
        self.bert_sc = BertForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_labels
        )

        
    def training_step(self, batch, batch_idx):
        output = self.bert_sc(**batch)
        weights = torch.tensor(label_weight).cuda()
        cross_entropy_loss = nn.CrossEntropyLoss(weight=weights)
        loss = cross_entropy_loss(output.logits, batch['labels'])
        self.log('train_loss', loss) # 損失を'train_loss'の名前でログをとる。
        return loss
        
    def validation_step(self, batch, batch_idx):
        output = self.bert_sc(**batch)
        val_loss = output.loss
        self.log('val_loss', val_loss) # 損失を'val_loss'の名前でログをとる。

    def test_step(self, batch, batch_idx):
        labels = batch.pop('labels') # バッチからラベルを取得
        output = self.bert_sc(**batch)
        labels_predicted = output.logits.argmax(-1)
        num_correct = ( labels_predicted == labels ).sum().item()
        accuracy = num_correct/labels.size(0) #精度
        self.log('accuracy', accuracy) # 精度を'accuracy'の名前でログをとる。

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

トレーニング

以降のトレーニングでロギングの際にエラーyou tried to log -1 which is currently not supported. Try a dict or a scalar/tensor.が発生したので、トレーナーに引数logger=Falseを追加しています。

# 学習時にモデルの重みを保存する条件を指定
checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='models',
) 

early_stopping = EarlyStopping(monitor='val_loss',verbose=True, mode="min")


# 学習の方法を指定
trainer = pl.Trainer(
    gpus=1, 
    max_epochs=20,
    callbacks = [checkpoint, early_stopping],
    logger=False
)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
# PyTorch Lightningモデルのロード
model = Model(MODEL_NAME, num_labels=label_n, lr=1e-5)

# ファインチューニングを行う。
trainer.fit(model, train_dataloaders=data_module.train_dataloader(),  val_dataloaders=data_module.val_dataloader()) 

エクスペリメントからトレーニングの進捗を確認できます。
Screenshot 2024-09-24 at 12.03.44.png

トレーニング結果の確認

from sklearn.metrics import roc_curve, precision_recall_curve, auc, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
pd.set_option("display.max_colwidth", 10000)


# モデルのロード
best_model_path = checkpoint.best_model_path
model = Model.load_from_checkpoint(best_model_path)
bert_sc = model.bert_sc.cuda()
# データの符号化
encoding = tokenizer(
    test_df['text'].tolist(), 
    return_tensors='pt',
    max_length=256, 
    padding='max_length',
    truncation=True
)
encoding = { k: v.cuda() for k, v in encoding.items() }


# BERTへデータを入力し分類スコアを得る。
with torch.no_grad():
    output = bert_sc(**encoding)
preds = output.logits.argmax(-1)
# 結果を表示
targets = test_df['label'].tolist()
preds_np = preds.to('cpu').detach().numpy().copy()

# accuracy_score
bert_a_score = accuracy_score(targets, preds_np)
# precision_score
bert_p_score = precision_score(targets, preds_np, average='macro')
# recall_score
bert_r_score = recall_score(targets, preds_np, average='macro')
# f1_score
bert_f_score = f1_score(targets, preds_np, average='macro')
# confusion_matrix
bert_confusion_matrix = confusion_matrix(targets, preds_np)
# 表示用データフレーム
df_model = pd.DataFrame(columns=['model_str', 'model', 'accuracy_score', 'precision_score', 'recall_score', 'f1_score'])
df_model = df_model.append({
      'model_str': 'BERT: tokenizer_max_length=256', 'model' : 'BERT' , 'accuracy_score' : bert_a_score,
      'precision_score' : bert_p_score, 'recall_score' : bert_r_score, 'f1_score' : bert_f_score
    } , ignore_index=True)
df_model = df_model.reindex(columns=['model_str', 'model', 'accuracy_score', 'precision_score', 'recall_score', 'f1_score'])
display(df_model)
bert_classification_report = classification_report(targets, preds_np)
print(bert_classification_report)
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
cm = confusion_matrix(targets, preds_np)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
plt.show()

Screenshot 2024-09-24 at 12.05.25.png

             precision    recall  f1-score   support

           0       0.97      0.78      0.87       186
           1       0.95      0.79      0.86       183
           2       0.74      0.88      0.80       166
           3       0.90      0.76      0.82       103
           4       0.89      0.95      0.92       153
           5       0.83      0.95      0.89       171
           6       0.94      0.98      0.95       163
           7       0.99      0.95      0.97       185
           8       0.87      0.96      0.91       144

    accuracy                           0.89      1454
   macro avg       0.90      0.89      0.89      1454
weighted avg       0.90      0.89      0.89      1454

混同行列(Confusion Matrix)からも精度良く分類できていることがわかります。

download.png

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?