Help us understand the problem. What is going on with this article?

TensorFlowにおけるクリッピングと正規化

概要

TensorFlowで扱うtensorが保持している各値のクリッピングについての記事です。
具体的にはtf.clip_by...系メソッドなどの説明と実装例になります。

そもそも何でこんな事するの?

用途は色々あるかと思いますが、TensorFlowの主たる使用目的である機械学習関連の計算(特に勾配計算)では、扱う変数のスケール感に違いが出てうまく計算できないということが多々あります。
そんな時に用いられるのがクリッピングや正規化になります。

TensorFlowの用語

TensorFlowでは従来の変数や定数、更に各種演算を表す処理をOpノードと呼んでいます(参考記事)
セッション内で実行(run)する事でOpノードは演算結果の値の集合(テンソル)を保持し次のOpノードに伝えていく仕組みになっています。
本記事ではこれに合わせて何らかの処理が終わった後の数値集合のことをノードと表記します。

clip_by系

何か決め打ちした基準に対して条件を満たす場合にノードの値を抑えるイメージ

tf.clip_by_value

tf.clip_by_value(
    t,
    clip_value_min,
    clip_value_max,
    name=None
)

ノードが保持する各値に対して最大値clip_value_maxより大きい値をclip_value_maxに、最小値clip_value_minより小さい値をclip_value_minに直します。

example1.py
p1 = tf.placeholder(tf.int32, 6, name='p1')
p2 = tf.placeholder(tf.float32, 6, name='p2')

clip_value1 = tf.clip_by_value(p1, clip_value_max=2, clip_value_min=-2, name='clip_value1')
clip_value2 = tf.clip_by_value(p2, clip_value_max=2., clip_value_min=-2., name='clip_value2')

num1 = np.linspace(-4, 6, 6)

with tf.Session() as sess:
    print(p1.eval(feed_dict={p1: num1}, session=sess))
    print(p2.eval(feed_dict={p2: num1}, session=sess))

    print(clip_value1.eval(feed_dict={p1: num1}, session=sess))
    print(clip_value2.eval(feed_dict={p2: num1}, session=sess))
console
[-4 -2  0  2  4  6]
[-4. -2.  0.  2.  4.  6.]

[-2 -2  0  2  2  2]
[-2. -2.  0.  2.  2.  2.]

ノードとclip_valueの型が一致していないとエラーを吐かれます。

example1.py
    print(clip_error1.eval(feed_dict={p1: num1}, session=sess))
console
TypeError: Expected int32 passed to parameter 'y' of op 'Minimum', got 2.0 of type 'float' instead.

tf.clip_by_norm

tf.clip_by_norm(
    t,
    clip_norm,
    axes=None,
    name=None
)

ノードのL2ノルムがclip_normよりも大きい場合このノルムに直すように各値を変更します。clip_normよりも小さい場合は変更されません。

example2.py
p3 = tf.placeholder(tf.float32, [2, 3], name='p3')

clip_norm1 = tf.clip_by_norm(p3, clip_norm=4, name='clip_norm1')
clip_norm2 = tf.clip_by_norm(p3, clip_norm=5, name='clip_norm2')

num2 = np.linspace(-2, 3, 6).reshape((2, 3))

with tf.Session() as sess:
    print(p3.eval(feed_dict={p3: num2}, session=sess))
    print(clip_norm1.eval(feed_dict={p3: num2}, session=sess))
    print(clip_norm2.eval(feed_dict={p3: num2}, session=sess))
console
[[-2. -1.  0.]
 [ 1.  2.  3.]]   # 全体のL2ノルムは4.358 ...

[[-1.8353258 -0.9176629  0.       ]
 [ 0.9176629  1.8353258  2.7529888]]

[[-2. -1.  0.]
 [ 1.  2.  3.]]

tf.clip_by_normではaxesを指定できます。
axesで指定した軸ごとのL2ノルムで値を正規化します。

example3.py
clip_norm3 = tf.clip_by_norm(p3, clip_norm=3, axes=1, name='clip_norm3')

with tf.Session() as sess:
    print(p3.eval(feed_dict={p3: num2}, session=sess))
    print(clip_norm3.eval(feed_dict={p3: num2}, session=sess))
