6
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

IBM 製基盤モデル Slate で sentiment classification モデルを fine tuning してみた

Last updated at Posted at 2023-09-25

この記事の内容

watsonx.ai でリリースされた IBM 製基盤モデルである Slate を用いて sentiment classification タスクに対応するモデルの fine tuning を行う手順を記載しています。

Slate とは

IBM では自社製基盤モデルとして以下の 3 つのモデルを提供することを blog で発表しています。
・Slate(encoder model)
・Sandstone(encoder-decoder model)
・Granite(decoder model)
Slate は encoder アーキテクチャーの基盤モデルで、ラベル付きデータセットを用いて多くの NLP タスクに対応させることができます。
当記事では sentiment classification(評判分類)タスクに対応するように Slate モデルの fine tuning を行います。

実行環境

Watson Studio Jupyter Notebook を使用します。
Watson Studio を利用するためには IBM Cloud の Web Dashboard に自身のアカウントでログインし、Watson Studio のリソースやプロジェクトを作成する必要があります。

以下のようにプロジェクトの資産として Jupyter Notebook を新たに作成します。
Slate を利用するためには Runtime 23.1 を選択する必要があります。
スクリーンショット 2023-09-25 10.46.36.png

データの準備

こちらのリンクに記載されているように json 形式のファイルでデータセットを準備します。

データの例
  [
      {
      "text": "とても幸せです",
      "labels": ["positive"]
      },
      {
      "text": "非常に悲しいです",
      "labels": ["negative"]
      },
      {
      "text": "空は青いです",
      "labels": ["neutral"]
      }
  ]

学習用、検証用、テスト用のデータセットはプロジェクトの資産としてアップロードしておきます。

モデルの学習

Slate モデルの fine tuning を行い、学習済みモデルをプロジェクトの資産として保存します。
こちらのドキュメントを参考にしています。

ライブラリのインポート

import watson_nlp
from watson_nlp.toolkit.sentiment_analysis_utils.training import train_util as utils

データの読み込み

Noteboo 右側のメニューから学習用データセットを pandas DataFrame として読み込み、セルにコードを挿入します。
スクリーンショット 2023-09-25 15.20.43.png

学習用と検証用のデータを上記の手順で読み込み、DataFrame を Watson NLP モデルが利用できるようにエクスポートしてから DataStream として読み込みます。
DataStream の作成には utility メソッドの prepare_data_from_json を使用します。

training_data_file = "./dataset_sentiment_train.json"
df_data_train.to_json(training_data_file, orient='records')
dev_data_file = "./dataset_sentiment_dev.json"
df_data_dev.to_json(dev_data_file, orient='records')
train_stream = utils.prepare_data_from_json(training_data_file)
dev_stream = utils.prepare_data_from_json(dev_data_file)

事前学習済みモデルの読み込み

モデルの学習の前に事前学習済み Slate モデルを読み込みます。
さらに、入力テキストで使用される言語の構文分析モデルをロードする必要があります。

pretrained_model_resource = watson_nlp.load('pretrained-model_slate.153m.distilled_many_transformer_multilingual_uncased')
syntax_model_ja = watson_nlp.load('syntax_izumo_ja_stock')
syntax_models = [syntax_model_ja]

学習の実行

train_transformer メソッドを使用して学習を実行します。
ドキュメントのガイドに従って、後続のステップで、言語検出を有効にして、ワークフロー・モデルが前提条件情報なしで入力テキストに対して実行できるようにします。

from watson_nlp.workflows.sentiment import AggregatedSentiment
sentiment_model = AggregatedSentiment.train_transformer(
    train_data_stream = train_stream,
    dev_data_stream = test_stream,
    syntax_model=syntax_models,
    pretrained_model_resource=pretrained_model_resource,
    label_list=['negative', 'neutral', 'positive'],
    learning_rate=2e-5,
    num_train_epochs=10,
    combine_approach="NON_NEUTRAL_MEAN",
    keep_model_artifacts=True
)
lang_detect_model = watson_nlp.load('lang-detect_izumo_multi_stock')
sentiment_model.enable_lang_detect(lang_detect_model)

上記のコードを実行すると Epoch ごとに学習用、検証用データセットに対する Loss と Accuracy が表示されます。
スクリーンショット 2023-09-25 15.44.12.png

学習済みモデルの保存

ibm-watson-studio-lib を使用して学習済みモデルをプロジェクトの資産として保存します。
こちらのドキュメントを参考にしています。

from ibm_watson_studio_lib import access_project_or_space
wslib = access_project_or_space({"token": "<ProjectToken>"})
wslib.save_data("sentiment_model", data=sentiment_model.as_bytes(), overwrite=True)

<ProjectToken> の箇所には各自の環境のプロジェクトトークンを記載する必要があります。
プロジェクトトークンの取得方法はこちらのドキュメントを参考にしてください。

学習済みモデルの読み込みと実行

モデルの読み込み

同じく ibm-watson-studio-libwatson_nlp を使用してモデルを読み込みます。

sentiment_model = watson_nlp.load(wslib.load_data("sentiment_model"))

モデルの実行

run() メソッドを使用して新規のデータに対してモデルを実行します。

text = "今日の夕飯はハンバーグだ"
sentiment_predictions = sentiment_model.run(text)

language_code="ja" のように language_code をオプションとして追加することで、明示的に入力言語を指定することができます。

結果の保存

ドキュメントには記載されていませんが、プロジェクトの資産としてモデルの結果を json ファイルとして保存する手順を記載します。
pandas DataFrame として読み込み直してから保存していますが、簡単な方法があるかもしれません。

import json
import pandas
results_pd = pandas.read_json(json.dumps(results, ensure_ascii=False))
wslib.save_data("results_sentiment.json", results_pd.to_json(force_ascii=False).encode(), overwrite=True)

参考資料

6
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?