したいこと
長さ N の2つの数列 A = (A0, …, AN-1), B = (B0, …, BN-1) から、同じく長さ N の数列 C = (C0, …, CN-1) を以下のように計算する。(添字がはみ出す際は N の剰余をとる)
C_k = \sum_{i+j=k} A_i B_j
ただし、演算が少し変わっても対処できるよう、畳み込み定理などは使わず愚直に O(N2) で計算する。
Ruby で普通に書くと以下のようになる。これを NArray を使って高速化したい。
n = 5000
a = Array.new(n) { rand(100) }
b = Array.new(n) { rand(100) }
c = Array.new(n) do |k|
(0...n).sum { |i| a[i] * b[k-i] }
end
コード
require "numo/narray"
def conv(n, a, b)
na = Numo::Int64.cast(a)
nb = Numo::Int64.cast(b)
nc = (na.append(0) * nb.expand_dims(-1))
.reshape!(n + 1, n)
.sum(-2)
nc.to_a
end
仕組み
畳み込みを計算する際は、掛け算の筆算のように要素の全ペアの積を斜めに配置して和をとる。(右にはみ出した分は左に移す)
1 2 3
x 1 2 3
-----------------
1 2 3
2 4 6
3 6 9
-----------------
1 2 3
6 2 4
6 9 3
-----------
13 13 10
しかし NArray で全ペアの積を計算しても、九九のように並ぶため斜めに配置されない。斜めに和をとる方法もおそらく無い。
n = 3
na = Numo::Int64[1, 2, 3]
nb = Numo::Int64[1, 2, 3]
na_nb = na * nb.expand_dims(-1)
#=>
# Numo::Int64#shape=[3,3]
# [[1, 2, 3],
# [2, 4, 6],
# [3, 6, 9]]
ここで、もし各行の末尾に 0
を付けたうえで元の列数に整形すれば、要素が次の行へ押し出されて斜めに配置される。また 0
が入っても和は変わらないので、このまま列毎に和を計算できる。
na_nb_ = na_nb.concatenate(Numo::Int64.zeros(n, 1), axis: -1)
#=>
# Numo::Int64#shape=[3,4]
# [[1, 2, 3, 0],
# [2, 4, 6, 0],
# [3, 6, 9, 0]]
na_nb_.reshape(nil, n)
#=>
# Numo::Int64#shape=[4,3]
# [[1, 2, 3],
# [0, 2, 4],
# [6, 0, 3],
# [6, 9, 0]]
並べ替える前の状態は、積をとる前に要素を追加しておくと速く作れる。
na_nb_ = na.append(0) * nb.expand_dims(-1)
#=>
# Numo::Int64#shape=[3,4]
# [[1, 2, 3, 0],
# [2, 4, 6, 0],
# [3, 6, 9, 0]]