Summary
- ss-toolsを機能拡張
- テキストから高精度で論文IDを引けるようにレーベンシュタイン距離を実装した
- APIでエラーが発生したときの回避処理を実装した
Crates & Repositories
crate | GitHub |
---|---|
rsrpp | rsrpp |
rsrpp-cli | rsrpp |
arxiv-tools | rs-arxiv-tools |
ss-tools | rs-ss-tools |
前回までのあらすじ
arXivとSemantic ScholarのAPIラッパーを実装して,arXivから日次で論文リストをスクレイピングしてきて論文のメタ情報を収集する機能が大体完成しました.
今回はSemantic Scholar APIにて積み残していた,
- テキストの類似度算出
- APIのエラー処理
を実装します.
テキストの類似度の問題は,下記クラス図の中で,SemanticScholar
に実装されている query_paper_id()
にて発生します.
APIのエラー処理は,query
関連のすべての関数で実装します.
テキストの類似度
query_paper_id()
では,論文のタイトルの文字列から論文IDを検索します.このツールはタイトルがほぼ完全にわかっている前提で情報の収集を行うことを想定しており,あまり曖昧な検索には使用しません.
その分,与えられたタイトルがSemanticScholarのデータベースに存在するのであれば,ほぼ確実に論文を引き当てたいです.
ここで論文を引き当てる処理の精度が低いと,後続の処理に大きな影響を与えてしまいます.
さて,精度を確保するために具体的に何をしなければいけないのか確認するために,まずはSemantic ScholarのAPIから返ってくるJSONを眺めてみます.
次のようなクエリを投げると
https://api.semanticscholar.org/graph/v1/paper/search?query=attention%20is%20all%20you%20need&limit=10
以下のようなレスポンスが返ってきます.
{
"total": 6915,
"offset": 0,
"next": 10,
"data": [
{
"paperId": "204e3073870fae3d05bcbc2f6a8e263d9b72e776",
"title": "Attention is All you Need"
},
{
"paperId": "51c9d4d2f50ac5707c1f889aa97f08350d549132",
"title": "Attention Is All You Need In Speech Separation"
},
{
"paperId": "bb4a9650ca3946c70a7e92007cc61dc0dfd75522",
"title": "Channel Attention Is All You Need for Video Frame Interpolation"
},
{
"paperId": "56e5f5810441f0ce72641cc8db2217510fd5f48d",
"title": "Attention is all you need: utilizing attention in AI-enabled drug discovery"
},
{
"paperId": "e8ed8883c20e8b1f16b20d6fd941f7abccf01199",
"title": "Attention is all you need: An interpretable transformer-based asset allocation approach"
},
{
"paperId": "3e32139deb17761a25075f8839daa61ad5992fc9",
"title": "Cross-Attention is All You Need: Adapting Pretrained Transformers for Machine Translation"
},
{
"paperId": "fa08b41ccdfc5d8771adfbc34c176fa237d4646c",
"title": "Is Space-Time Attention All You Need for Video Understanding?"
},
{
"paperId": "1dfcf8401952771a35ca9d465bd2361729ad01ca",
"title": "Attention Is All You Need For Blind Room Volume Estimation"
},
{
"paperId": "38272f2c4cb5ec843db0e721f8d2589e0a9d82d4",
"title": "Graph Structure from Point Clouds: Geometric Attention is All You Need"
},
{
"paperId": "a972d28273fe7e9c7608a562449719ffa08f769f",
"title": "Master GAN: Multiple Attention is all you Need: A Multiple Attention Guided Super Resolution Network for Dems"
}
]
}
おそらく関連度順に並んでいるので,一番上のものを持ってくれば良さそうですが,前述の通りここでの誤検出は後の処理への影響が大きいので,こちらでも欲しいタイトルがマッチしているかどうか確認したいです.
そこで,今回はテキスト処理でよく使われるレーベンシュタイン距離を用いて類似度を判定します.
レーベンシュタイン距離とは
Hello Perplexity!
レーベンシュタイン距離は、2つの文字列がどれだけ異なっているかを数値化する方法です。
この距離は、一方の文字列を他方の文字列に変換するために必要な最小の編集操作回数を表します。
主な特徴
- 定義: 文字列間の最小編集距離
- 操作: 文字の挿入、削除、置換の3種類
- 距離の意味: 小さいほど文字列が類似、大きいほど相違
計算方法
レーベンシュタイン距離の計算は、以下の手順で行われます:
- 2つの文字列を比較
- 一方を他方に変換するために必要な最小の編集操作を特定
- 編集操作の回数をカウント
具体例
「kitten」を「sitting」に変換する場合:
k → s (置換)
e → i (置換)
→ g (挿入)
この例では、3回の編集操作が必要なので、レーベンシュタイン距離は3となります
特徴と注意点
- 文字列の長さが異なる場合でも計算可能
- 編集操作ごとに異なるコストを設定することも可能
- 計算には通常、動的計画法が用いられる
レーベンシュタイン距離は、文字列の類似度を客観的に評価する上で非常に有用なツールです。しかし、具体的な応用場面に応じて、適切な閾値や解釈方法を設定する必要があります
以上,レーベンシュタイン距離についての説明でした.
なぜ今回Perplexityさんに説明をお願いしたかというと,いざレーベンシュタイン距離を実装しようとして,fn levenshtein_dist() {
と打ち込んだ瞬間にGitHub Copilotさんが完成されたプログラムを出力してくれたので,こちらでやることがなくなってしまったため,いっそ説明もAIにお願いしようと思った次第であります.
なお,プログラムが正しいかどうかは,きちんとテストを書いて確認しています.
レーベンシュタイン距離は類似度とは逆で,2つのテキストが似ているほど値が小さくなるので,使用する場合はスケールが逆になるように変換してから使用します.
今回は以下のように変換して使用しました.$s_1$,$s_2$はそれぞれ入力のテキストです.
$$
\text{LevenshteinSimilarity}(s_1, s_2)=\frac{1}{1+\text{NormalizedLevenshteinDistance}(s_1, s_2))}
$$
$$
\text{NormalizedLevenshteinDistance}(s_1, s_2) = \frac{\text{LevenshteinDistance}(s_1, s_2)}{\max(\text{CharCount}(s_1), CharCount(s_2))}
$$
さて,この定義に基づいて,先ほどの出力からレーベンシュタイン類似度を計算してみます.
Text | Levenshtein Similarity |
---|---|
Attention Is All You Need | 1.000 |
Attention Is All You Need In Speech Separation | 0.687 |
Channel Attention Is All You Need for Video Frame Interpolation | 0.624 |
Attention is all you need: utilizing attention in AI-enabled drug discovery | 0.600 |
Attention is all you need: An interpretable transformer-based asset allocation approach | 0.584 |
Cross-Attention is All You Need: Adapting Pretrained Transformers for Machine Translation | 0.582 |
Is Space-Time Attention All You Need for Video Understanding? | 0.592 |
Attention Is All You Need For Blind Room Volume Estimation | 0.637 |
Graph Structure from Point Clouds: Geometric Attention is All You Need | 0.609 |
Master GAN: Multiple Attention is all you Need: A Multiple Attention Guided Super Resolution Network for Dems | 0.565 |
Attention Is All You Needは色んなパロディが存在するのでちょうど良いテスト対象でした.
無事に,対象のタイトルを抽出できそうです.
一見すると完全一致で比較するだけでも良さそうですが,それだと,こちらから与える論文のタイトルが完全でなければならないということと,SemanticScholarが提供する論文のタイトルも表記揺れが存在するケースがあるので,少々柔軟性に欠けるシステムになってしまいます.
類似度を計算しておけば,最もスコアが高いテキストを取得してくることができるので,多少誤字ってもダイジョウブ!
ちなみに,今回の変換の値域は $0<\text{score}<=1$ となっています.
最後に,レーベンシュタイン距離〜類似度を計算するプログラムです.
pub fn levenshtein_dist(s1: &str, s2: &str) -> usize {
let len1 = s1.chars().count();
let len2 = s2.chars().count();
let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
for i in 0..=len1 {
matrix[i][0] = i;
}
for j in 0..=len2 {
matrix[0][j] = j;
}
s1.chars().enumerate().for_each(|(i, c1)| {
s2.chars().enumerate().for_each(|(j, c2)| {
let cost = if c1 == c2 { 0 } else { 1 };
matrix[i + 1][j + 1] = std::cmp::min(
matrix[i][j + 1] + 1,
std::cmp::min(matrix[i + 1][j] + 1, matrix[i][j] + cost),
);
});
});
return matrix[len1][len2];
}
pub fn levenshtein_dist_normalized(s1: &str, s2: &str) -> f64 {
let len1 = s1.chars().count();
let len2 = s2.chars().count();
let dist = levenshtein_dist(s1, s2) as f64;
let max_len = std::cmp::max(len1, len2) as f64;
return dist / max_len;
}
pub fn levenshtein_similarity(s1: &str, s2: &str) -> f64 {
return 1.0 / (1.0 + levenshtein_dist_normalized(s1, s2));
}
APIのエラー処理
SemanticScholarのAPIはまれによくエラーで失敗します.
現状のコードではAPIのエラーを想定しておらず,エラーが発生したらそのまま終了してしまいます.
システム全体としても,頻繁にエラーが発生して止まってしまうのは良くないので,なんとかします.
何度かリトライしているとうまくいくことが多いので,方針としては,クエリが失敗した場合には少しスリープして時間をおいたのち,指定された最大回数までリトライするという仕組みを導入します.
最大試行回数とスリープする時間は引数として受け取るようにします.
関数の中段のloop
で試行回数までクエリを繰り返すようにしています.指定された最大回数を超えた場合は普通にエラーとして処理します.
改良後のquery_paper_id()
はこちら.
pub async fn query_paper_id(
&mut self,
query_text: String,
max_retry_count: &mut u64,
wait_time: u64,
) -> Result<(String, String)> {
self.query_text = query_text;
self.endpoint = SsEndpoint::GetPaperTitle;
let mut headers = header::HeaderMap::new();
if !self.api_key.is_empty() {
headers.insert("x-api-key", self.api_key.parse().unwrap());
}
let client = request::Client::builder()
.default_headers(headers)
.build()
.unwrap();
let url = self.build();
println!("URL: {}", url);
loop {
if *max_retry_count == 0 {
return Err(Error::msg(format!(
"Failed to get paper id for: {}",
self.query_text
)));
}
let body = client
.get(url.clone())
.send()
.await
.unwrap()
.text()
.await
.unwrap();
match serde_json::from_str::<SsResponsePpaerIds>(&body) {
Ok(response) => {
if response.data.is_empty() {
*max_retry_count -= 1;
self.sleep(wait_time);
continue;
}
let mut scores: Vec<(SsScore, LevSimilarityScore, (String, String))> =
Vec::new();
response.data.iter().for_each(|paper| {
let title = paper.title.clone().unwrap_or("".to_string());
let score = paper.match_score.unwrap_or(0.0);
let lev_score = utils::levenshtein_similarity(&self.query_text, &title);
scores.push((
score,
lev_score,
(
paper.paper_id.clone().unwrap(),
paper.title.clone().unwrap(),
),
));
});
let total_score = |ss_s, lev_s| 0.5 * ss_s + 0.5 * lev_s;
let (paper_id, paper_title) = scores
.iter()
.max_by(|a, b| {
total_score(a.0, a.1)
.partial_cmp(&total_score(b.0, b.1))
.unwrap()
})
.unwrap()
.2
.clone();
return Ok((paper_id, paper_title));
}
Err(_) => {
*max_retry_count -= 1;
self.sleep(wait_time);
continue;
}
}
}
}
なお,sleep
の実装はこちら.
fn sleep(&self, seconds: u64) {
let pb = ProgressBar::new(seconds);
pb.set_style(
indicatif::ProgressStyle::default_bar()
.template(
"{spinner:.green} [{elapsed_precise}] [{bar:40.green/cyan}] {pos}s/{len}s {msg}",
)
.unwrap()
.progress_chars("█▓▒░"),
);
pb.set_message("Waiting for the next request...");
for _ in 0..seconds {
pb.inc(1);
std::thread::sleep(std::time::Duration::from_secs(1));
}
pb.finish_and_clear();
}
指定した秒数分プログレスバーを進捗させてスリープします.
rustではindicatif
という使いやすいプログレスバーがあるので,それでスリープ時間をわかるように表示しています.
上記の処理をAPIを叩く関数全てに実装しました.
以上でss-tools
の積み残しの実装は完了です.
次回
次回は抽出してきた論文からキーワードを抽出する処理を実装します.
キーワード抽出にはいろいろな手法がありますが,今回は目的がはっきりしているので一番原始的な方法で実装していきます.