0
2

[RAG構築]sklearnを使ってAIに投げかけた質問を分類してPDFから回答を生成するのかAPIから回答を生成するのかを学習させる

Last updated at Posted at 2024-08-07

下記のコードは、AIに投げられた質問がPDF(決算短信)によるものなのかAPI(yfinance)によるものなのかをsklearnを使って学習させ適切な参照元から回答を生成するものです。
PDFから参照すべき質問には決算短信から財務情報が回答されます。APIから参照すべき質問にはAPIから株価グラフが回答として出力されます。

機械学習モデルによる分類

機械学習モデルとして、TfidfVectorizer と MultinomialNB(ナイーブベイズ分類器)を使用しています。以下の部分でモデルを学習させています。
questions というサンプル質問リストと、それに対応する labels(カテゴリラベル)を使ってモデルを訓練しています。

questions = [
    "AAPLの決算情報を教えてください",
    "AAPLの最新の株価は?",
    "MSFTのQ3 2024のレポートを見せてください",
    "GOOGLの株価動向を教えてください"
]
labels = ["PDF", "yfinance", "PDF", "yfinance"]

model = make_pipeline(TfidfVectorizer(), MultinomialNB())
model.fit(questions, labels)

質問を分類する関数

ステップ1
質問に pdf_keywords のいずれかのキーワードが含まれているかどうかを確認します。含まれていれば "PDF" を返します。
質問に yfinance_keywords のいずれかのキーワードが含まれているかどうかを確認します。含まれていれば "yfinance" を返します。

ステップ2
上記のキーワードが含まれていない場合、機械学習モデルを使用して質問を分類します。model.predict([question])[0] で予測したラベルを返します。

def classify_question_hybrid(question):
    if any(keyword in question for keyword in pdf_keywords):
        return "PDF"
    elif any(keyword in question for keyword in yfinance_keywords):
        return "yfinance"
    else:
        return model.predict([question])[0]

下記はコード全体です。

import os
import re
import fitz
import openai
import pandas as pd
from google.oauth2 import service_account
from googleapiclient.discovery import build
from bokeh.plotting import figure
from bokeh.embed import file_html
from bokeh.resources import CDN
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
import yfinance as yf
from datetime import datetime
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_chroma import Chroma
from langchain.schema import Document
import json
import asyncio
import concurrent.futures
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline

# OpenAIのAPIキーを設定
api_key = 
os.environ["OPENAI_API_KEY"] = api_key
openai.api_key = api_key

# 質問からティッカーシンボルを抽出する関数
def extract_tickers_with_ai(question):
    # 質問から株のティッカーシンボルを抽出するためのプロンプトを作成
    prompt = f"Please extract the stock ticker symbols from the following question:\n\n{question}"
    
    # OpenAI APIを使ってティッカーシンボルを抽出
    response = openai.chat.completions.create(
        messages=[
            {"role": "system", "content": "You are a helpful assistant that extracts stock ticker symbols accurately."},
            {"role": "user", "content": prompt}
        ],
        model="gpt-4",
    )
    tickers = response.choices[0].message.content.strip()
    # 正規表現を使ってティッカーシンボルを抽出
    ticker_matches = re.findall(r'\b[A-Z]{1,5}\b', tickers)
    return ticker_matches

# 株価データを取得する関数
def get_stock_data(ticker):
    # yfinanceライブラリを使って株価データを取得
    stock = yf.Ticker(ticker)
    try:
        hist = stock.history(period="1y")
        if hist.empty:
            raise ValueError(f"No data found for ticker {ticker}")
        return hist
    except Exception as e:
        raise ValueError(f"Error fetching data for ticker {ticker}: {e}")

