14
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Papers with Code トレンド論文bot作って雑食的に情報収集してみた

Last updated at Posted at 2023-12-11

本記事は「Kaggle Advent Calendar 2023 」の 11 日目の記事です。

概要

X(旧:Twitter)のタイムラインを眺めていると、日々新しい論文・手法の情報が流れてきますよね。
SNSによる情報収集を私も活用しておりますが、一方でフォローしているユーザーさんの情報がメインとなるため、もともと私が関心ある・認知済みの分野に偏った情報が入ってきやすかったりします(Kaggle、CVなど)。

そこで、認知外の情報をよりキャッチアップするために、トレンド論文情報を通知するbotを作ってみました。
あえてキーワード検索などせずトレンドに身を任せることで、機械学習関連の情報を雑食的にキャッチアップし、分野の流行りやKaggleに使えそうな手法を見つけようって狙いです。

  • botのSlack通知例

トレンド論文の情報は、Papers with CodeTrending Researchから取得してます。

  • Papers with Code
    • 機械学習系の論文とソースコードがセットで公開されているため、「論文の手法に興味あり」->「試してみる」のサイクルが回しやすい(はず)です。
    • Trending Researchのトレンド定義は公開されていませんが、Xで話題になった論文などを上位で見かけることも多いため、閲覧数やGithubのスター数などをもとにランキングされていると推察してます。

Botの動作イメージ

Papers with CodeTrending Research上位10件を毎朝チェックし、未読の論文(まだ通知していない論文)があればSlack通知・NotionDBへ登録します(Slackが無料版だと90日で消えちゃうので、Notion DBにも飛ばすようにしてます)。
基本的にはOpenAIのAPI料金だけで実現出来ます。タイトル翻訳もOpenAIに任せても良いですが、節約のためDeepL APIで実施してます。

検証条件

環境、ライブラリ

  • python 3.10
  • slack-sdk==3.26.1
  • deepl==1.16.1
  • openai==0.28.1
  • schedule==1.2.1
  • arxiv==2.0.0
  • notion-client==2.1.0
  • pandas==2.1.3
  • numpy==1.26.2

botを作ってみる

事前準備

Notion Databaseの用意

Slackのフリープランでは90日間でメッセージが見れなくなるため、Notion上に論文情報を保管するDatabaseを用意します。
手間ですが手動で以下DataBaseを作ります(APIでも可能だと思いますが、やり方が分からず...)。

  1. /databaseコマンドで Database作成
  2. カラムを編集
    • Name -> title

    • Tags -> task

    • 以下 追加

      カラム名 Type
      arxiv URL
      github URL
      title_jp Text
      llm_summary Text
      subjects Multi-select
      published Date

各API情報

各APIを使うにあたり、必要なトークン情報などを取得しておきます。

  • OpenAI

  • DeepL

    • 認証キー
      • account/summary - アカウント - DeepL APIで使用する認証キー で取得
  • Slack

    • Incoming WebHook の Webhook URL
      • 通知を飛ばしたいSlackチャンネルにIncoming WebHookを設定し、Webhook URLを取得 (参考記事)
  • Notion

    • token
      • インテグレーションを作成し、トークンを取得 (参考記事)
    • Database ID

API情報を環境変数に設定

  • API情報を環境変数に設定するため、setenv.shを作成
    setenv.sh
    #!/bin/bash
    
    export OPENAI_API_KEY="xxxxx"       # OpenAI API keys
    export OPENAI_ORGANIZATION="xxxxx"  # OpenAI Organization ID
    export DEEPL_API_KEY="xxxxx"        # DeepL 認証キー
    export SLACK_URL="xxxxx"            # Slack Webhook URL
    export NOTION_API_TOKEN="xxxxx"     # Notion token
    export NOTION_DB_ID="xxxxx"         # Notion Database ID
    
  • API情報を環境変数に反映
    source setenv.sh
    

実装

