LoginSignup
130
110

More than 5 years have passed since last update.

編集距離(レーベンシュタイン距離)を理解し、実装する

Posted at

編集距離(レーベンシュタイン距離)を理解し、実装する

とある実験を行うため、編集距離を使う必要があるので、勉強したものをアウトプットします。

<本記事のゴール>

編集距離について、理解した上で、Pythonで実装できている状態

<進め方>

  • 編集距離の概念を理解する
  • 編集距離の実装(正解)を見る
  • デコレータについて理解する
  • メモ化について理解する
  • 編集距離の実装(正解)を理解する
  • 補足:編集距離の発展形を理解する

編集距離の概念を理解する

編集距離、または、レーベンシュタイン距離については、Wikipediaに以下の記載があります。

レーベンシュタイン距離(レーベンシュタインきょり、英: Levenshtein distance)は、二つの文字列がどの程度異なっているかを示す距離の一種である。編集距離(へんしゅうきょり、英: edit distance)とも呼ばれる。具体的には、1文字の挿入・削除・置換によって、一方の文字列をもう一方の文字列に変形するのに必要な手順の最小回数として定義される。

文字列Aと文字列Bを同時に先頭からスキャンしながら、以下操作を行います。

  • 文字列Aの現在位置に、文字列Bの現在位置の文字を挿入(I:Insert)
  • 文字列Aの現在位置の文字を削除(D:Delete)
  • 文字列Aの現在位置の文字を、文字列Bの現在位置の文字で置換(R:Replace)
  • 何もしない(M:Match)

以上をワンパス行うことで、文字列Aは文字列Bと等しくなります。何もしない場合を除き、コスト1が発生するとします。挿入、削除、置換の選び方は自由なので、コストの合計はいろんな値をとり得ますが、最小のものを編集距離とします。

Algorithms on strings, trees, and sequences』pp.215-216では、文字列「vintner」と「writers」が操作「RIMDMDMMI」によって一致する例が示されています。

  • 初期状態
0 1 2 3 4 5 6
v i n t n e r
w r i t e r s
  • 0番目の文字をvからwに置換(コスト==1)
0 1 2 3 4 5 6
w i n t n e r
w r i t e r s
  • 1番目に文字rを挿入(コスト==2)
0 1 2 3 4 5 6 7
w r i n t n e r
w r i t e r s
  • 2番目の文字iは一致(コスト==2)
0 1 2 3 4 5 6 7
w r i n t n e r
w r i t e r s
  • 3番目の文字nを削除→tが一致(コスト==3)
0 1 2 3 4 5 6
w r i t n e r
w r i t e r s
  • 4番目の文字nを削除→eが一致(コスト==4)
0 1 2 3 4 5 6
w r i t e r
w r i t e r s
  • 5番目の文字rは一致(コスト==4)
0 1 2 3 4 5 6
w r i t e r
w r i t e r s
  • 6番目に文字sを挿入(コスト==5)
0 1 2 3 4 5 6
w r i t e r s
w r i t e r s

文字列AとBは一致し、コスト5を得ました。あとは、最小のコストを得るアルゴリズムを考え、実装すれば良いですね。

編集距離の実装(正解)を見る

Wikipediaには、(n+1)(m+1)の二次元配列を作成し、二重ループを回すアルゴリズムが擬似コードで示されています。しかし、本質を理解/説明するためには、ループのような低次元な書き方は避けたいところです。

こういうときは、Rosetta CodeのHaskellでの実装が参考になる…と思いきや、こちらは高次元すぎてワケわかんにゃい。Haskellを理解できない自分の頭がうらめしい。

幸いにも、Rosetta CodeのPythonでの実装に、ちょうど良いレベル感のコードがあったので、ほんの一部変更して、以下に引用します。このコードを理解/説明するのが、本記事のゴールです。

from functools import lru_cache

@lru_cache(maxsize=4096)
def ld(s, t):
    if not s: return len(t)
    if not t: return len(s)
    if s[0] == t[0]: return ld(s[1:], t[1:])
    l1 = ld(s, t[1:])
    l2 = ld(s[1:], t)
    l3 = ld(s[1:], t[1:])
    return 1 + min(l1, l2, l3)

print(ld('vintner', 'writers'))
5

デコレータについて理解する

のっけから「@lru_cache(maxsize=4096)」と来たもんだ。これは、デコレータと呼ばれます。クラスメソッドを定義する際に、定義の直前に@classmethodと書くと思いますが、あれもデコレータです。

デコレータは、関数に何らかの機能を追加するものです。その実体は、引数と戻り値が関数であるような関数です。構文糖衣のため、簡潔に書くことができるようになっています。

Python用語集に、以下の記載があります。

decorator

(デコレータ) 別の関数を返す関数で、通常、 @wrapper 構文で関数変換として適用されます。デコレータの一般的な利用例は、 classmethod() と staticmethod() です。

デコレータの文法はシンタックスシュガーです。次の2つの関数定義は意味的に同じものです:

def f(...):
   ...
f = staticmethod(f)

@staticmethod
def f(...):
   ...

デコレータは重要機能ですので、Qiitaでも記事に取り上げられています。

以下の例は、関数に「お腹すいた」アピール機能を追加します。

def hungry(f):
    def wrapper(*args, **kwargs):
        print('お腹すいた')
        return f(*args, **kwargs)
    return wrapper

@hungry
def msum(iterable):
    return sum(iterable)

@hungry
def mmax(iterable):
    return max(iterable)

print(msum(range(10)))
print(mmax(range(10)))
お腹すいた
45
お腹すいた
9

編集距離の実装を理解するため、まずは、デコレータについて理解しました。

次なる疑問は、先ほどの「@lru_cache(maxsize=4096)」が、具体的にどんな機能を追加してくれるのか、という点です。

メモ化について理解する

結論を先に言うと、「@lru_cache(maxsize=4096)」は、関数にメモ化の機能を追加してくれます。Python Documentationでは、ここに説明があります。

メモ化とは、計算結果をメモしておき、再利用する手法のことです。Wikipediaの説明は以下です。

メモ化(英: Memoization)とは、プログラムの高速化のための最適化技法の一種であり、サブルーチン呼び出しの結果を後で再利用するために保持し、そのサブルーチン(関数)の呼び出し毎の再計算を防ぐ手法である。

例として、階乗からなる数列を計算するとします。すなわち、リスト $[1!, 2!, 3!, 4!, 5!, ...]$ を得たいとします。このとき、$5!$ を定義通りに $5\cdot4\cdot3\cdot2\cdot1$ と計算するのでなく、直前の $4!$ の計算結果をメモしておいて、$5\cdot4!$ と計算したほうがコストが少なくて済みますよね。これがメモ化の恩恵です。フィボナッチ数列などでも同様のことがいえます。

Pythonではデコレータ「@lru_cache」によりメモ化の機能を追加することができます。maxsizeは、メモのサイズを意味します。maxsizeに0を指定した場合、メモを使わないのと同義です。以下に階乗の例を示します。

※余談:0の階乗は1である

from functools import lru_cache

for i in range(3):
    @lru_cache(maxsize=i)
    def pow(n):
        return 1 if n == 0 else n * pow(n - 1)
    print('maxsize =', i)
    %time [pow(n) for n in range(1, 101)]
    print(pow.cache_info(), '\n')
maxsize = 0
CPU times: user 21.4 ms, sys: 1.04 ms, total: 22.5 ms
Wall time: 22.4 ms
CacheInfo(hits=0, misses=5150, maxsize=0, currsize=0) 

maxsize = 1
CPU times: user 80 µs, sys: 0 ns, total: 80 µs
Wall time: 86.1 µs
CacheInfo(hits=99, misses=101, maxsize=1, currsize=1) 

maxsize = 2
CPU times: user 82 µs, sys: 1 µs, total: 83 µs
Wall time: 87 µs
CacheInfo(hits=99, misses=101, maxsize=2, currsize=2) 