# 株価グラフを描画し、HTMLとして返す関数
def plot_stock_data(tickers, hist_dict):
    # Bokehライブラリを使ってグラフを作成
    p = figure(title='Stock Prices over the last year', x_axis_label='Date', y_axis_label='Close Price', x_axis_type='datetime')
    
    # カスタムパレット(色の設定)
    custom_palette = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
    
    for i, ticker in enumerate(tickers):
        hist = hist_dict[ticker]
        p.line(hist.index, hist['Close'], legend_label=ticker, line_width=2, color=custom_palette[i % len(custom_palette)])
        p.scatter(hist.index, hist['Close'], size=5, color=custom_palette[i % len(custom_palette)])
    
    # グラフをHTMLとしてエクスポート
    graph_html = file_html(p, CDN, "Stock Prices")
    return graph_html

# PDFファイルからデータを取得する非同期関数
async def load_pdf_data_async(pdf_path):
    if not os.path.exists(pdf_path):
        raise FileNotFoundError(f"PDF file does not exist at the specified path: {pdf_path}")
    
    try:
        # PyMuPDFを使ってPDFを開く
        document = fitz.open(pdf_path)
    except Exception as e:
        raise FileNotFoundError(f"An error occurred while trying to open the PDF file: {e}")

    data = []
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # 各ページを非同期に処理
        futures = [executor.submit(load_page_data, document, page_num) for page_num in range(document.page_count)]
        for future in concurrent.futures.as_completed(futures):
            data.append(future.result())
    return data

# ページデータを読み込む関数
def load_page_data(document, page_num):
    # PDFの特定のページからテキストを取得
    page = document.load_page(page_num)
    text = page.get_text("text")
    return Document(page_content=text, metadata={"page": page_num + 1})

# PDFファイルからデータを取得し回答する関数
async def answer_pdf_question_async(question):
    pdf_path = "/app/lib/FY24_Q3_Financial_Statements.pdf"
    data = await load_pdf_data_async(pdf_path)

    # ベクトルデータベースを作成
    embeddings_model = OpenAIEmbeddings(model="text-embedding-ada-002")
    db = Chroma.from_documents(data, embeddings_model)

    # 抽出器を作成
    retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 1})

    # クエリを実行
    results = retriever.invoke(question)

    # 抽出した結果を元にChatOpenAIモデルを使用
    gpt4o = ChatOpenAI(model="gpt-4", temperature=1.0)
    context = " ".join([result.page_content for result in results])
    response = gpt4o.invoke(context + "\n" + question)

    # 結果を表示
    print(response.content)

# 質問の分類 (キーワードベース + 機械学習モデル)
pdf_keywords = ["決算", "売上", "レポート", "有価証券報告書"]
yfinance_keywords = ["株価", "ティッカー", "yfinance"]

# 学習用の質問とラベル
questions = [
    "AAPLの決算情報を教えてください",
    "AAPLの最新の株価は?",
    "MSFTのQ3 2024のレポートを見せてください",
    "GOOGLの株価動向を教えてください"
]
labels = ["PDF", "yfinance", "PDF", "yfinance"]

# 機械学習モデルの作成
model = make_pipeline(TfidfVectorizer(), MultinomialNB())
model.fit(questions, labels)

# 質問を分類するハイブリッド関数
def classify_question_hybrid(question):
    # キーワードに基づく分類
    if any(keyword in question for keyword in pdf_keywords):
        return "PDF"
    elif any(keyword in question for keyword in yfinance_keywords):
        return "yfinance"
    else:
        # 機械学習モデルによる分類
        return model.predict([question])[0]

# コマンドラインで質問を投げかける
if __name__ == "__main__":
    import sys
    question = " ".join(sys.argv[1:])
    category = classify_question_hybrid(question)
    
    if category == "PDF":
        # PDFに関連する質問を処理
        asyncio.run(answer_pdf_question_async(question))
    elif category == "yfinance":
        # 株価に関連する質問を処理
        tickers = extract_tickers_with_ai(question)
        hist_dict = {ticker: get_stock_data(ticker) for ticker in tickers}
        graph_html = plot_stock_data(tickers, hist_dict)
        print(graph_html)  # グラフのHTMLを出力
0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2