いざ実装です。ファイル構成は以下のとおりです。

  • ファイル構成
    .
    ├── config.py                   <- pathなど設定
    ├── get_paper_info.py           <- 論文情報取得
    ├── post_process.py             <- 翻訳、要約
    ├── send_paper_info.py          <- 論文情報発信
    └── result                      <- 出力ディレクトリ(スクリプト実行後に生成)
        ├── paper_info.csv          <- 取得した論文情報一覧
        ├── deepl_api_usage_log.csv <- deep APIの使用状況を保存(フリープランの翻訳上限 500,000文字/月までの目安を把握)
        └── bot.log                 <- bot動作ログ
    

  • config.py
    取得した論文情報を出力するディレクトリ、ファイル名など設定します。
    config.py
    output_dir = "./result"                         # 出力ディレクトリ
    paper_info_filename = "paper_info.csv"          # 取得した論文情報一覧
    deepl_log_filename = "deepl_api_usage_log.csv"  # deep APIの使用状況
    bot_log_filename = "bot.log"                    # bot動作ログ
    

  • get_paper_info.py (論文情報取得)
    Paper with Codeからトレンド論文情報(タイトル、Abstract、Githubリポジトリ、タスクなど)を取得し、arxivから対象論文のsubjectsを取得します。
    この際、過去に取得した論文情報をconfig.paper_info_filenameファイルと比較し、新規論文のみ追加するようにしてます。

    コード
    get_paper_info.py
    import os
    import re
    import time
    import requests
    from datetime import datetime
    from typing import List, Dict, Union
    
    import arxiv
    import numpy as np
    import pandas as pd
    
    import config
    
    
    PAPERS_WITH_CODE_API_URL = "https://paperswithcode.com/api/v1/"
    QUERY_PARAMS = {
        "page": 1,
        "items_per_page": 10,
    }
    
    
    def query_papers_with_code(url: str) -> List[Dict[str, Union[str, None, List]]]:
        """PapersWithCode API実行
    
        Args:
            url (str): 対象APIのURL
    
        Returns:
            List[Dict[str, Union[str, None, List]]]: PapersWithCode API 戻り値
        """
        try:
            response = requests.get(url, params=QUERY_PARAMS)
            response.raise_for_status()
            data = response.json()
            return data["results"]
        except requests.exceptions.RequestException as e:
            print(f"Error making API request: {e}")
            return None
    
    
    def search_trend_papers() -> str:
        """トレンド論文を上位10件取得するURL生成
    
        Returns:
            str: PapersWithCode API URL
        """
        url = PAPERS_WITH_CODE_API_URL + "search"
        return query_papers_with_code(url)
    
    
    def get_paper_repositorie(paper_id: str) -> str:
        """対象論文のリポジトリを取得するURL生成
    
        Args:
            paper_id (str): 対象論文のid
    
        Returns:
            str: PapersWithCode API URL
        """
        url = f"{PAPERS_WITH_CODE_API_URL}papers/{paper_id}/repositories"
        return query_papers_with_code(url)
    
    
    def get_paper_tasks(paper_id: str) -> str:
        """対象論文のタスクを取得するURL生成
    
        Args:
            paper_id (str): 対象論文のid
    
        Returns:
            str: PapersWithCode API URL
        """
        url = f"{PAPERS_WITH_CODE_API_URL}papers/{paper_id}/tasks"
        return query_papers_with_code(url)
    
    
    def search() -> pd.DataFrame:
        """Papers With CodeのTrending Research論文を 10件取得
    
        Returns:
            pd.DataFrame: 論文情報のDataFrame
        """
        paper_info_dic = {
            k: []
            for k in [
                "title",
                "abstract",
                "published",
                "conference",
                "arxiv",
                "repository",
                "stars",
                "task",
                "subjects",
            ]
        }
        papers = search_trend_papers()
    
        for i, paper in enumerate(papers):
            print(f"\t paper {i+1}...")
            # 基本情報の取得
            paper_info = paper["paper"]
            paper_info_dic["title"].append(paper_info["title"])
            paper_info_dic["abstract"].append(paper_info["abstract"])
            paper_info_dic["arxiv"].append(paper_info["url_abs"])
            paper_info_dic["published"].append(paper_info["published"])
            paper_info_dic["conference"].append(paper_info["proceeding"])
            time.sleep(0.5)
            # リポジトリ情報の取得
            repositories = get_paper_repositorie(paper_info["id"])
            if repositories:
                use_repositories = [
                    repo for repo in repositories if repo["is_official"]
                ]  # officialを優先して取得
                if use_repositories:
                    use_repository = use_repositories[0]
                else:
                    _use_index = np.argmax([repo["stars"] for repo in repositories])
                    use_repository = repositories[_use_index]
                paper_info_dic["repository"].append(use_repository["url"])
                paper_info_dic["stars"].append(use_repository["stars"])
            else:
                paper_info_dic["repository"].append("-")
                paper_info_dic["stars"].append("-")
            time.sleep(0.5)
            # タスク情報取得
            tasks = get_paper_tasks(paper_info["id"])
            tasks = [t["name"] for t in tasks]
            paper_info_dic["task"].append(tasks)
            time.sleep(0.5)
            # 論文タイトルをもとに、arxivからsubjects情報を取得
            query = " OR ti:".join(re.findall("[a-z]+-*[a-z]+", paper_info["title"].lower()))
            search = arxiv.Search(
                query=f"ti:{query}",
                max_results=1,
            )
            arxiv_result = next(arxiv.Client().results(search))
            paper_info_dic["subjects"].append(arxiv_result.categories)
    
        paper_info_df = pd.DataFrame(paper_info_dic)
        today = datetime.now()
        paper_info_df["get_date"] = today.strftime("%Y-%m-%d")
        paper_info_df["send_flag"] = False
        paper_info_df["title_jp"] = "-"
        paper_info_df["llm_summary"] = "-"
        return paper_info_df
    
    
    def save(paper_info_df: pd.DataFrame) -> None:
        """取得した論文情報をcsv保存
    
        Args:
            paper_info_df (pd.DataFrame): 論文情報
        """
        paper_info_path = os.path.join(config.output_dir, config.paper_info_filename)
        if os.path.exists(paper_info_path):
            old_paper_info_df = pd.read_csv(paper_info_path)
            paper_info_df = pd.concat(
                [old_paper_info_df, paper_info_df], axis=0
            ).drop_duplicates("title", keep="first")
            paper_info_df.reset_index(drop=True, inplace=True)
        else:
            os.makedirs(config.output_dir, exist_ok=True)
        paper_info_df.to_csv(paper_info_path, index=False)
    
    
    def run() -> None:
        """論文情報の取得とcsv保存
        """
        print("get paper info...")
        paper_info_df = search()
        save(paper_info_df)
    

  • post_process.py (翻訳、要約)
    DeepLでタイトルの翻訳、OpenAI(GPT-3.5)でAbstractの要約をします。
    Abstractは以下観点で要約させてます。

    1. 既存研究では何ができなかったのか
    2. どのようなアプローチでそれを解決しようとしたか
    3. 結果、何が達成できたのか

    要約の観点、promptはこちらの記事を参考にさせていただきました。

    コード
    post_process.py
    import os
    import re
    import csv
    import time
    from datetime import datetime
    
    import deepl
    import openai
    
    import config
    
    
    # For DeepL API
    AUTH_KEY = os.environ["DEEPL_API_KEY"]
    TARGET_LANGAGE = "JA"
    DEEPL_MONTHLY_LIMIT = 500_000
    
    # For OpenAI API
    MAX_RETRIES = 3
    openai.organization = os.environ.get("OPENAI_ORGANIZATION")
    openai.api_key = os.environ.get("OPENAI_API_KEY")
    
    
    def translate(target_text: str) -> str:
        """日本語翻訳
    
        Args:
            target_text (str): 翻訳対象
    
        Returns:
            str: 翻訳結果
        """
        # 今日の日付を取得
        today_date = datetime.now()
        current_month = today_date.strftime("%Y-%m")
        # 過去の使用ログを読み取り、同じ月の文字数を合計
        total_chars_used = 0
        deepl_log_path = os.path.join(config.output_dir, config.deepl_log_filename)
        if os.path.exists(deepl_log_path):
            with open(deepl_log_path, "r") as file:
                reader = csv.DictReader(file)
                for row in reader:
                    log_date = datetime.strptime(row["Date"], "%Y-%m-%d")
                    log_month = log_date.strftime("%Y-%m")
                    if log_month == current_month:
                        total_chars_used += int(row["Characters"])
    
        translator = deepl.Translator(AUTH_KEY)
        result = translator.translate_text(target_text, target_lang=TARGET_LANGAGE)
        # 使用文字数を取得
        total_chars_used += len(target_text)
        # 今日の日付と使用文字数をログに記録
        with open(deepl_log_path, "a", newline="") as file:
            headers = ["Date", "Characters", "Monthly_characters"]
            writer = csv.DictWriter(file, fieldnames=headers)
            # ファイルが空の場合、ヘッダーを書き込む
            if file.tell() == 0:
                writer.writeheader()
            writer.writerow(
                {
                    "Date": today_date.strftime("%Y-%m-%d"),
                    "Characters": len(target_text),
                    "Monthly_characters": total_chars_used,
                }
            )
        return result.text
    
    
    def llm_summarize(abstract: str) -> str:
        """論文のAbstractをGPT-3.5で要約
    
        Args:
            abstract (str): Abstract文
    
        Raises:
            RetryLimitError: APIのリトライが上限を超えた場合
    
        Returns:
            str: 要約結果
        """
        messages = [
            {"role": "system", "content": "あなたは親切なアシスタントです。"},
            {
                "role": "user",
                "content": f"以下論文のAbstractをもとに、\n\
        1. 既存研究では何ができなかったのか、\n\
        2. どのようなアプローチでそれを解決しようとしたか、\n\
        3. 結果、何が達成できたのか\nについて日本語で5行くらいで教えて。\n\n\
        Abstract: {abstract}",
            },
        ]
    
        for retry_count in range(MAX_RETRIES):
            try:
                response = openai.ChatCompletion.create(
                    model="gpt-3.5-turbo",  # GPTのエンジン名を指定します
                    messages=messages,
                    max_tokens=500,  # 生成するトークンの最大数
                    n=1,  # 生成するレスポンスの数
                    stop=None,  # 停止トークンの設定
                    temperature=0,  # 生成時のランダム性の制御
                    top_p=1,  # トークン選択時の確率閾値
                )
                break
            except Exception as RetryLimitError:
                if retry_count < MAX_RETRIES:
                    print(f"Retry {retry_count}: {RetryLimitError}")
                    time.sleep(1)
                else:
                    print(
                        f"Max retries reached. Unable to complete operation. Last error: {RetryLimitError}"
                    )
                    raise RetryLimitError
        response_text = response.choices[0].message.content.strip()
        return re.sub("\n+", "\n\n", response_text)  # 改行を整形
    

  • send_paper_info.py (論文情報発信)
    論文情報をSlack、Notionへ発信します。send_flagで未通知/通知済みを判断し、未読の論文(まだ通知していない論文)だけ発信します。
    scheduleライブラリを使用し、6:30に論文情報取得・7:00に発信するようスケジュール設定してます。

    コード
    send_paper_info.py
    import os
    import time
    import logging
    import argparse
    from datetime import datetime
    
    import numpy as np
    import pandas as pd
    import schedule
    from notion_client import Client
    from slack_sdk.webhook import WebhookClient
    
    import config
    import get_paper_info
    import post_process
    
    
    def message_to_slack(message: str) -> None:
        """Slackにメッセージ送信
    
        Args:
            message (str): 送信メッセージ
        """
        slack = WebhookClient(os.environ["SLACK_URL"])
        _ = slack.send(text=message)
    
    
    def make_slack_message(row: pd.Series) -> str:
        """論文情報をSlack用メッセージに変換
    
        Args:
            row (pd.Series): 1論文の情報
    
        Returns:
            str: Slack用メッセージ
        """
        m = f"<{row['arxiv']} | {row['title']}>\n"
        m += f"{row['title_jp']}\n"
        task = ", ".join([f"`{t[1:-1]}`" for t in row["task"][1:-1].split(", ")])
        m += f"・Task : {task}\n"
        m += f"{row['published']}  (conference : {row['conference'] if not pd.isna(row['conference']) else '-'})\n"
        m += f"・<{row['repository']} | Github> : ☆{row['stars']}\n"
        m += f"{row['llm_summary']}\n"
        m += "-" * 100
        return m
    
    
    def paper_to_notion(client: Client, row: pd.Series) -> None:
        """論文情報をNotionへ送信
    
        Args:
            client (Client): notion_client.Client
            row (pd.Series): 1論文の情報
        """
        _ = client.pages.create(
            **{
                "parent": {"database_id": os.environ["NOTION_DB_ID"]},
                "properties": {
                    "title": {
                        "title": [
                            {
                                "text": {
                                    "content": row["title"],
                                }
                            }
                        ]
                    },
                    "task": {
                        "type": "multi_select",
                        "multi_select": [
                            {"name": t[1:-1]}
                            for t in row["task"][1:-1].split(", ")
                            if len(t)
                        ],
                    },
                    "published": {"date": {"start": row["published"]}},
                    "arxiv": {"url": row["arxiv"]},
                    "github": {"url": row["repository"]},
                    "llm_summary": {
                        "rich_text": [
                            {
                                "text": {
                                    "content": row["llm_summary"],
                                },
                            }
                        ],
                    },
                    "title_jp": {
                        "rich_text": [
                            {
                                "text": {
                                    "content": row["title_jp"],
                                },
                            }
                        ],
                    },
                    "subjects": {
                        "type": "multi_select",
                        "multi_select": [
                            {"name": t[1:-1]}
                            for t in row["subjects"][1:-1].split(", ")
                            if len(t)
                        ],
                    },
                },
                # coverに画像urlを設定すると、任意のサムネイルも設定可能
                # "cover": {"type": "external", "external": {"url": xxx}},
            }
        )
    
    
    def prepare_paper_info() -> None:
        """論文情報の取得、要約など"""
        logging.info(f"{datetime.now()} - prepare_paper_info start")
    
        # 論文情報取得
        get_paper_info.run()
        paper_info_path = os.path.join(config.output_dir, config.paper_info_filename)
        paper_info_df = pd.read_csv(paper_info_path)
        # 要約処理を未実施の論文を抽出
        process_index = paper_info_df.query("llm_summary == '-'").index
        if len(process_index):
            print("post-process...")
            for idx in process_index:
                print(f"\t index {idx}...")
                row = paper_info_df.loc[idx]
                print("\t\t title...")
                title_tranlate = post_process.translate(row["title"])
                print("\t\t summary...")
                llm_summary = post_process.llm_summarize(row["abstract"])
                paper_info_df.at[idx, "title_jp"] = title_tranlate
                paper_info_df.at[idx, "llm_summary"] = llm_summary
            paper_info_df.to_csv(paper_info_path, index=False)
    
        logging.info(f"{datetime.now()} - prepare_paper_info done")
    
    
    def send():
        logging.info(f"{datetime.now()} - send start")
    
        paper_info_path = os.path.join(config.output_dir, config.paper_info_filename)
        paper_info_df = pd.read_csv(paper_info_path)
        # 未送信の論文を抽出
        send_index = paper_info_df.query("send_flag == False").index
        client = Client(auth=os.environ["NOTION_API_TOKEN"])
    
        if len(send_index):
            print("send...")
            for idx in send_index:
                message = make_slack_message(paper_info_df.loc[idx])
                message_to_slack(message)
                paper_to_notion(client, paper_info_df.loc[idx])
                paper_info_df.at[idx, "send_flag"] = True
            paper_info_df.to_csv(paper_info_path, index=False)
        else:
            message_to_slack("No updated papers.")
    
        print("done!")
        logging.info(f"{datetime.now()} - send done")
    
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser(description="")
        parser.add_argument("--debug", action="store_true", help="Enable debug mode")
        args = parser.parse_args()
    
        os.makedirs(config.output_dir, exist_ok=True)
        bot_log_path = os.path.join(config.output_dir, config.bot_log_filename)
        logging.basicConfig(filename=bot_log_path, level=logging.INFO)
    
        if args.debug:
            prepare_paper_info()
            send()
        else:
            try:
                print("timer start.")
                schedule.every().days.at("06:30").do(prepare_paper_info)
                schedule.every().days.at("07:00").do(send)
                while True:
                    schedule.run_pending()
            except Exception as e:
                logging.error(f"{datetime.now()}: {str(e)}")
                message_to_slack(f"An error occurred: {str(e)}")
    

  • 実行
    # スケジューリング実行
    python send_paper_info.py
    
    # スケジューリング無しで1回実行(デバッグ実行)
    python send_paper_info.py --debug
    

  • 終了
    Ctrl+C
    
    # `python send_paper_info.py &`などでバックグラウンド実行の場合
    kill $(pgrep -a python | grep send_paper_info.py | awk '{print $1}')
    