console
[[-2. -1.  0.]    # 0列目のL2ノルムは2.236 ...
 [ 1.  2.  3.]]   # 1列目のL2ノルムは3.741 ...

[[-2.        -1.         0.       ]
 [ 0.8017837  1.6035674  2.4053512]]

なお、tf.clip_by_normは噛ませるノードが小数点を扱えないとTypeErrorとなります。
float◯◯やcomplex◯◯型を使ってください。

tf.clip_by_global_norm

tf.clip_by_global_norm(
    t_list,
    clip_norm,
    use_norm=None,
    name=None
)

tf.clip_by_normと違ってノードではなくノードのリストを渡します。
ノード自体を渡すとTypeErrorになります。

リストに格納されているノード全体でのL2ノルムをglobal_normとし、この値がclip_normよりも大きい場合L2ノルムがclip_normになるようにリスト内の全ての値を変更します。clip_normよりも小さい場合は変更されません。

また、返値はクリッピングした後のノードが格納されたリストlist_clippedと計算されたglobal_normの2つです。

example4.py
c1 = tf.constant([[0, 1, 2], [3, 4, 5]], dtype=tf.float32, name='c1')
c2 = tf.constant([[-2, -4], [2, 4]], dtype=tf.float32, name='c2')
C = [c1, c2]

clip_global_norm, global_norm = tf.clip_by_global_norm(C, clip_norm=9, name='clip_global_norm')

with tf.Session() as sess:
    for c in C:
        print(c.eval(session=sess))
    print(global_norm.eval(session=sess))
    for cgn in clip_global_norm1:
        print(cgn.eval(session=sess))
console
[[0. 1. 2.]
 [3. 4. 5.]]
[[-2. -4.]
 [ 2.  4.]]

9.746795
[[0.        0.9233805 1.846761 ]
 [2.7701416 3.693522  4.6169024]]
[[-1.846761 -3.693522]
 [ 1.846761  3.693522]]

tf.clip_by_normtf.clip_by_global_normメソッド自体は単純ではありますが、たとえばRNNにおける勾配爆発の対策「勾配クリッピング」に対応するために用いることができます。

以下が参考になります。

モデルを構築した後、いざ学習させようと思ったら誤差の伝播計算などでinfにぶっ飛ぶ時などはこの方法を取ると解決するかもしれません。

おまけ

クリッピングは実際にノードの保有値を変更していましたが、ノルムの計算のみする事もできます。

tf.norm

tf.norm(
    tensor,
    ord='euclidean',
    axis=None,
    keepdims=None,
    name=None
)

パラメタordでLpノルムのpの値を決定します。
L∞ノルムの場合はnp.infを指定します。

example4.py
p4 = tf.placeholder(tf.float32, [3, 4], name='p4')

normalize1 = tf.norm(p4, name='normalize1')
normalize2 = tf.norm(p4, ord=1.5, axis=0, name='normalize2')
normalize3 = tf.norm(p4, ord=np.inf, axis=1, name='normalize3')

num3 = np.linspace(-10, 8, 12).reshape((3, 4))

with tf.Session() as sess:
    print(p4.eval(feed_dict={p4: num3}, session=sess))
    print(normalize1.eval(feed_dict={p4: num3}, session=sess))
    print(normalize2.eval(feed_dict={p4: num3}, session=sess))
    print(normalize3.eval(feed_dict={p4: num3}, session=sess))
console
[[-10.          -8.363636    -6.7272725   -5.090909  ]
 [ -3.4545455   -1.8181819   -0.18181819   1.4545455 ]
 [  3.090909     4.7272725    6.3636365    8.        ]]

19.87232

[12.364525  11.0871725 10.408293  10.876119 ]

[10.         3.4545455  8.       ]

参考

TensorFlow > API > TensorFlow Core r2.0 > Python > tf.cilp_by_value
TensorFlow > API > TensorFlow Core r2.0 > Python > tf.clip_by_norm
TensorFlow > API > TensorFlow Core r2.0 > Python > tf.clip_by_global_norm
TensorFlow > API > TensorFlow Core r2.0 > Python > tf.norm


明日は、@yoshishinさんによる「Rubyを使って予約システムをハックする」です。
引き続き、GMOアドマーケティング Advent Calendar 2019をお楽しみください!

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした