0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

AI要素④ ソフトマックス関数

Posted at

AIの要素技術について記述します。
 

理解

ソフトマックス関数は、最終出力層で「確率」に変換するために使われる特別な活性化関数
シグモイド関数を多次元に拡張
多クラス分類問題において、ニューラルネットワークの出力を確立分布に変換する
各成分は区間 (0, 1) に収まり、全ての成分の和が 1 になるため、「確率」として解釈できる

入力ベクトル z の各成分 zi に自然指数関数を適用し、これらすべての指数の合計で割ることによって、値を正規化する

K = 2(二値分類問題)、z = z1 – z2 と置くと、標準シグモイド関数になる

サンプルプログラム

softmax_demo.py

040_softmax_demo.py
import numpy as np

def softmax(x, axis=-1):
    """
    数値安定化付きソフトマックス
    x: np.ndarray(任意形状)
    axis: ソフトマックスを取る軸(クラス軸)
    """
    x = np.asarray(x, dtype=float)
    x_shift = x - np.max(x, axis=axis, keepdims=True)  # ★安定化:最大値を引く
    e = np.exp(x_shift)
    return e / np.sum(e, axis=axis, keepdims=True)

def main():
    rng = np.random.default_rng(42)

    # 1) ベクトル入力(クラス数=5)
    logits_1d = rng.normal(loc=0.0, scale=2.0, size=5)
    probs_1d = softmax(logits_1d)
    print("==== 1D ====")
    print("logits:", np.round(logits_1d, 3))
    print("probs :", np.round(probs_1d, 6))
    print("sum   :", probs_1d.sum(), "\n")  # → 1.0 になる

    # 2) バッチ入力(バッチ=3, クラス=4)
    logits_2d = rng.normal(size=(3, 4)) * 3.0  # ばらつきを少し大きく
    probs_2d = softmax(logits_2d, axis=1)
    print("==== 2D (batch) ====")
    print("logits:\n", np.round(logits_2d, 3))
    print("probs :\n", np.round(probs_2d, 6))
    print("row-wise sum:", np.round(probs_2d.sum(axis=1), 6), "\n")  # 各行=1.0

    # 3) 大きな値でも安定かチェック(安定化がないとオーバーフローしやすい)
    big_logits = np.array([1000.0, 1001.0, 999.0])
    print("==== large logits (stability check) ====")
    print("softmax:", np.round(softmax(big_logits), 6))  # 問題なく計算できる

    # 4) 温度パラメータ(任意):分布の鋭さを調整
    T = 0.5  # T<1 で鋭く、T>1 でなだらか
    probs_temp = softmax(logits_1d / T)
    print("\n==== temperature (T=0.5) ====")
    print("probs (T=0.5):", np.round(probs_temp, 6))

if __name__ == "__main__":
    main()

結果
==== 1D ====
logits: [ 0.609 -2.08   1.501  1.881 -3.902]
probs : [0.141153 0.009587 0.344231 0.503478 0.00155 ]
sum   : 0.9999999999999999

==== 2D (batch) ====
logits:
 [[-3.907  0.384 -0.949 -0.05 ]
 [-2.559  2.638  2.333  0.198]
 [ 3.382  1.403 -2.578  1.106]]
probs :
 [[0.007117 0.519333 0.137043 0.336507]
 [0.003023 0.546466 0.402886 0.047626]
 [0.804175 0.111121 0.002076 0.082628]]
row-wise sum: [1. 1. 1.]

==== large logits (stability check) ====
softmax: [0.244728 0.665241 0.090031]

==== temperature (T=0.5) ====
probs (T=0.5): [5.08270e-02 2.34000e-04 3.02280e-01 6.46652e-01 6.00000e-06]

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?