運用してみた感想

約3週間運用してみて、合計58本の論文が届きました。

  • 最近のトレンドは?
    • タイトルに多い単語(一般的なストップワード, model, learning 除去)

      単語 論文数
      language 12
      large 11
      diffusion 6
      generation 5
      image 4
      3d 4

      -> LLM関連(language, large)やdiffusion, generationが論文数多く、SNS同様に生成系がトレンドになりやすいようです


  • Kaggleに活かせそうなネタあった?

    • 筆者の理解力の問題もあるが、「これKaggleに使えるんじゃね?」ってすぐ刺さる論文はぶっちゃけ少ない (コンペで勝つにはタスクやデータ固有のアプローチが重要なので、刺さるかどうかは参加中のコンペ次第かも)

    • たまに新しいアーキテクチャを提案した論文がある -> アンサンブル ネタに使えるかも

      • 例:UniRepLKNet: A Universal Perception Large-Kernel ConvNet for Audio, Video, Point Cloud, Time-Series and Image Recognition
        1. 既存の大規模カーネル畳み込みニューラルネットワーク(ConvNet)のアーキテクチャは、従来のConvNetやトランスフォーマーの設計原則に従っており、大規模カーネルConvNetのアーキテクチャ設計は未解決のままでした。
        
        2. トランスフォーマーが複数のモダリティで主導的な存在となっている中、ConvNetも視覚以外の領域で強力な普遍的な知覚能力を持っているかどうかは調査されていませんでした。
        
        3. 本研究では、大規模カーネルConvNetの設計に関する4つのアーキテクチャガイドラインを提案しました。これらのガイドラインに従って設計されたモデルは、画像認識において優れた性能を示しました。また、大規模カーネルはConvNetの優れた性能を元々得意ではなかった領域でも引き出す鍵であることを発見しました。モダリティに関連する前処理手法を組み合わせることで、提案モデルはモダリティ固有のカスタマイズなしでも時系列予測や音声認識のタスクで最先端の性能を達成しました。
        
    • データ改善が性能に刺さる論文多い -> コンペに通ずる考えを再認識


  • 改善点は?
    • Papers with Code APIはデータセットやタスクで絞り込み出来る。Sota更新されたら通知も出来るっぽい(参考)。
    • 後から検索しやすいように、キーワード抽出をGPTにお願いするのもアリかと
    • 気になった論文にフラグづけなど、後から振り返りやすい管理の仕組み

最後まで読んでいただき、ありがとうございました!
本記事の内容に誤りなどあれば、コメントにてご教授お願いいたします。

免責事項

本記事の掲載にあたり、記事内容について精査・確認をしておりますが、ご利用により生じた損害等については、筆者および筆者の所属組織は一切の責任を負いません。
ご利用に際しては、自己責任でお願いいたします。

Reference

14
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?