AIの要素技術について記述します。
理解
ソフトマックス関数は、最終出力層で「確率」に変換するために使われる特別な活性化関数
シグモイド関数を多次元に拡張
多クラス分類問題において、ニューラルネットワークの出力を確立分布に変換する
各成分は区間 (0, 1) に収まり、全ての成分の和が 1 になるため、「確率」として解釈できる
入力ベクトル z の各成分 zi に自然指数関数を適用し、これらすべての指数の合計で割ることによって、値を正規化する
K = 2(二値分類問題)、z = z1 – z2 と置くと、標準シグモイド関数になる
サンプルプログラム
softmax_demo.py
040_softmax_demo.py
import numpy as np
def softmax(x, axis=-1):
"""
数値安定化付きソフトマックス
x: np.ndarray(任意形状)
axis: ソフトマックスを取る軸(クラス軸)
"""
x = np.asarray(x, dtype=float)
x_shift = x - np.max(x, axis=axis, keepdims=True) # ★安定化:最大値を引く
e = np.exp(x_shift)
return e / np.sum(e, axis=axis, keepdims=True)
def main():
rng = np.random.default_rng(42)
# 1) ベクトル入力(クラス数=5)
logits_1d = rng.normal(loc=0.0, scale=2.0, size=5)
probs_1d = softmax(logits_1d)
print("==== 1D ====")
print("logits:", np.round(logits_1d, 3))
print("probs :", np.round(probs_1d, 6))
print("sum :", probs_1d.sum(), "\n") # → 1.0 になる
# 2) バッチ入力(バッチ=3, クラス=4)
logits_2d = rng.normal(size=(3, 4)) * 3.0 # ばらつきを少し大きく
probs_2d = softmax(logits_2d, axis=1)
print("==== 2D (batch) ====")
print("logits:\n", np.round(logits_2d, 3))
print("probs :\n", np.round(probs_2d, 6))
print("row-wise sum:", np.round(probs_2d.sum(axis=1), 6), "\n") # 各行=1.0
# 3) 大きな値でも安定かチェック(安定化がないとオーバーフローしやすい)
big_logits = np.array([1000.0, 1001.0, 999.0])
print("==== large logits (stability check) ====")
print("softmax:", np.round(softmax(big_logits), 6)) # 問題なく計算できる
# 4) 温度パラメータ(任意):分布の鋭さを調整
T = 0.5 # T<1 で鋭く、T>1 でなだらか
probs_temp = softmax(logits_1d / T)
print("\n==== temperature (T=0.5) ====")
print("probs (T=0.5):", np.round(probs_temp, 6))
if __name__ == "__main__":
main()
結果
==== 1D ====
logits: [ 0.609 -2.08 1.501 1.881 -3.902]
probs : [0.141153 0.009587 0.344231 0.503478 0.00155 ]
sum : 0.9999999999999999
==== 2D (batch) ====
logits:
[[-3.907 0.384 -0.949 -0.05 ]
[-2.559 2.638 2.333 0.198]
[ 3.382 1.403 -2.578 1.106]]
probs :
[[0.007117 0.519333 0.137043 0.336507]
[0.003023 0.546466 0.402886 0.047626]
[0.804175 0.111121 0.002076 0.082628]]
row-wise sum: [1. 1. 1.]
==== large logits (stability check) ====
softmax: [0.244728 0.665241 0.090031]
==== temperature (T=0.5) ====
probs (T=0.5): [5.08270e-02 2.34000e-04 3.02280e-01 6.46652e-01 6.00000e-06]