0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【LangChain】LLM を用いて論文から知りたい情報を抽出する

Posted at

はじめに

  • 自分用の簡単なまとめであり,説明はほとんどありません.
  • 内容は時間があるときに補充します.

概要

  • 動作の流れとしては,ArXiv API から関連する論文の PDF ファイルをダウンロードし,その PDF に対して LangChain を用いて QA を行います.
  • 小規模なローカル LLM ではうまく動作しませんでしたが,OpenAI などの大規模モデルを使用すれば,それなりに知りたい情報を抽出できると思います.
  • コードは GitHub (リンク) にあり,最新版はそちらをご参照ください.
    • 検索キーワードなど一般化されていない部分が多いため,処理に関するアドバイスやコメント, Pull Request は大歓迎です.

実装

1. 論文のダウンロード

download_papers.py
# -*- coding: utf-8 -*-
import arxiv
import datetime
import json
import logging
import os
import tempfile
import time
from typing import Any, Dict, Optional


def download_papers_from_arxiv(max_results: Optional[int] = 10) -> str:
    # PDF の保存先
    temp_dir = tempfile.mkdtemp()
    logging.info(f"Saving directory: {temp_dir}")

    # 検索条件
    today = datetime.datetime.now().strftime(format="%Y%m%d")
    yesterday = (datetime.datetime.now() - datetime.timedelta(days=1)).strftime(format="%Y%m%d")

    query = f'all:"Retrieval Augmented Generation" AND submittedDate:[{yesterday} TO {today}]'
    logging.info(f"Query: {query}")

    search = arxiv.Search(
        query=query,
        max_results=max_results,
        sort_by=arxiv.SortCriterion.SubmittedDate,
    )

    # 検索
    client = arxiv.Client()
    results = client.results(search)

    # PDF の保存
    cnt = 0
    for i, paper in enumerate(results):
        logging.info(f"{i}: {paper.title}")

        _id = paper.entry_id.split("/")[-1]  # like "2409.06450v1"
        paper.download_pdf(dirpath=temp_dir, filename=f"{_id}.pdf")

        info = {
            "id": paper.entry_id,
            "title": paper.title,
            "published": paper.published.isoformat(),  # ISOフォーマットで日付を保存
            "summary": paper.summary,
            "categories": paper.categories,
            "journal_ref": paper.journal_ref,
        }
        with open(os.path.join(temp_dir, f"{_id}.json"), "w") as f:
            json.dump(info, f, indent=4)

        cnt += 1
        time.sleep(3)

    logging.info(f"Number of downloaded papers: {cnt}")
    return temp_dir


def get_paper_info(json_path: str) -> Dict[str, Any]:
    with open(json_path, "r") as file:
        info = json.load(file)
    info = {k: v for k, v in info.items()}
    return info


# 使い方
if __name__ == "__main__":
    temp_dir = download_papers_from_arxiv(max_results=1)

2. LLM のセットアップ

utils.py
# -*- coding: utf-8 -*-
import logging
import torch
from langchain.llms import OpenAI, HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from typing import Optional


def setup_llm(
    hf_model_id: str = "elyza/Llama-3-ELYZA-JP-8B", quantization: Optional[bool] = True, max_tokens: Optional[int] = 500
):
    """
    動作確認済みの hf_model_id:
        elyza/Llama-3-ELYZA-JP-8B
    """

    # hf_model_id が None の場合は OpenAI の LLM を返す
    if hf_model_id is None:
        return OpenAI(temperature=0, max_tokens=max_tokens)

    tokenizer = AutoTokenizer.from_pretrained(hf_model_id)

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,  # 量子化の有効化
        bnb_4bit_quant_type="nf4",  # 量子化種別 (fp4 or nf4)
        bnb_4bit_compute_dtype=torch.float16,  # 量子化の dtype (float16 or bfloat16)
        bnb_4bit_use_double_quant=True,  # 二重量子化の有効化
    )

    model = AutoModelForCausalLM.from_pretrained(
        hf_model_id,
        torch_dtype=torch.float16,
        quantization_config=bnb_config if quantization else None,
        device_map="auto",
    )

    hf_pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=max_tokens,
        device_map="auto",
    )

    return HuggingFacePipeline(pipeline=hf_pipeline)


# 使い方
if __name__ == "__main__":
    llm = setup_llm()
    query = "月見を題材に詩を詠んでください."

    output = llm.invoke(query)
    logging.info(output)

3. 論文のサマリーから関連度をスコアリング

estimate_applicability.py
# -*- coding: utf-8 -*-
import logging
from langchain.prompts import FewShotPromptTemplate, PromptTemplate
from langchain.output_parsers import RegexParser
from typing import Dict

# 自作モジュール
from utils import setup_llm


def estimate_applicability_from_title(llm, title: str) -> Dict[str, str]:
    # create the list of few shot examples
    examples = [
        {
            "title": "Multimodal Retrieval-Augmented Generation for Healthcare",
            "answer": "Yes. This paper explores the use of RAG techniques with both text and medical imaging data, aligning well with multimodal RAG.",
            "score": "100",
        },
        {
            "title": "Enhancing Document Retrieval with Text-Based Models",
            "answer": "No. This paper focuses on text-based retrieval models without incorporating multiple modalities.",
            "score": "20",
        },
        {
            "title": "Combining Text and Visual Data for Improved Question Answering",
            "answer": "Yes. The paper addresses the integration of text and visual data for question answering, relevant to multimodal RAG approaches.",
            "score": "80",
        },
        {
            "title": "Analyzing Speech and Text Data for Sentiment Analysis",
            "answer": "Not clear. While it involves text and speech data, the focus is on sentiment analysis rather than RAG.",
            "score": "10",
        },
        {
            "title": "Developing Multimodal Systems for Autonomous Vehicles",
            "answer": "Yes. This paper involves the integration of various data types (e.g., video, sensor data) for autonomous vehicles, which is relevant to multimodal RAG.",
            "score": "85",
        },
    ]

    # specify the template to format the examples
    example_formatter_template = "Paper Title: {title}\n" "Answer: {answer}\n" "Score: {score}"

    example_prompt = PromptTemplate(
        template=example_formatter_template,
        input_variables=["title", "answer", "score"],
        validate_template=False,
    )

    prefix = (
        "Based on the following title of the paper, please estimate its applicability to multimodal retrieval-augmented generation (RAG).\n"
        "If you do not know, answer that you do not know.\n\n"
        "Paper Title: [title of paper]\n"
        "Answer: [applicability to multimodal RAG]\n"
        "Score: [estimated applicability from 0 to 100]\n\n"
        "How to determine the score\n"
        "- If the paper is presumed to have high potential for application in multimodal RAG, a high score will be given.\n"
        "- Be careful not to be overconfident!\n"
    )
    suffix = "Paper Title: {title}\n" "Answer: "

    few_shot_prompt = FewShotPromptTemplate(
        examples=examples,
        example_prompt=example_prompt,
        prefix=prefix,  # some text that goes before the examples in the prompt
        suffix=suffix,  # some text that goes after the examples in the prompt
        input_variables=["title"],
        example_separator="\n\n",
    )

    prompt_text = few_shot_prompt.format(title=title.encode("utf-8").decode("utf-8"))
    logging.info(f"{prompt_text = }")

    output = llm.invoke(prompt_text.encode("utf-8").decode("utf-8"))
    logging.info(f"{output = }")

    output_parser = RegexParser(
        regex=r"\s(.*?)\nScore: (.*)",
        output_keys=["answer", "score"],
    )
    format_output = output_parser.parse(output.replace(prompt_text, ""))

    return format_output


# 使い方
if __name__ == "__main__":
    llm = setup_llm()
    titile = "Multimodal Large Language Model Driven Scenario Testing for Autonomous Vehicles"

    format_output = estimate_applicability_from_title(llm, titile)
    logging.info(f"{format_output = }")

4. 知りたい情報の定義

queston_list.py
# -*- coding: utf-8 -*-
from pydantic import BaseModel, Field
from typing import List


class QuestionList(BaseModel):
    multimodal_support: bool = Field(
        description="Does the method support multi-modal documents that include text, images, or tables?",
    )
    modalities: str = Field(
        description="What modalities are addressed? (e.g., text, image, table, etc.)",
    )
    llm: str = Field(
        description="What type(s) of LLM(s) were used to validate the method?",
    )
    data: str = Field(
        description="What type(s) of data were used to validate the method?",
    )
    integration: str = Field(
        description="How does the method handle multiple modalities?",
    )
    characteristic: List[str] = Field(
        description="What were the top three characteristics of the study?",
    )
    retrieval_challenge: str = Field(
        description=(
            "Did the paper provide insights or solutions for addressing challenges "
            "in retrieving specific information from large-scale, multi-modal documents (including text, images, and/or tables)? "
            "If so, how?"
        ),
    )

5. 論文から知りたい情報を抽出

main.py
# -*- coding: utf-8 -*-
import json
import logging
import os
from langchain.chains.question_answering import load_qa_chain
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma

# 自作モジュール
from queston_list import QuestionList
from utils import setup_llm
from estimate_applicability import estimate_applicability_from_title
from download_papers import download_papers_from_arxiv, get_paper_info


def extract_paper_details(pdf_path: str, llm, embeddings):
    # 基本情報の読取
    json_path = pdf_path.replace(".pdf", ".json")
    info = get_paper_info(json_path)

    # 論文タイトルからラフに評価
    rough_appl = estimate_applicability_from_title(llm, info["title"])  # like {"answer": "...", "score": "100"}
    info.update(rough_appl)

    # スコアがしきい値以上の場合
    if int(rough_appl["score"]) >= 20:

        qa_chain = load_qa_chain(
            llm=llm,
            chain_type="refine",  # "stuff", "map_reduce", "refine", "map_rerank"
            return_intermediate_steps=True,
            verbose=True,
        )

        output_parser = CommaSeparatedListOutputParser()

        # pdf の読取
        loader = PyPDFLoader(pdf_path)
        pages = loader.load()

        # chunk 分割
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=254)
        chunks = text_splitter.split_documents(pages)

        # db 化
        temp_db = Chroma.from_documents(chunks, embedding=embeddings)

        # retriever 化
        retriever = temp_db.as_retriever(search_type="similarity", search_kwargs={"k": 4})

        result = dict()
        for field in QuestionList.model_fields:
            question = QuestionList.model_fields[field].description  # 質問
            logging.info(f"{question = }")

            contexts = retriever.get_relevant_documents(query=question)  # 検索結果

            # qa chain
            answer = qa_chain({"input_documents": contexts, "question": question}, return_only_outputs=True)
            answer = answer["output_text"]
            if field in ["characteristic"]:
                answer = output_parser.parse(answer)

            result[field] = dict(question=question, answer=answer)

        info.update({"qa": result})

    # 結果の保存
    with open(json_path, "w") as f:
        json.dump(info, f, indent=4)

    return None


# 使い方
if __name__ == "__main__":
    temp_dir = download_papers_from_arxiv(max_results=10)

    llm = setup_llm()
    embeddings = HuggingFaceEmbeddings()

    for fname in os.listdir(temp_dir):
        pdf_path = os.path.join(temp_dir, fname)
        extract_paper_details(pdf_path, llm, embeddings)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?