1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

[TensorFlow] RaggedTensorに対するIndexingを極めたい

Last updated at Posted at 2020-07-19

はじめに

TensorFlow 2.1以降で導入された、可変長のデータを表す RaggedTensor ですが、普通のTensorのノリで書こうとすると色々ハマります。
今回はIndexing編。RaggedTensor から特定のインデックスを指定して値を取り出したりしてみます。慣れてくると複雑な操作もできるようになる…はず。

検証環境

  • Ubuntu 18.04
  • Python 3.6.9
  • TensorFlow 2.2.0 (CPU)

Indexingの例

Indexing対象の RaggedTensor として、x が以下のように作られているとします。

x = tf.RaggedTensor.from_row_lengths(tf.range(15), tf.range(1, 6))
print(x)
# <tf.RaggedTensor [[0], [1, 2], [3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13, 14]]>
列番号 0 1 2 3 4
0行目 0
1行目 1 2
2行目 3 4 5
3行目 6 7 8 9
4行目 10 11 12 13 14

特定の行でスライシング

まずはある行を取り出す操作ですが、これは普通の Tensor と同じです。numpy.ndarray の感覚で考えてもよいでしょう。範囲指定した場合、**最初のインデックスを含み、最後のインデックスを含みません。**Python使いの方なら問題ないと思いますが。

print(x[2])
# tf.Tensor([3 4 5], shape=(3,), dtype=int32)
print(x[1:4])
# <tf.RaggedTensor [[1, 2], [3, 4, 5], [6, 7, 8, 9]]>

ただし numpy.ndarray と異なり、飛び飛びの行を指定するスライシングは使えないようです。

# ndarrayに対してはこれでできる
print(x.numpy()[[1, 3]])                                                                                                                    
# [array([1, 2], dtype=int32) array([6, 7, 8, 9], dtype=int32)]

# Tensor/RaggedTensorに対しては使えない
print(x[[1, 3]])
# InvalidArgumentError: slice index 3 of dimension 0 out of bounds. [Op:StridedSlice] name: strided_slice/

代わりにこちらでどうぞ。

# Tensor/RaggedTensorでFancy Indexing
print(tf.gather(x, [1, 3], axis=0))
# <tf.RaggedTensor [[1, 2], [6, 7, 8, 9]]>

固定の列インデックスでスライシング

次に固定の列インデックスでスライシングする例です。
普通の Tensor と違って、行によってそのインデックスの要素があるかないかが変わるので、単純に

print(x[:, 2])
# ValueError: Cannot index into an inner ragged dimension.

のようにはできないようになっています。範囲指定をすれば

print(x[:, 2:3])
# <tf.RaggedTensor [[], [], [5], [8], [12]]>

のように動きます。指定したインデックスが存在しない行に対しては [] となっています。

列番号 0 1 2 3 4
0行目 0
1行目 1 2
2行目 3 4 5
3行目 6 7 8 9
4行目 10 11 12 13 14

行ごとに異なる列インデックスを指定してスライシング

集めたい2次元のインデックスを列挙した Tensor がある場合は、tf.gather_nd() が使えます。

ind = tf.constant([[0, 0], [1, 1], [2, 0], [4, 3]])
# x の (0, 0), (1, 1), (2, 0), (4, 3) 要素を集めたい
print(tf.gather_nd(x, ind))
# tf.Tensor([ 0  2  3 13], shape=(4,), dtype=int32)
列番号 0 1 2 3 4
0行目 0
1行目 1 2
2行目 3 4 5
3行目 6 7 8 9
4行目 10 11 12 13 14

一方、行ごとに要素を1つずつ取ってくるのですが、それぞれ異なる列から取ってきたい、という場合もあると思います。

col = tf.constant([0, 0, 2, 1, 2])
# x の (0, 0), (1, 0), (2, 2), (3, 1), (4, 2) 要素を集めたい
# インデックスに行番号をつけてから、先程と同じ方法で
ind = tf.transpose(tf.stack([tf.range(tf.shape(col)[0]), col]))
print(tf.gather_nd(x, ind))
# tf.Tensor([ 0  1  5  7 12], shape=(5,), dtype=int32)
列番号 0 1 2 3 4
0行目 0
1行目 1 2
2行目 3 4 5
3行目 6 7 8 9
4行目 10 11 12 13 14

でもなんとなく遅そうな気がするので、もう少し賢い方法を考えてみると

print(tf.gather(x.values, x.row_starts() + col))
# tf.Tensor([ 0  1  5  7 12], shape=(5,), dtype=int32)

これでOKです。
x の値の実体は各行をつなげた(1次元少ない)TensorRaggedTensor ではありません)に入っていて、x.values にアクセスすると取得できます。また、x の形状を表すために各行の開始インデックス (x.row_starts()) の情報が保持されています。よって、このインデックスに指定したオフセットを加え、x.values に対してスライシングすればよいというわけです。

