1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ローカルLLMで文章の分類(classification)をしたメモ

Last updated at Posted at 2025-01-19

やりたいこと

LLMで文章を分類したい。一括で処理したいのでPythonプログラムで。そしてChatGPTとかだと費用がかかるのでローカルLLMでやりたい

プロンプトで分類を指示するやりかたがもあるが、「分類だけを返して」と指示しても、しばしば余計なテキストがついてくる。Scikit-LLMがあるらしいがローカルLLMでやりたい

やったこと

scikit-ollamaを試した。

準備

  • ollamaを導入

  • 分類したいテキストと分類項目をならべてzero_shot_classification.csvとして保管
    ここでは学習しないので分類項目を保存するだけ。テキストと分類が対応しなくてOK。実際の分類判断はLLMがもともと持っている知識に依存している。

  • 表の例(要約とカテゴリは対応していなくてよい)
id 要約 カテゴリ
1 政府は本日、新たな経済対策を発表しました。 ニュース
2 昨日のサッカーの試合は劇的な結末を迎えた。 レビュー
3 このスマートフォンは非常に使いやすい。バッテリーの持ちも良い。 チャット
4 このレストランの料理は美味しかったが、サービスが少し遅かった。 法律文書
5 おはよう!今日の天気はどう? 技術文書
6 昨日の映画、めっちゃ面白かったよ! ニュース
7 契約第3条に基づき、支払いは30日以内に行うものとする。 レビュー
8 この合意は双方の同意のもと締結されるものとする。 チャット
9 このプログラムはPythonで実装されており、データ処理に適している。 法律文書
10 機械学習モデルの精度を向上させるためには、適切なデータ前処理が重要である。 技術文書

コード

例によって赤ペン先生GPTにがっつり指導してもらった

classification.py
import pandas as pd
from skollama.models.ollama.classification.zero_shot import ZeroShotOllamaClassifier
from typing import Optional

# ------------------------------ #
#        設定(定数)             #
# ------------------------------ #

DEFAULT_CSV_FILE = "zero_shot_classification.csv"  # 入力CSVファイル(要約とカテゴリのデータ)
DEFAULT_OUTPUT_FILE = "predicted_categories.csv"  # 予測結果を保存するCSVファイル
MODEL_NAME = "phi4:latest"  # 使用するZero-Shotモデルの名前

# 列名の定数化(データフレームのカラム名)
COLUMN_SUMMARY = "要約"  # 入力データの要約(特徴量)
COLUMN_CATEGORY = "カテゴリ"  # 分類するカテゴリ(COLUMN_SUMMARY列と対応しなくてよい)
COLUMN_PREDICTED_CATEGORY = "predicted_category"  # 予測されたカテゴリ(出力結果)


# ------------------------------ #
#        CSV読み込み関数           #
# ------------------------------ #

def load_csv(file_path: str) -> pd.DataFrame:
    """CSVファイルをUTF-8形式で読み込む

    Args:
        file_path (str): 読み込むCSVファイルのパス

    Returns:
        pd.DataFrame: 読み込んだデータフレーム(失敗時は空のデータフレーム)
    """
    try:
        df = pd.read_csv(file_path, encoding="utf-8")  # UTF-8エンコーディングで読み込む
        return df
    except FileNotFoundError:
        print(f"エラー: {file_path} が見つかりません。")  # ファイルが存在しない場合のエラーメッセージ
        return pd.DataFrame()
    except Exception as e:
        print(f"CSV読み込み時のエラー: {e}")  # その他のエラー発生時の処理
        return pd.DataFrame()


# ------------------------------ #
#        モデル関数                #
# ------------------------------ #

def train_classifier(x: pd.Series, y: pd.Series) -> Optional[ZeroShotOllamaClassifier]:
    """Zero-Shot 分類器を作成

    Args:
        x (pd.Series): 分類するテキスト(入力)
        y (pd.Series): カテゴリ(付与する分類ラベル)

    Returns:
        ZeroShotOllamaClassifier | None: 学習済み分類器(失敗時はNone)
    """
    if x.empty or y.empty:
        print("エラー: データが空です。")  # データが空の場合の処理
        return None

    # Zero-Shot モデルのインスタンスを作成
    clf = ZeroShotOllamaClassifier(model=MODEL_NAME)

    # ゼロショットで分類
    clf.fit(x, y)

    return clf


# ------------------------------ #
#      予測&結果保存関数           #
# ------------------------------ #

def predict_and_save(df: pd.DataFrame, clf: ZeroShotOllamaClassifier, output_file: str) -> None:
    """予測を行い、結果をCSVに保存

    Args:
        df (pd.DataFrame): 入力データ
        clf (ZeroShotOllamaClassifier): Zero-Shot分類器
        output_file (str): 予測結果を保存するCSVファイル名
    """
    # カテゴリを予測
    predicted_categories = clf.predict(df[COLUMN_SUMMARY])

    # 予測結果を新しい列に追加
    df[COLUMN_PREDICTED_CATEGORY] = predicted_categories

    # 予測結果をCSVに保存(utf-8-sigを指定してExcelでの文字化けを防ぐ)
    df.to_csv(output_file, encoding="utf-8-sig", index=False)

    print(f"予測結果を {output_file} に保存しました。")


# ------------------------------ #
#          メイン処理             #
# ------------------------------ #

def main(csv_file: str = DEFAULT_CSV_FILE, output_file: str = DEFAULT_OUTPUT_FILE) -> None:
    """プログラムのメイン処理

    1. CSVデータを読み込む
    2. モデルを準備する
    3. 予測を行い、結果を保存する

    Args:
        csv_file (str, optional): 入力CSVファイルのパス(デフォルト: DEFAULT_CSV_FILE)
        output_file (str, optional): 出力CSVファイルのパス(デフォルト: DEFAULT_OUTPUT_FILE)
    """
    # CSVファイルを読み込む
    df = load_csv(csv_file)

    # 読み込みに失敗した場合は処理を終了
    if df.empty:
        return

    # 分類したいテキストとカテゴリの列を取得
    x, y = df.get(COLUMN_SUMMARY), df.get(COLUMN_CATEGORY)

    # データに必要なカラムがない場合のエラーチェック
    if x is None or y is None:
        print(f"エラー: 必要なカラム({COLUMN_SUMMARY}, {COLUMN_CATEGORY})が存在しません。")
        return

    # モデルを作成
    clf = train_classifier(x, y)

    # 成功した場合のみ予測を実行
    if clf:
        predict_and_save(df, clf, output_file)


# ------------------------------ #
#      スクリプトの実行            #
# ------------------------------ #

if __name__ == "__main__":
    main()

無事分類された

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?