AIの要素技術について記述します。LSTM(Long Short-Term Memory)は、RNN が抱える長期依存関係の記憶の難しさ(時間的な情報の忘却)を改善したモデルです。
サンプルプログラム
lstm_numpy.py
入力
text = "hello lstm demo."
テキスト全体を整数インデックスに変換
data = [4, 3, 5, 5, 7, 0, 5, 8, 9, 6, 0, 2, 3, 6, 7, 1]
設定
隠れ状態ベクトル 32次元
lr 0.1(勾配に対して 0.1倍 のステップで重みを更新する)(学習率 learning rate)
乱数シード 42 (random seed)
学習の繰り返し回数 5 エポック
出力
1.生成開始文字 'h' を選ぶ
2.LSTM で20文字分のシーケンスを生成する
LSTM の 初期隠れ状態 h をゼロベクトルに設定
LSTM の 初期セル状態 c をゼロベクトルに設定
(softmax で求めた確率分布 p に従ってランダムに次の文字を選ぶ)
出力文字列
hthmeto mo st mse .dm
070_lstm_numpy.py
import numpy as np
def one_hot(idx, vocab_size):
v = np.zeros((vocab_size,), dtype=float)
v[idx] = 1.0
return v.reshape(-1, 1)
def softmax(x):
x = x - np.max(x, axis=0, keepdims=True)
e = np.exp(x)
return e / np.sum(e, axis=0, keepdims=True)
class LSTM:
"""
LSTM with one hidden layer.
h_t, c_t update:
i = σ(Wxi x_t + Whi h_{t-1} + bi)
f = σ(Wxf x_t + Whf h_{t-1} + bf)
o = σ(Wxo x_t + Who h_{t-1} + bo)
g = tanh(Wxg x_t + Whg h_{t-1} + bg)
c_t = f * c_{t-1} + i * g
h_t = o * tanh(c_t)
y_t = softmax(Why h_t + by)
"""
def __init__(self, input_dim, hidden_dim, output_dim, lr=0.1, seed=0):
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.lr = lr
rng = np.random.default_rng(seed)
# パラメータ初期化(Xavier近似)
def init(shape):
return rng.normal(0, 0.1, size=shape)
# 入力→ゲート
self.Wxi = init((hidden_dim, input_dim))
self.Wxf = init((hidden_dim, input_dim))
self.Wxo = init((hidden_dim, input_dim))
self.Wxg = init((hidden_dim, input_dim))
# 隠れ→ゲート
self.Whi = init((hidden_dim, hidden_dim))
self.Whf = init((hidden_dim, hidden_dim))
self.Who = init((hidden_dim, hidden_dim))
self.Whg = init((hidden_dim, hidden_dim))
# バイアス
self.bi = np.zeros((hidden_dim, 1))
self.bf = np.zeros((hidden_dim, 1))
self.bo = np.zeros((hidden_dim, 1))
self.bg = np.zeros((hidden_dim, 1))
# 出力層
self.Why = init((output_dim, hidden_dim))
self.by = np.zeros((output_dim, 1))
def forward(self, inputs, h_prev, c_prev):
"""
inputs: list of indices (sequence)
戻り: loss, caches, h_last, c_last
"""
xs, hs, cs, os, ps = {}, {}, {}, {}, {}
hs[-1] = np.copy(h_prev)
cs[-1] = np.copy(c_prev)
loss = 0
for t in range(len(inputs)-1):
x = one_hot(inputs[t], self.input_dim)
xs[t] = x
h_prev, c_prev = hs[t-1], cs[t-1]
i = self._sigmoid(self.Wxi@x + self.Whi@h_prev + self.bi)
f = self._sigmoid(self.Wxf@x + self.Whf@h_prev + self.bf)
o = self._sigmoid(self.Wxo@x + self.Who@h_prev + self.bo)
g = np.tanh(self.Wxg@x + self.Whg@h_prev + self.bg)
c = f * c_prev + i * g
h = o * np.tanh(c)
y = self.Why @ h + self.by
p = softmax(y)
target = inputs[t+1]
loss += -np.log(p[target, 0] + 1e-12)
hs[t], cs[t] = h, c
os[t], ps[t] = o, p
cache = (xs, hs, cs, os, ps, inputs)
return loss/ (len(inputs)-1), cache, hs[len(inputs)-2], cs[len(inputs)-2]
def sample(self, start_idx, length, h, c):
x = one_hot(start_idx, self.input_dim)
idxs = [start_idx]
for _ in range(length):
i = self._sigmoid(self.Wxi@x + self.Whi@h + self.bi)
f = self._sigmoid(self.Wxf@x + self.Whf@h + self.bf)
o = self._sigmoid(self.Wxo@x + self.Who@h + self.bo)
g = np.tanh(self.Wxg@x + self.Whg@h + self.bg)
c = f * c + i * g
h = o * np.tanh(c)
y = self.Why @ h + self.by
p = softmax(y).ravel()
idx = int(np.random.choice(len(p), p=p))
#idx = int(p.argmax())
x = one_hot(idx, self.input_dim)
idxs.append(idx)
return idxs
@staticmethod
def _sigmoid(z):
return 1.0 / (1.0 + np.exp(-z))
# ===== デモ: 小さなテキストで次文字予測 =====
if __name__ == "__main__":
text = "hello lstm demo."
chars = sorted(set(text))
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for ch,i in stoi.items()}
data = np.array([stoi[ch] for ch in text])
input_dim = len(chars)
hidden_dim = 32
output_dim = len(chars)
lstm = LSTM(input_dim, hidden_dim, output_dim, lr=0.1, seed=42)
h = np.zeros((hidden_dim,1))
c = np.zeros((hidden_dim,1))
# 学習ループ(超簡略: 勾配逆伝播省略、forwardのみでloss確認)
for epoch in range(5):
loss, cache, h, c = lstm.forward(data, h, c)
print(f"epoch {epoch+1}, loss={loss:.4f}")
# サンプリング
start = stoi["h"]
out_idx = lstm.sample(start, length=20, h=np.zeros_like(h), c=np.zeros_like(c))
print("Generated:", "".join(itos[i] for i in out_idx))
結果
epoch 1, loss=2.3020
epoch 2, loss=2.3032
epoch 3, loss=2.3032
epoch 4, loss=2.3032
epoch 5, loss=2.3032
Generated: heell.dh. to ohh.esms