はじめに
この記事は完全に自分用の備忘録ですが、同じことやってる人は一定数いそうなので、シェアします。
おそらく、もっと良い方法があると思うので、コメント等で教えてくれると嬉しいです!
前提
自分は自然言語処理で文書間の類似度を測っているときに利用しました。全ペアの中から類似度の高い上位n件とその文書idを保持する必要がありました。
実装
以下のようなclassで実装しました。
class TopData:
def __init__(self, top_n=3):
self.top_n = top_n
self.data = []
def add(self, i, j, sim):
self.data.append((i, j, sim))
self.data.sort(key=lambda x: x[2], reverse=True)
self.data = self.data[:self.top_n]
def get_top_data(self):
return self.data
使用例はこんな感じです。
top3 = TopData()
top3.add(1, 2, 0.5)
top3.add(3, 4, 0.8)
top3.add(5, 6, 0.2)
top3.add(7, 8, 0.9)
print(top_sim_data.get_top_data()) # [(7, 8, 0.9), (3, 4, 0.8), (1, 2, 0.5)]
使い勝手としては良くも悪くもない感じですが、これで最低限自分がやりたいことはできました。
最後に
もっといい書き方、実装方法教えてください🥺