はじめに
TensorFlow 2.1以降で導入された、可変長のデータを表す RaggedTensor ですが、普通のTensorのノリで書こうとすると色々ハマります。
今回は信号処理などで使いそうな窓処理編。ある時間幅を持ったフレームを少しずつずらして、フレームの範囲に入った波形を抽出します。
検証環境
- Ubuntu 18.04
- Python 3.6.9
- TensorFlow 2.2.0 (CPU)
やりたいこと
x
をバッチだと思って、各行のデータに対して短い区間の波形を切り出していきます。データの長さはバラバラです。
ここでは、フレーム幅を2として、切り出し位置を1ずつシフトしていきます。
[3, 1, 4, 1]
なら [[3, 1], [1, 4], [4, 1]]
といった感じですね。
普通の Tensor
に対しては tf.signal.frame という便利な関数がありますが、残念ながら RaggedTensor
には使えません。
x = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
print(tf.signal.frame(x, 2, 1)) # NG
# ValueError: TypeError: object of type 'RaggedTensor' has no len()
print(tf.signal.frame(x.to_tensor(), 2, 1)) # 動くけど余分な0がいっぱい出てくる
# tf.Tensor(
# [[[3 1]
# [1 4]
# [4 1]]
#
# [[0 0]
# [0 0]
# [0 0]]
#
# [[5 9]
# [9 2]
# [2 0]]
#
# [[6 0]
# [0 0]
# [0 0]]
#
# [[0 0]
# [0 0]
# [0 0]]], shape=(5, 3, 2), dtype=int32)
解決策
バッチ次元を外した (flattenした) 値を表す x.values
と、RaggedTensor
から得られる各行の長さやオフセットをベースに考えます。
print(x.values) # 値を並べたTensor
# tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
print(x.row_starts()) # valuesにおける各行の開始インデックス(オフセット)
# tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64)
print(x.row_lengths()) # 各行の長さ
# tf.Tensor([4 0 3 1 0], shape=(5,), dtype=int64)
x
の各行ごとに、x.values
のどのインデックスから値を取ってくればよいかを考えます(*)。
- 0行目は
[0, 1], [1, 2], [2, 3]
- 1行目は空
- 2行目は
[4, 5], [5, 6]
- 3行目は空
- 4行目は空
まずは(*)の先頭インデックスを持つようなRaggedTensor
を作ってやると
s = x.row_starts()
e = s + x.row_lengths() - 1
r = tf.ragged.range(s, e)
print(r)
# <tf.RaggedTensor [[0, 1, 2], [], [4, 5], [], []]>
さらに、インデックスを1進めたものを結合すると、窓処理後の期待する結果について、x.values
のどこから値を持ってくればよいかが分かります。先の箇条書き(*)と対応する結果が得られています。
ind = tf.stack([r, r+1], axis=2)
print(ind)
# <tf.RaggedTensor [[[0, 1], [1, 2], [2, 3]], [], [[4, 5], [5, 6]], [], []]>
あとは、tf.gather()
を使って、ind
に入ったインデックスに基づいて x.values
から値を取ってくればOK。
ret = tf.gather(x.values, ind)
print(ret)
# <tf.RaggedTensor [[[3, 1], [1, 4], [4, 1]], [], [[5, 9], [9, 2]], [], []]>
フレーム長が3以上の場合
e
と ind
の作り方が少し変わりますが、大筋は同じです。
ind
の作成にはブロードキャストを利用しています。そのために r[:, :, tf.newaxis]
として末尾に長さ1の次元を追加しています。
len_frame = 3
s = x.row_starts()
e = s + x.row_lengths() + 1 - len_frame
r = tf.ragged.range(s, e)
ind = r[:, :, tf.newaxis] + tf.range(0, len_frame, dtype="int64")
ret = tf.gather(x.values, ind)
print(ret)
# <tf.RaggedTensor [[[3, 1, 4], [1, 4, 1]], [], [[5, 9, 2]], [], []]>
もちろん len_frame = 2
の場合でも使えます。
フレームシフトが2以上の場合
r
の刻み幅を変えればOKです。
len_frame = 2
shift_frame = 2
s = x.row_starts()
e = s + x.row_lengths() + 1 - len_frame
r = tf.ragged.range(s, e, shift_frame)
ind = r[:, :, tf.newaxis] + tf.range(0, len_frame, dtype="int64")
ret = tf.gather(x.values, ind)
print(ret)
# <tf.RaggedTensor [[[3, 1], [4, 1]], [], [[5, 9]], [], []]>
shift_frame = 1
でも大丈夫です。
サンプルが多次元の場合
例えばステレオ音声でLとRの値がペアで格納されているような場合が該当します。
x = tf.ragged.constant([[[3, 2], [1, 7], [4, 1], [1, 8]], [], [[5, 2], [9, 8], [2, 1]], [[6, 8]], []])
実は今までと全く同じ方法で動きます。
len_frame = 2
shift_frame = 1
s = x.row_starts()
e = s + x.row_lengths() + 1 - len_frame
r = tf.ragged.range(s, e, shift_frame)
ind = r[:, :, tf.newaxis] + tf.range(0, len_frame, dtype="int64")
ret = tf.gather(x.values, ind)
print(ret)
# <tf.RaggedTensor [[[[3, 2], [1, 7]], [[1, 7], [4, 1]], [[4, 1], [1, 8]]], [], [[[5, 2], [9, 8]], [[9, 8], [2, 1]]], [], []]>
次元数が増えすぎて、パッと見ただけでは合っているかどうか分からなくなってきましたが、大丈夫なはず…。