はじめに
TensorFlow 2.1以降で導入された、可変長のデータを表す RaggedTensor ですが、普通のTensorのノリで書こうとすると色々ハマります。
今回はIndexing編。RaggedTensor
から特定のインデックスを指定して値を取り出したりしてみます。慣れてくると複雑な操作もできるようになる…はず。
検証環境
- Ubuntu 18.04
- Python 3.6.9
- TensorFlow 2.2.0 (CPU)
Indexingの例
Indexing対象の RaggedTensor
として、x
が以下のように作られているとします。
x = tf.RaggedTensor.from_row_lengths(tf.range(15), tf.range(1, 6))
print(x)
# <tf.RaggedTensor [[0], [1, 2], [3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13, 14]]>
列番号 | 0 | 1 | 2 | 3 | 4 |
---|---|---|---|---|---|
0行目 | 0 | ||||
1行目 | 1 | 2 | |||
2行目 | 3 | 4 | 5 | ||
3行目 | 6 | 7 | 8 | 9 | |
4行目 | 10 | 11 | 12 | 13 | 14 |
特定の行でスライシング
まずはある行を取り出す操作ですが、これは普通の Tensor
と同じです。numpy.ndarray
の感覚で考えてもよいでしょう。範囲指定した場合、**最初のインデックスを含み、最後のインデックスを含みません。**Python使いの方なら問題ないと思いますが。
print(x[2])
# tf.Tensor([3 4 5], shape=(3,), dtype=int32)
print(x[1:4])
# <tf.RaggedTensor [[1, 2], [3, 4, 5], [6, 7, 8, 9]]>
ただし numpy.ndarray
と異なり、飛び飛びの行を指定するスライシングは使えないようです。
# ndarrayに対してはこれでできる
print(x.numpy()[[1, 3]])
# [array([1, 2], dtype=int32) array([6, 7, 8, 9], dtype=int32)]
# Tensor/RaggedTensorに対しては使えない
print(x[[1, 3]])
# InvalidArgumentError: slice index 3 of dimension 0 out of bounds. [Op:StridedSlice] name: strided_slice/
代わりにこちらでどうぞ。
# Tensor/RaggedTensorでFancy Indexing
print(tf.gather(x, [1, 3], axis=0))
# <tf.RaggedTensor [[1, 2], [6, 7, 8, 9]]>
固定の列インデックスでスライシング
次に固定の列インデックスでスライシングする例です。
普通の Tensor
と違って、行によってそのインデックスの要素があるかないかが変わるので、単純に
print(x[:, 2])
# ValueError: Cannot index into an inner ragged dimension.
のようにはできないようになっています。範囲指定をすれば
print(x[:, 2:3])
# <tf.RaggedTensor [[], [], [5], [8], [12]]>
のように動きます。指定したインデックスが存在しない行に対しては []
となっています。
列番号 | 0 | 1 | 2 | 3 | 4 |
---|---|---|---|---|---|
0行目 | 0 | ||||
1行目 | 1 | 2 | |||
2行目 | 3 | 4 | 5 | ||
3行目 | 6 | 7 | 8 | 9 | |
4行目 | 10 | 11 | 12 | 13 | 14 |
行ごとに異なる列インデックスを指定してスライシング
集めたい2次元のインデックスを列挙した Tensor
がある場合は、tf.gather_nd()
が使えます。
ind = tf.constant([[0, 0], [1, 1], [2, 0], [4, 3]])
# x の (0, 0), (1, 1), (2, 0), (4, 3) 要素を集めたい
print(tf.gather_nd(x, ind))
# tf.Tensor([ 0 2 3 13], shape=(4,), dtype=int32)
列番号 | 0 | 1 | 2 | 3 | 4 |
---|---|---|---|---|---|
0行目 | 0 | ||||
1行目 | 1 | 2 | |||
2行目 | 3 | 4 | 5 | ||
3行目 | 6 | 7 | 8 | 9 | |
4行目 | 10 | 11 | 12 | 13 | 14 |
一方、行ごとに要素を1つずつ取ってくるのですが、それぞれ異なる列から取ってきたい、という場合もあると思います。
col = tf.constant([0, 0, 2, 1, 2])
# x の (0, 0), (1, 0), (2, 2), (3, 1), (4, 2) 要素を集めたい
# インデックスに行番号をつけてから、先程と同じ方法で
ind = tf.transpose(tf.stack([tf.range(tf.shape(col)[0]), col]))
print(tf.gather_nd(x, ind))
# tf.Tensor([ 0 1 5 7 12], shape=(5,), dtype=int32)
列番号 | 0 | 1 | 2 | 3 | 4 |
---|---|---|---|---|---|
0行目 | 0 | ||||
1行目 | 1 | 2 | |||
2行目 | 3 | 4 | 5 | ||
3行目 | 6 | 7 | 8 | 9 | |
4行目 | 10 | 11 | 12 | 13 | 14 |
でもなんとなく遅そうな気がするので、もう少し賢い方法を考えてみると
print(tf.gather(x.values, x.row_starts() + col))
# tf.Tensor([ 0 1 5 7 12], shape=(5,), dtype=int32)
これでOKです。
x
の値の実体は各行をつなげた(1次元少ない)Tensor
(RaggedTensor
ではありません)に入っていて、x.values
にアクセスすると取得できます。また、x
の形状を表すために各行の開始インデックス (x.row_starts()
) の情報が保持されています。よって、このインデックスに指定したオフセットを加え、x.values
に対してスライシングすればよいというわけです。
%timeit tf.gather_nd(x, tf.transpose(tf.stack([tf.range(tf.shape(col)[0]), col])))
# 739 µs ± 75.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit tf.gather(x.values, x.row_starts() + col)
# 124 µs ± 6.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
こっちのほうが速いですね(^_^)
このあたりの操作を極めようと思えば、公式ドキュメントを見るのが吉です。
列インデックスがRaggedTensorに入っている場合
先程の「値の実体が1次元の Tensor
に入っている」ということを応用します。
col = tf.ragged.constant([[0], [], [0, 2], [1, 3], [2]])
# x の (0, 0), (2, 0), (2, 2), (3, 1), (3, 3), (4, 2) 要素を集めたい
# xの各行の開始インデックスを得る
row_starts = tf.cast(x.row_starts(), "int32")
# colの各成分が属する行番号を得て、xにおける開始インデックスに変換し、オフセットを加える
ind_flat = tf.gather(row_starts, col.value_rowids()) + col.values
ret = tf.gather(x.values, ind_flat)
print(ret)
# tf.Tensor([ 0 3 5 7 9 12], shape=(6,), dtype=int32)
列番号 | 0 | 1 | 2 | 3 | 4 |
---|---|---|---|---|---|
0行目 | 0 | ||||
1行目 | 1 | 2 | |||
2行目 | 3 | 4 | 5 | ||
3行目 | 6 | 7 | 8 | 9 | |
4行目 | 10 | 11 | 12 | 13 | 14 |
元の行の情報を保存したい場合
先程の結果は値を並べた普通の Tensor
になっていて、元の行の情報は失われていますが、行の情報を保存したい場合はどうすればよいでしょう。
Tensor
に対して行の開始インデックスの情報を与えることで RaggedTensor
を作ることができます。各行の長さは col
と一致するはずなので、この開始インデックスを col.value_rowids()
から持ってくればよいですね。
print(tf.RaggedTensor.from_value_rowids(ret, col.value_rowids()))
# <tf.RaggedTensor [[0], [], [3, 5], [7, 9], [12]]>
対象のRaggedTensorが3次元以上の場合
2次元以上のデータが時系列に並んでいる(バッチ次元を含めて、RaggedTensor
としては3次元以上)場合でも、今までの方法がそのまま使えます。
x = tf.RaggedTensor.from_row_lengths(tf.reshape(tf.range(30), (15, 2)), tf.range(1, 6))
print(x)
# <tf.RaggedTensor [[[0, 1]], [[2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15], [16, 17], [18, 19]], [[20, 21], [22, 23], [24, 25], [26, 27], [28, 29]]]>
この x
の構造は以下のように解釈できます。
列番号 | 0 | 1 | 2 | 3 | 4 |
---|---|---|---|---|---|
0行目 | [0, 1] | ||||
1行目 | [2, 3] | [4, 5] | |||
2行目 | [6, 7] | [8, 9] | [10, 11] | ||
3行目 | [12, 13] | [14, 15] | [16, 17] | [18, 19] | |
4行目 | [20, 21] | [22, 23] | [24, 25] | [26, 27] | [28, 29] |
あとはこれまでと全く同じです。ただし、返ってくる Tensor
が2次元になっていることにご注意ください。
ind = tf.constant([[0, 0], [1, 1], [2, 0], [4, 3]])
# x の (0, 0), (1, 1), (2, 0), (4, 3) 要素を集めたい
print(tf.gather_nd(x, ind))
# tf.Tensor(
# [[ 0 1]
# [ 4 5]
# [ 6 7]
# [26 27]], shape=(4, 2), dtype=int32)
col = tf.constant([0, 0, 2, 1, 2])
# x の (0, 0), (1, 0), (2, 2), (3, 1), (4, 2) 要素を集めたい
print(tf.gather(x.values, x.row_starts() + col))
# tf.Tensor(
# [[ 0 1]
# [ 2 3]
# [10 11]
# [14 15]
# [24 25]], shape=(5, 2), dtype=int32)
col = tf.ragged.constant([[0], [], [0, 2], [1, 3], [2]])
# x の (0, 0), (2, 0), (2, 2), (3, 1), (3, 3), (4, 2) 要素を集めたい
# xの各行の開始インデックスを得る
row_starts = tf.cast(x.row_starts(), "int32")
# colの各成分が属する行番号を得て、xにおける開始インデックスに変換し、オフセットを加える
ind_flat = tf.gather(row_starts, col.value_rowids()) + col.values
ret = tf.gather(x.values, ind_flat)
print(ret)
# tf.Tensor(
# [[ 0 1]
# [ 6 7]
# [10 11]
# [14 15]
# [18 19]
# [24 25]], shape=(6, 2), dtype=int32)
# 元の行の情報を保存したい場合
print(tf.RaggedTensor.from_value_rowids(ret, col.value_rowids()))
# <tf.RaggedTensor [[[0, 1]], [], [[6, 7], [10, 11]], [[14, 15], [18, 19]], [[24, 25]]]>