2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

住所の表記ゆれを生成AIで何とかできないか(Fine Tuning編-最終版)

Last updated at Posted at 2023-10-15

前回の記事でFine Tuningを行ったが、とりあえずやっただけの記事になっていたので、ちゃんと最終形にして公開します。

学習データ作成

使用するモジュールは2つ。

import numpy as np
import pandas as pd

まず、数値を「〇丁目」に変換する関数を用意。

# '一丁目'みたいに変換
def get_choume(c):
    num10 = ["", "", "", "", "", "", "", "", "", ""]
    num1  = ["", "", "", "", "", "", "", "", "", ""]
    
    juu = ""
    if (c//10 >= 1):
        juu = ""
    return num10[c//10] + juu + num1[c%10] + "丁目"

一気に「anchor」「positive」「negative」の組み合わせで、「1丁目1番1号」~「19丁目19番19号」を作成。

# データ作成
list_anchor = []
list_positive = []
list_negative = []

for c in range(1,20):
    for b in range(1,20):
        for g in range(1,20):
            print(str(c) + "-" + str(b) + "-" + str(g))
                
            # 間違い
            for inv in range(1,20):
                # 丁目違い
                if inv != c:
                    # 一丁目1番1号
                    list_anchor.append(get_choume(c) + str(b) + "" + str(g) + "")
                    list_positive.append(str(c) + "-" + str(b) + "-" + str(g))
                    list_negative.append(get_choume(inv) + str(b) + "" + str(g) + "")

                    # 1丁目1番1号
                    list_anchor.append(str(c) + "丁目" + str(b) + "" + str(g) + "")
                    list_positive.append(str(c) + "-" + str(b) + "-" + str(g))
                    list_negative.append(str(inv) + "丁目" + str(b) + "" + str(g) + "")

                    # 1-1-1
                    list_anchor.append(str(c) + "-" + str(b) + "-" + str(g))
                    list_positive.append(get_choume(c) + str(b) + "" + str(g) + "")
                    list_negative.append(str(inv) + "-" + str(b) + "-" + str(g))
                    list_anchor.append(str(c) + "-" + str(b) + "-" + str(g))
                    list_positive.append(str(c) + "丁目" + str(b) + "" + str(g) + "")
                    list_negative.append(str(inv) + "-" + str(b) + "-" + str(g))

                # 番違い
                if inv != b:
                    # 一丁目1番1号
                    list_anchor.append(get_choume(c) + str(b) + "" + str(g) + "")
                    list_positive.append(str(c) + "-" + str(b) + "-" + str(g))
                    list_negative.append(get_choume(c) + str(inv) + "" + str(g) + "")

                    # 1丁目1番1号
                    list_anchor.append(str(c) + "丁目" + str(b) + "" + str(g) + "")
                    list_positive.append(str(c) + "-" + str(b) + "-" + str(g))
                    list_negative.append(str(c) + "丁目" + str(inv) + "" + str(g) + "")

                    # 1-1-1
                    list_anchor.append(str(c) + "-" + str(b) + "-" + str(g))
                    list_positive.append(get_choume(c) + str(b) + "" + str(g) + "")
                    list_negative.append(str(c) + "-" + str(inv) + "-" + str(g))
                    list_anchor.append(str(c) + "-" + str(b) + "-" + str(g))
                    list_positive.append(str(c) + "丁目" + str(b) + "" + str(g) + "")
                    list_negative.append(str(c) + "-" + str(inv) + "-" + str(g))

                # 号違い
                if inv != g:
                    # 一丁目1番1号
                    list_anchor.append(get_choume(c) + str(b) + "" + str(g) + "")
                    list_positive.append(str(c) + "-" + str(b) + "-" + str(g))
                    list_negative.append(get_choume(c) + str(b) + "" + str(inv) + "")

                    # 1丁目1番1号
                    list_anchor.append(str(c) + "丁目" + str(b) + "" + str(g) + "")
                    list_positive.append(str(c) + "-" + str(b) + "-" + str(g))
                    list_negative.append(str(c) + "丁目" + str(b) + "" + str(inv) + "")

                    # 1-1-1
                    list_anchor.append(str(c) + "-" + str(b) + "-" + str(g))
                    list_positive.append(get_choume(c) + str(b) + "" + str(g) + "")
                    list_negative.append(str(c) + "-" + str(b) + "-" + str(inv))
                    list_anchor.append(str(c) + "-" + str(b) + "-" + str(g))
                    list_positive.append(str(c) + "丁目" + str(b) + "" + str(g) + "")
                    list_negative.append(str(c) + "-" + str(b) + "-" + str(inv))

# DataFrameにまとめる
out_data = pd.DataFrame(
        data={'anchor': list_anchor, 'positive': list_positive, 'negative': list_negative},
        columns=['anchor', 'positive', 'negative']
    )

作成した学習データをCSVファイルに保存。

# ファイルに保存
out_data.to_csv("./output.csv")

作成したファイルはこんな感じ。
image.png

Fine Tuning

学習済みモデルの取得。

from sentence_transformers import SentenceTransformer, SentencesDataset, InputExample, losses, models
from sentence_transformers.evaluation import TripletEvaluator

# 学習済みモデルの読み込み
bert = models.Transformer('sonoisa/sentence-bert-base-ja-mean-tokens-v2')

最後にpooling層を追加。

# pooling層の追加
pooling = models.Pooling(bert.get_word_embedding_dimension())
model = SentenceTransformer(modules=[bert, pooling])

学習データの読み込み。

# 学習データの読み込み
import pandas as pd
datasets_df = pd.read_csv('./output.csv')

# 学習用と評価用に分割(9:1)
from sklearn.model_selection import train_test_split
train, test = train_test_split(datasets_df, train_size=0.9, random_state=4)

ここで、パラメータ値の設定。(必要に応じて調整)

BATCH_SIZE = 16
NUM_EPOCHS = 1
EVAL_STEPS = 1000
WARMUP_STEPS = int(len(train) // BATCH_SIZE * 0.1) 
OUTPUT_PATH = "./sbert_1"

学習データのローダーとロスを作成。

from torch.utils.data import DataLoader

# 学習用ローダーの作成
train_dataset = SentencesDataset([InputExample(texts=[row["anchor"], row["positive"], row["negative"]]) for index,row in train.iterrows()], model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)

# ロスの作成
train_loss = losses.TripletLoss(model=model)

評価の作成。

# 評価の作成
evaluator = TripletEvaluator(test["anchor"].values, test["positive"].values, test["negative"].values)

※ネット上の情報で、データローダーを引数とした「TripletEvaluator(test_dataloader)」が見つかるが、少なくとも今は動かないので注意

学習の実行。

# 学習
model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        evaluator=evaluator,
        epochs=NUM_EPOCHS,
        evaluation_steps=EVAL_STEPS,
        warmup_steps=WARMUP_STEPS,
        output_path=OUTPUT_PATH,
        )

Fine Tuningした結果は、指定したフォルダに保存される。
こんな感じで。
image.png
「pytorch_model.bin」がFine Tuning結果本体で、後は設定値等々。
image.png
image.png
「triplet_evaluation_results.csv」にロス結果が入っている。

ということで、ロスのグラフを表示。
(最後の行の「steps」の値が-1になるので、強制的に修正している)

# ロスのグラフを表示
eval_df = pd.read_csv(OUTPUT_PATH+"/eval/triplet_evaluation_results.csv")
eval_df["steps"][len(train)//(BATCH_SIZE*EVAL_STEPS)*NUM_EPOCHS] = len(train) // BATCH_SIZE
eval_df.plot(x="steps", y=["accuracy_cosinus", "accuracy_manhattan", "accuracy_euclidean"])

こんなグラフが表示される。
(epochが1より大きいと重なってしまうので、工夫が必要)
image.png

評価結果

Fine Tuningしたモデルの読み込み。

from sentence_transformers import SentenceTransformer, SentencesDataset, InputExample, losses, models

# Fine Tuningしたモデルの読み込み
model = SentenceTransformer(OUTPUT_PATH)

評価したいデータ。

# 評価したいデータ
sentences = [
    # 正解
    "東京都港区海岸一丁目2番3号",
    # 表記ゆれ
    "東京都港区海岸一丁目2番3号",
    "東京都港区海岸一丁目二番三号",
    "東京都港区海岸1丁目2番3号",
    "東京都港区海岸1丁目2番3号",
    "東京都港区海岸1-2-3",
    "東京都港区海岸1-2-3",
    # 不正解
    "東京都港区海岸一丁目1番3号",
    "東京都港区海岸一丁目1番3号",
    "東京都港区海岸一丁目一番三号",
    "東京都港区海岸1丁目1番3号",
    "東京都港区海岸1丁目1番3号",
    "東京都港区海岸1-1-3",
    "東京都港区海岸1-1-3",

    "東京都港区海岸一丁目1番1号",
    "東京都港区海岸一丁目1番1号",
    "東京都港区海岸一丁目一番一号",
    "東京都港区海岸1丁目1番1号",
    "東京都港区海岸1丁目1番1号",
    "東京都港区海岸1-1-1",
    "東京都港区海岸1-1-1"
]

評価の実行。

# 評価
sentence_vectors = model.encode(sentences)

正解とのコサイン距離の計算。

# コサイン距離の計算
import scipy.spatial
distances = scipy.spatial.distance.cdist(sentence_vectors, sentence_vectors, metric="cosine")[0]

一応、距離を数値で取得しておく。

distances

こんな感じに出力される。(1つ目は正解と正解の距離なので、無視する)
image.png

距離を棒グラフで可視化。

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

# 棒グラフで距離を可視化
left = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])
plt.xticks(left)
height = distances[1:]
plt.bar(left, height)

こんな感じに出力される。
image.png

最後に

とりあえずこれでひと段落だが、せっかくなので、今後もいろいろとFine Tuningさせてみたい。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?