beam searchしたくてギッハブからスターがたくさんついているやつを動かしているとたまに TypeError: '<' not supported between instances of 'BeamSearchNode' and 'BeamSearchNode' がでちゃう問題が発生しました。
参考にした実装↓
不定期に出るエラー(抜粋)
File "/~~~~~~~~~~~~/beam.py", line 119, in beam_decode
nodes.put((score, nn))
TypeError: '<' not supported between instances of 'BeamSearchNode' and 'BeamSearchNode'
nodesはqueue.PriorityQueueのインスタンスです(組み込み関数)。
ちょっと調べてみると、PriorityQueueにtupleを追加していくときに、まずtupleの一つ目の要素で比較して中身をソートしてくれますが、その要素が一致してしまってpriorityが付けられない場合は2つ目の要素で比較演算子を実行し、ソートするみたいです。
今回の場合、(scores, nn)を追加していて、たまたまscoreが一致したときにnn (BeamSearchNodeのインスタンス) 同士で比較演算子が実行されてしまい、このようなエラーが出ちゃったようです。
これを解決するために、scoreの後ろに他のノードと被らないスコアを入れました。
自分で分かりやすいようにイジっているので上記のギッハブの実装とはちょっと違いますが、new nodeを追加する所は以下のように変更しました。
node = BeamSearchNode(encoder_out=encoder_output,
previousNode=n,
decoder_input=next_decoder_input,
logProb=n.logp + log_p,
length=n.leng + 1)
score = -node.eval()
count += 1
nextnodes.append((score, count, node))
for i in range(len(nextnodes)):
score, count, nn = nextnodes[i]
nodes.put((score, count, nn))
BeamSearchNodeの中身↓
class BeamSearchNode(object):
def __init__(self, previousNode, decoder_input, logProb, length):
self.prevNode = previousNode
self.dec_in = decoder_input
self.logp = logProb
self.leng = length
def eval(self, alpha=0.6):
return self.logp / (((5 + self.leng) / (5 + 1)) ** alpha)
Transformerなのでノードに対してtokenではなく推論中のsentence全体を渡している点がギッハブの実装と異なっています。
encoderのアウトプット(memory)は外から渡してます。
あとはevalで計算するスコアにGNMTと同じペナルティ項を追加しています。