メモを使う場合と使わない場合とで、20倍強の差が出ました。メモ化の恩恵は大きいですね。

メモ化は重要な手法ですので、Qiitaでも記事に取り上げられています。

一番目の記事でご指摘のとおり、メモ化を行うためには参照透過(引数が同じであれば戻り値が同じ)が必須条件となります。現在時刻を返すような関数をメモ化しても役に立ちません。

二番目の記事はPHPの例ですが、原因分析がわかりやすいなと思いました。再帰で書くとシュッと書ける(可読性は上がる)が、呼び出し回数が飛躍的に増加してしまう。メモ化はその弱点を補う。これは本質をついていると思います。

というわけで、編集距離の実装を理解するため、メモ化についても理解しました。重要なのは、メモ化することにより、再帰で書けているため、きわめて可読性の高いコードになっている、という点です。

編集距離の実装(正解)を理解する

可読性の高いコードなので、絵を描いたりしなくても、コードにコメントを追記すれば、理解/説明に十分なはずです。

from functools import lru_cache

@lru_cache(maxsize=4096)
def ld(s, t):
    '''文字列のレーベンシュタイン距離を計算する'''

    # 一方が空文字列なら、他方の長さが求める距離
    if not s: return len(t)
    if not t: return len(s)

    # 一文字目が一致なら、二文字目以降の距離が求める距離
    if s[0] == t[0]: return ld(s[1:], t[1:])

    # 一文字目が不一致なら、追加/削除/置換のそれぞれを実施し、
    # 残りの文字列についてのコストを計算する

    # Sの先頭に追加
    l1 = ld(s, t[1:])

    # Sの先頭を削除
    l2 = ld(s[1:], t)

    # Sの先頭を置換
    l3 = ld(s[1:], t[1:])

    # 追加/削除/置換を実施した分コスト(距離)1の消費は確定
    # 残りの文字列についてのコストの最小値を足せば距離となる
    return 1 + min(l1, l2, l3)

print(ld('vintner', 'writers'))
5

補足:編集距離の発展形を理解する

文字列Aと文字列Bの編集距離のレンジは、0 〜 max(len(A), len(B)) です。これを 0 〜 1 に標準化したいというのは、誰もが思うこと。以下記事にならい、「長い方の文字列の長さで編集距離を割る」ことにしましょう。

from functools import lru_cache

@lru_cache(maxsize=4096)
def ld(s, t):
    if not s: return len(t)
    if not t: return len(s)
    if s[0] == t[0]: return ld(s[1:], t[1:])
    l1 = ld(s, t[1:])
    l2 = ld(s[1:], t)
    l3 = ld(s[1:], t[1:])
    return 1 + min(l1, l2, l3)

def lds(s, t):
    return ld(s, t) / max(len(s), len(t))

print(lds('xx', 'xx'))
print(lds('xx', 'xy'))
print(lds('xx', 'yy'))
0.0
0.5
1.0

いい感じですね。

最後に、距離ではなく、類似度を表すようにしてみましょうか。$y = -x + 1$ となるよう線形変換してやれば良いはずです。

from functools import lru_cache

@lru_cache(maxsize=4096)
def ld(s, t):
    if not s: return len(t)
    if not t: return len(s)
    if s[0] == t[0]: return ld(s[1:], t[1:])
    l1 = ld(s, t[1:])
    l2 = ld(s[1:], t)
    l3 = ld(s[1:], t[1:])
    return 1 + min(l1, l2, l3)

def lds(s, t):
    return ld(s, t) / max(len(s), len(t))

def lss(s, t):
    return -lds(s, t) + 1

print(lss('xx', 'xx'))
print(lss('xx', 'xy'))
print(lss('xx', 'yy'))
1.0
0.5
0.0

できました。

発展形まで含め、編集距離について、理解した上で、Pythonで実装できている状態となりました。ゴールを達成できたので、本記事を締めくくりたいと思います。

ご購読ありがとうございました!

130
110
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
130
110