0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

[TensorFlow] RaggedTensorで窓処理したい

Last updated at Posted at 2020-07-18

はじめに

TensorFlow 2.1以降で導入された、可変長のデータを表す RaggedTensor ですが、普通のTensorのノリで書こうとすると色々ハマります。
今回は信号処理などで使いそうな窓処理編。ある時間幅を持ったフレームを少しずつずらして、フレームの範囲に入った波形を抽出します。

検証環境

  • Ubuntu 18.04
  • Python 3.6.9
  • TensorFlow 2.2.0 (CPU)

やりたいこと

x をバッチだと思って、各行のデータに対して短い区間の波形を切り出していきます。データの長さはバラバラです。
ここでは、フレーム幅を2として、切り出し位置を1ずつシフトしていきます。
[3, 1, 4, 1] なら [[3, 1], [1, 4], [4, 1]] といった感じですね。

普通の Tensor に対しては tf.signal.frame という便利な関数がありますが、残念ながら RaggedTensor には使えません。

x = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
print(tf.signal.frame(x, 2, 1)) # NG
# ValueError: TypeError: object of type 'RaggedTensor' has no len()
print(tf.signal.frame(x.to_tensor(), 2, 1)) # 動くけど余分な0がいっぱい出てくる
# tf.Tensor(
# [[[3 1]
#   [1 4]
#   [4 1]]
# 
#  [[0 0]
#   [0 0]
#   [0 0]]
# 
#  [[5 9]
#   [9 2]
#   [2 0]]
# 
#  [[6 0]
#   [0 0]
#   [0 0]]
# 
#  [[0 0]
#   [0 0]
#   [0 0]]], shape=(5, 3, 2), dtype=int32)

解決策

バッチ次元を外した (flattenした) 値を表す x.values と、RaggedTensor から得られる各行の長さやオフセットをベースに考えます。

print(x.values)        # 値を並べたTensor
# tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
print(x.row_starts())  # valuesにおける各行の開始インデックス(オフセット)
# tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64)
print(x.row_lengths()) # 各行の長さ
# tf.Tensor([4 0 3 1 0], shape=(5,), dtype=int64)

x の各行ごとに、x.values のどのインデックスから値を取ってくればよいかを考えます(*)。

  • 0行目は [0, 1], [1, 2], [2, 3]
  • 1行目は空
  • 2行目は [4, 5], [5, 6]
  • 3行目は空
  • 4行目は空

まずは(*)の先頭インデックスを持つようなRaggedTensor を作ってやると

s = x.row_starts()
e = s + x.row_lengths() - 1
r = tf.ragged.range(s, e)
print(r)
# <tf.RaggedTensor [[0, 1, 2], [], [4, 5], [], []]>

さらに、インデックスを1進めたものを結合すると、窓処理後の期待する結果について、x.values のどこから値を持ってくればよいかが分かります。先の箇条書き(*)と対応する結果が得られています。

ind = tf.stack([r, r+1], axis=2)
print(ind)
# <tf.RaggedTensor [[[0, 1], [1, 2], [2, 3]], [], [[4, 5], [5, 6]], [], []]>

あとは、tf.gather() を使って、ind に入ったインデックスに基づいて x.values から値を取ってくればOK。

ret = tf.gather(x.values, ind)
print(ret)
# <tf.RaggedTensor [[[3, 1], [1, 4], [4, 1]], [], [[5, 9], [9, 2]], [], []]>

フレーム長が3以上の場合

eind の作り方が少し変わりますが、大筋は同じです。
ind の作成にはブロードキャストを利用しています。そのために r[:, :, tf.newaxis] として末尾に長さ1の次元を追加しています。

len_frame = 3
s = x.row_starts()
e = s + x.row_lengths() + 1 - len_frame
r = tf.ragged.range(s, e)
ind = r[:, :, tf.newaxis] + tf.range(0, len_frame, dtype="int64")
ret = tf.gather(x.values, ind)
print(ret)
# <tf.RaggedTensor [[[3, 1, 4], [1, 4, 1]], [], [[5, 9, 2]], [], []]>

もちろん len_frame = 2 の場合でも使えます。

フレームシフトが2以上の場合

r の刻み幅を変えればOKです。

len_frame = 2
shift_frame = 2
s = x.row_starts()
e = s + x.row_lengths() + 1 - len_frame
r = tf.ragged.range(s, e, shift_frame)
ind = r[:, :, tf.newaxis] + tf.range(0, len_frame, dtype="int64")
ret = tf.gather(x.values, ind)
print(ret)
# <tf.RaggedTensor [[[3, 1], [4, 1]], [], [[5, 9]], [], []]>

shift_frame = 1 でも大丈夫です。

サンプルが多次元の場合

例えばステレオ音声でLとRの値がペアで格納されているような場合が該当します。

x = tf.ragged.constant([[[3, 2], [1, 7], [4, 1], [1, 8]], [], [[5, 2], [9, 8], [2, 1]], [[6, 8]], []])

実は今までと全く同じ方法で動きます。

len_frame = 2
shift_frame = 1
s = x.row_starts()
e = s + x.row_lengths() + 1 - len_frame
r = tf.ragged.range(s, e, shift_frame)
ind = r[:, :, tf.newaxis] + tf.range(0, len_frame, dtype="int64")
ret = tf.gather(x.values, ind)
print(ret)
# <tf.RaggedTensor [[[[3, 2], [1, 7]], [[1, 7], [4, 1]], [[4, 1], [1, 8]]], [], [[[5, 2], [9, 8]], [[9, 8], [2, 1]]], [], []]>

次元数が増えすぎて、パッと見ただけでは合っているかどうか分からなくなってきましたが、大丈夫なはず…。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?