%timeit tf.gather_nd(x, tf.transpose(tf.stack([tf.range(tf.shape(col)[0]), col])))
# 739 µs ± 75.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit tf.gather(x.values, x.row_starts() + col)                                                                                           
# 124 µs ± 6.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

こっちのほうが速いですね(^_^)

このあたりの操作を極めようと思えば、公式ドキュメントを見るのが吉です。

列インデックスがRaggedTensorに入っている場合

先程の「値の実体が1次元の Tensor に入っている」ということを応用します。

col = tf.ragged.constant([[0], [], [0, 2], [1, 3], [2]])
# x の (0, 0), (2, 0), (2, 2), (3, 1), (3, 3), (4, 2) 要素を集めたい

# xの各行の開始インデックスを得る
row_starts = tf.cast(x.row_starts(), "int32")
# colの各成分が属する行番号を得て、xにおける開始インデックスに変換し、オフセットを加える
ind_flat = tf.gather(row_starts, col.value_rowids()) + col.values
ret = tf.gather(x.values, ind_flat)
print(ret)
# tf.Tensor([ 0  3  5  7  9 12], shape=(6,), dtype=int32)
列番号 0 1 2 3 4
0行目 0
1行目 1 2
2行目 3 4 5
3行目 6 7 8 9
4行目 10 11 12 13 14

元の行の情報を保存したい場合

先程の結果は値を並べた普通の Tensor になっていて、元の行の情報は失われていますが、行の情報を保存したい場合はどうすればよいでしょう。
Tensor に対して行の開始インデックスの情報を与えることで RaggedTensor を作ることができます。各行の長さは col と一致するはずなので、この開始インデックスを col.value_rowids() から持ってくればよいですね。

print(tf.RaggedTensor.from_value_rowids(ret, col.value_rowids()))
# <tf.RaggedTensor [[0], [], [3, 5], [7, 9], [12]]>

対象のRaggedTensorが3次元以上の場合

2次元以上のデータが時系列に並んでいる(バッチ次元を含めて、RaggedTensor としては3次元以上)場合でも、今までの方法がそのまま使えます。

x = tf.RaggedTensor.from_row_lengths(tf.reshape(tf.range(30), (15, 2)), tf.range(1, 6))
print(x)
# <tf.RaggedTensor [[[0, 1]], [[2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15], [16, 17], [18, 19]], [[20, 21], [22, 23], [24, 25], [26, 27], [28, 29]]]>

この x の構造は以下のように解釈できます。

列番号 0 1 2 3 4
0行目 [0, 1]
1行目 [2, 3] [4, 5]
2行目 [6, 7] [8, 9] [10, 11]
3行目 [12, 13] [14, 15] [16, 17] [18, 19]
4行目 [20, 21] [22, 23] [24, 25] [26, 27] [28, 29]

あとはこれまでと全く同じです。ただし、返ってくる Tensor が2次元になっていることにご注意ください。

ind = tf.constant([[0, 0], [1, 1], [2, 0], [4, 3]])
# x の (0, 0), (1, 1), (2, 0), (4, 3) 要素を集めたい
print(tf.gather_nd(x, ind))
# tf.Tensor(
# [[ 0  1]
#  [ 4  5]
#  [ 6  7]
#  [26 27]], shape=(4, 2), dtype=int32)
col = tf.constant([0, 0, 2, 1, 2])
# x の (0, 0), (1, 0), (2, 2), (3, 1), (4, 2) 要素を集めたい
print(tf.gather(x.values, x.row_starts() + col))
# tf.Tensor(
# [[ 0  1]
#  [ 2  3]
#  [10 11]
#  [14 15]
#  [24 25]], shape=(5, 2), dtype=int32)
col = tf.ragged.constant([[0], [], [0, 2], [1, 3], [2]])
# x の (0, 0), (2, 0), (2, 2), (3, 1), (3, 3), (4, 2) 要素を集めたい

# xの各行の開始インデックスを得る
row_starts = tf.cast(x.row_starts(), "int32")
# colの各成分が属する行番号を得て、xにおける開始インデックスに変換し、オフセットを加える
ind_flat = tf.gather(row_starts, col.value_rowids()) + col.values
ret = tf.gather(x.values, ind_flat)
print(ret)
# tf.Tensor(
# [[ 0  1]
#  [ 6  7]
#  [10 11]
#  [14 15]
#  [18 19]
#  [24 25]], shape=(6, 2), dtype=int32)

# 元の行の情報を保存したい場合
print(tf.RaggedTensor.from_value_rowids(ret, col.value_rowids()))
# <tf.RaggedTensor [[[0, 1]], [], [[6, 7], [10, 11]], [[14, 15], [18, 19]], [[24, 25]]]>
1
0
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?