LoginSignup
58
10

RWKVをjuliaで試してみる

Last updated at Posted at 2023-11-30

はじめに

こんにちは、RetailAIの@long10langです。さて、advent calendar今年は何書こう?というわけで、あれこれ考えた結果、あまり理解が進んでいなかったRWKVについてちゃんと調べてみようと思い立ちまして、それなら単にRWKVの説明をしたところで面白くないので、juliaでいっちょやってみっかということで、やって行きたいと思います。

目次

  1. RWKV
  2. 実装あれこれ
  3. 苦労話
  4. まとめ
  5. 参考文献

RWKVについて

やはり、今年はなんと言ってもChatGPT元年ということで、いろんなことがありました。もうすっかり市民権を獲得したと言っても過言ではないtransformerですが、その根幹をなすSelf-Attentionという仕組みは、もちろん画期的な発明ではありましたが、学習量がすんごいことになるため、一般人がそう簡単に手出しできるような代物ではありませんよね。

とはいえ、RNNベースのモデルだと、メモリと計算要件の面で線形にスケールするものの、並列化と拡張性の制限からtransformerと同等の性能を達成できるかといえばそれもまた困難です。

そこで,transformerの効率的な並列学習と、RNNの効率的な推論の両方を兼ね備えたモデルとしてRWKV(Receptance Weighted Key Value)という新たなモデルアーキテクチャーが提案されました。

詳しくは論文を参照して頂ければと思うのですが、ざっくり言ってしまえば、このモデルは、数百億のパラメータまでスケールする初の非Transformerアーキテクチャでありながら、同じサイズのTransformerと同等の性能を発揮することが論文内で示されています。

そんなわけで、現在GPTをはじめとしたtransformerベースのモデルよりも、高速に推論可能(?)なモデルとして注目されている、というわけです。

image.png

左がtransformerで、右がRWKVです。Channel MixingとTime Mixingが大きく異なる点ですが、そこに登場するモデル要素にちなんで、RWKVと呼ばれています。

  • R:過去の情報の受容度を表現するReceptanceベクトル。
  • W:位置の重み減衰ベクトル。訓練可能なモデルパラメータ。
  • K:一般的な注意機構におけるK(Key)に類似のベクトル。
  • V:一般的な注意機構におけるV(Value)に類似のベクトル。

めちゃくちゃざっくりと言ってしまえば、tranformerもRWKVも、どちらも計算上で、文章などの情報を理解する方法という点では一緒なのですが、tranformerはSelf-Attention、つまり文章のすべての部分に注意を払い、複雑な内容も理解できる反面たくさんの計算が必要になる。一方、RWKVはよりシンプルで、少ない計算で済むためサクサク動作します。というわけで、長い文章を扱うときRWKVの方が効率的ですが、transformerの方がより深く内容を理解できる、乱暴ですがそんな理解です。

実装あれこれ

それでは、四の五の言わず、juliaで実装を見ていきましょう。まずは、ネットワークから。

1.Layer Normalization (LN) 構造体と関数:層正規化を行うカスタム構造体とその動作を定義しています。

struct LN
    γ
    β
    ϵ
    dims
end
LN(γ, β) = LN(γ, β, 1e-5, 1)
LN(n::Integer) = LN(ones(Float32, n), zeros(Float32, n))
@Flux.functor LN
Flux.trainable(m::LN) = (γ=m.γ, β=m.β)
function (m::LN)(x::AbstractArray{T,N}) where {T,N}
    μ = mean(x, dims=m.dims)
    σ² = var(x, dims=m.dims)
    (T).(m.γ .* (x .- μ) ./ (σ² .+ m.ϵ).^0.5 .+ m.β)
end

2.次に、ユーティリティ関数:時間の混合、二乗ReLU、指数混合などの関数を定義していきます。

time_mix(x, x_prev, mix) = @. x * mix + x_prev * (1 - mix)
square_relu(x::T) where T = max(zero(T), x)^2
function exp_mix(v1::AbstractArray{T}, v2::AbstractArray{T}, p1::AbstractArray{T}, p2::AbstractArray{T}) where T
    p = max.(p1, p2)
    (@. exp(p1 - p) * v1 + exp(p2 - p) * v2, p)
end

function exp_selfadj(v1, v2, p)
    p_new = @. max(min(abs(asinh(v1)), p), min(abs(asinh(v2)), p))
    factor = @. exp(p - p_new)
    @. (v1*factor, v2*factor, p_new)
end

3.State構造体:ネットワークの状態を保持するための構造体を定義します。

mutable struct State
    x_tm  # token mixing
    x_cm  # channel mixing
    a # numerator
    b # denominator
    p # largest exponent seen
end

State(n_embed::Integer, n_layer::Integer) = begin
    dim = (n_embed, n_layer)
    State(zeros(Float32, dim), zeros(Float32, dim), zeros(Float32, dim), zeros(Float32, dim), zeros(Float32, dim))
end
@Flux.functor State

4.TokenMixing と ChannelMixing 構造体:トークンとチャネルの混合を行うための層を定義しています。

function recur_step(left::Vector, right::Vector; w)
    a_prev, b_prev, p_prev = left
    expkv, expk, p = right
    a_new, p_new = exp_mix(a_prev, expkv, p_prev .+ w, p)
    b_new, _ = exp_mix(b_prev, expk, p_prev .+ w, p)
    [a_new, b_new, p_new]
end

struct TokenMixing{T}
    Tₖ::AbstractArray{T, 1}
    Tᵥ::AbstractArray{T, 1}
    Tᵣ::AbstractArray{T, 1}
    r_proj
    k_proj
    v_proj
    out_proj 
    time_first::AbstractArray{T, 1}
    time_decay::AbstractArray{T, 1}  # <-- w
end

@Flux.functor TokenMixing

TokenMixing(n_embed::Integer) = TokenMixing(
    zeros(Float32, n_embed), # Tₖ
    zeros(Float32, n_embed), # Tᵥ
    zeros(Float32, n_embed), # Tᵣ
    Dense(n_embed, n_embed, bias=false), # r_proj
    Dense(n_embed, n_embed, bias=false), # k_proj
    Dense(n_embed, n_embed, bias=false), # v_proj
    Dense(n_embed, n_embed, bias=false), # out_proj
    zeros(Float32, n_embed), # time first
    ones(Float32, n_embed),  # time_decay
)

function (m::TokenMixing)(x::AbstractArray{T,2}, state::State; i) where T
    n_embed, n_seq = size(x)

    x_prev = hcat(@views(state.x_tm[:, i]), @views(x[:, 1:end-1]))

    xₖ = time_mix(x, x_prev, m.Tₖ)
    xᵥ = time_mix(x, x_prev, m.Tᵥ)
    xᵣ = time_mix(x, x_prev, m.Tᵣ)

    r = m.r_proj(xᵣ) .|> sigmoid
    k = m.k_proj(xₖ)
    v = m.v_proj(xᵥ)

    p = k
    expk = k.*0 .+ 1
    expkv = v

    step_f = (a, b) -> recur_step(a, b; w=m.time_decay)
    a_prev, b_prev, p_prev = @views(state.a[:, i]), @views(state.b[:, i]), @views(state.p[:, i])
    substrate = [[@views(expkv[:,i]), @views(expk[:,i]), @views(p[:,i])] for i = 1:n_seq]
    abp = accumulate(step_f, substrate; init=[a_prev, b_prev, p_prev])

    a_prev = batch([a_prev, [abp[i][1] for i = 1:n_seq-1]...])
    b_prev = batch([b_prev, [abp[i][2] for i = 1:n_seq-1]...])
    p_prev = batch([p_prev, [abp[i][3] for i = 1:n_seq-1]...])

    c, _ = exp_mix(a_prev, expkv, p_prev, p .+ m.time_first)
    d, _ = exp_mix(b_prev, expk,  p_prev, p .+ m.time_first)
    rwkv = @. r * c / d

    # update state
    @views state.x_tm[:, i] .= x[:, end]
    @views state.a[:, i] .= abp[end][1]
    @views state.b[:, i] .= abp[end][2]
    @views state.p[:, i] .= abp[end][3]

    m.out_proj(rwkv), state
end
    
# stateless for training
function (m::TokenMixing)(x::AbstractArray{T,3}) where T
    n_embed, n_batch, n_seq = size(x)

    @views x_prev = pad_zeros(x, (0,0,0,0,1,0))[:, :, 1:end-1]

    xₖ = time_mix(x, x_prev, m.Tₖ)
    xᵥ = time_mix(x, x_prev, m.Tᵥ)
    xᵣ = time_mix(x, x_prev, m.Tᵣ)

    r = m.r_proj(xᵣ) .|> sigmoid
    k = m.k_proj(xₖ)
    v = m.v_proj(xᵥ)

    p = k
    expk = zeros_like(k) .+ 1
    expkv = v

    step_f = (a, b) -> recur_step(a, b; w=m.time_decay)
    a_prev = b_prev = p_prev = zeros_like(k, eltype(v), (n_embed, n_batch))
    substrate = @views [[expkv[:,:,i], expk[:,:,i], p[:,:,i]] for i = 1:n_seq]
    abp = accumulate(step_f, substrate; init=[a_prev, b_prev, p_prev])

    a_prev = batch([a_prev, [abp[i][1] for i = 1:n_seq-1]...])
    b_prev = batch([b_prev, [abp[i][2] for i = 1:n_seq-1]...])
    p_prev = batch([p_prev, [abp[i][3] for i = 1:n_seq-1]...])

    c, _ = exp_mix(a_prev, expkv, p_prev, p .+ m.time_first)
    d, _ = exp_mix(b_prev, expk,  p_prev, p .+ m.time_first)
    rwkv = @. r * c / d

    m.out_proj(rwkv)
end

struct ChannelMixing{T}
    Tₖ::AbstractArray{T, 1}  # will be taken out in the future
    Tᵣ::AbstractArray{T, 1}  # will be taken out in the future
    r_proj
    k_proj
    v_proj
end

@Flux.functor ChannelMixing

ChannelMixing(n_embed::Integer) = ChannelMixing(
    zeros(Float32, n_embed), # Tₖ
    zeros(Float32, n_embed), # Tᵣ
    Dense(n_embed, n_embed, bias=false), # r_proj
    Dense(n_embed, n_embed*4, bias=false), # k_proj
    Dense(n_embed*4, n_embed, bias=false), # v_proj
)

function (m::ChannelMixing)(x::AbstractArray{T, 2}, state::State; i) where T
    n_embed, n_seq = size(x)

    x_prev = @views(state.x_cm[:, i])
    if size(x, 2) > 1
        x_prev = hcat(x_prev, @views(x[:, 1:end-1]))
    end
    xₖ = time_mix(x, x_prev, m.Tₖ)
    xᵣ = time_mix(x, x_prev, m.Tᵣ)

    r = m.r_proj(xᵣ) .|> sigmoid
    k = m.k_proj(xₖ) .|> square_relu

    # update state
    @views state.x_cm[:, i] .= x[:, end]

    r .* (m.v_proj(k)), state
end

function (m::ChannelMixing)(x::AbstractArray{T, 3}) where T
    # n_embed, n_batch, n_seq = size(x)
    @views x_prev = pad_zeros(x, (0,0,0,0,1,0))[:, :, 1:end-1]

    xₖ = time_mix(x, x_prev, m.Tₖ)
    xᵣ = time_mix(x, x_prev, m.Tᵣ)

    r = m.r_proj(xᵣ) .|> sigmoid
    k = m.k_proj(xₖ) .|> square_relu

    r .* (m.v_proj(k))
end

5.Block 構造体:LN、TokenMixing、ChannelMixingを組み合わせたブロック構造体を定義します。

struct Block
    ln1
    token_mixing
    ln2
    channel_mixing
end

@Flux.functor Block

Block(n_embed::Integer) = Block(
    LN(n_embed),
    TokenMixing(n_embed),
    LN(n_embed),
    ChannelMixing(n_embed),
)

function (m::Block)(x, state::State; i)
    xp, state = m.token_mixing(m.ln1(x), state; i=i)
    x = x + xp
    xp, state = m.channel_mixing(m.ln2(x), state; i=i)
    x = x + xp 
    x, state
end

function (m::Block)(x)
    xp = m.token_mixing(m.ln1(x))
    x = x + xp
    xp = m.channel_mixing(m.ln2(x))
    x = x + xp
    x
end

6.RWKVモデル:言語モデルのメイン構造体で、埋め込み層、LN層、複数のブロック、最終LN層、言語モデルのヘッドを含んでいます。

struct RWKV
    ln_init
    embedding
    blocks
    ln_final
    lm_head
end

@Flux.functor RWKV

RWKV(n_embed::Integer, n_blocks::Integer, n_vocab::Integer) = RWKV(
    Embedding(n_vocab, n_embed),
    LN(n_embed),
    [Block(n_embed) for _ in 1:n_blocks],
    LN(n_embed),
    Embedding(n_embed, n_vocab)
)

7.モデルのフォワードパス:状態を持つか持たないかに基づいて、モデルのフォワードパスを定義しています。

(m::RWKV)(x::AbstractArray{T, 1}, state::State) where T = begin
    x = m.embedding(x)
    x = m.ln_init(x)
    for i in 1:length(m.blocks)
        x, state = m.blocks[i](x, state; i=i)
    end
    x = m.ln_final(x)

    # x: [n_embed, n_seq]
    x = m.lm_head.weight' * x

    x, state
end
(m::RWKV)(x::Integer, state::State) = begin
    out, state = m([x], state)
    out[:, end], state
end

(m::RWKV)(x::AbstractArray{T, 2}) where T = begin
    x = m.embedding(x)
    # better performance
    x = permutedims(x, (1, 3, 2))
    x = m.ln_init(x)
    for i in 1:length(m.blocks)
        x = m.blocks[i](x)
    end
    x = m.ln_final(x)
    # better performance
    x = permutedims(x, (1, 3, 2))
    # x: [n_embed, n_seq, n_batch]
    batched_mul(m.lm_head.weight', x)
end

以上で、RWKVのネットワーク実装が終わりました。では、早速このネットワークを用いてトレーニングをしていきましょう。

include("data.jl") # <- データ
include("rwkv.jl") # <- 上で定義したネットワーク
include("utils.jl") # <- 下で定義するユーティリティ関数

using Flux.Losses
using OneHotArrays
using Dates, Logging
using CUDA
using BSON: @save
using Optimisers

function save_model(model, path)
    model = model |> cpu
    @save path model
end

cfg = (
    learning_rate = 3e-4,
    epochs = 5,
    dataset = "sample.jsonl",
    use_cuda = true,
    ctx_len = 256,
    batch_size = 8,
    log_per_nbatch = 5,
    save_per_nbatch = 1000,
    n_vocab = 50277,
    n_layer = 12,
    n_head = 12
)

data_file = cfg.dataset
tokenizer = get_tokenizer()
ts = TextSplitter(cfg.ctx_len+1, 16, tokenizer)

function loss_func(y_pred, y; n_vocab=cfg.n_vocab)
    return logitcrossentropy(y_pred, onehotbatch(y, 1:n_vocab))
end

if cfg.use_cuda && CUDA.functional()
    CUDA.allowscalar(false)
    @info "Using CUDA"
    device = gpu
else
    @info "Using CPU"
    device = identity
end

Flux.trainable(m::RWKV) = (blocks=m.blocks,)
Flux.trainable(m::Block) = (token_mixing=m.token_mixing,)

# device = identity  # debug
model = rwkv_from_pth("RWKV-4-Pile-169M-20220807-8023.pth"; cfg.n_layer) |> f32
model = model |> device

# setup optimizer
opt = Optimisers.setup(Optimisers.Adam(cfg.learning_rate), model)
i_batch = 1
for epoch in 1:cfg.epochs
    global i_batch
    batches = jsonl_reader(data_file) |> ch->batch_sampler(ch, ts; batch_size=cfg.batch_size)
    for (x, y) in batches
        x = x |> device
        y = y |> device
        val, grads = Flux.withgradient(model) do m
            y_pred = m(x)
            loss_func(y_pred, y)
        end
        Flux.update!(opt, model, grads[1])
        if i_batch % cfg.log_per_nbatch == 0
            @info "metrics" loss=val
            println("Ep $(epoch) Batch $(i_batch) Loss $(val)")
        end
        if i_batch % cfg.save_per_nbatch == 0
            @info "saving model"
            save_model(model, "rwkv-169m-$(now()).bson")
        end
        i_batch += 1
    end
end

ユーティリティ関数(utils.jl)がこちら↓

using PyCall
using Flux
using StatsBase

include("rwkv.jl")
torch = pyimport("torch")

function rwkv_from_pth(pth_path="RWKV-4-Pile-169M-20220807-8023.pth"; n_layer=12)
    data = torch.load(pth_path, "cpu")

    ln_init = LN(
        data["blocks.0.ln0.weight"].float().numpy(),
        data["blocks.0.ln0.bias"].float().numpy(),
    )
    ln_final = LN(
        data["ln_out.weight"].float().numpy(),
        data["ln_out.bias"].float().numpy()
    )
    embedding = Embedding(
        data["emb.weight"].float().numpy()'
    )
    lm_head = Embedding(
        data["head.weight"].float().numpy()'
    )
    blocks = []
    for i = 0:n_layer-1
        ln1 = LN(
            data["blocks.$i.ln1.weight"].float().numpy(),
            data["blocks.$i.ln1.bias"].float().numpy(),
        )
        ln2 = LN(
            data["blocks.$i.ln2.weight"].float().numpy(),
            data["blocks.$i.ln2.bias"].float().numpy(),
        )
        time_first = data["blocks.$i.att.time_first"].float().numpy()
        time_decay = -exp.(data["blocks.$i.att.time_decay"].float().numpy())
        token_mixing = TokenMixing(
            dropdims(data["blocks.$i.att.time_mix_k"].float().numpy(), dims=(1,2)),
            dropdims(data["blocks.$i.att.time_mix_v"].float().numpy(), dims=(1,2)),
            dropdims(data["blocks.$i.att.time_mix_r"].float().numpy(), dims=(1,2)),
            Dense(
                data["blocks.$i.att.receptance.weight"].float().numpy(),
                false
            ), # r_proj
            Dense(
                data["blocks.$i.att.key.weight"].float().numpy(),
                false
            ), # k_proj
            Dense(
                data["blocks.$i.att.value.weight"].float().numpy(),
                false
            ), # v_proj
            Dense(
                data["blocks.$i.att.output.weight"].float().numpy(),
                false
            ), # out_proj
            time_first,
            time_decay,
        )
        channel_mixing = ChannelMixing(
            dropdims(data["blocks.$i.ffn.time_mix_k"].float().numpy(), dims=(1,2)),
            dropdims(data["blocks.$i.ffn.time_mix_r"].float().numpy(), dims=(1,2)),
            Dense(
                data["blocks.$i.ffn.receptance.weight"].float().numpy(),
                false
            ), # r_proj
            Dense(
                data["blocks.$i.ffn.key.weight"].float().numpy(),
                false
            ), # k_proj
            Dense(
                data["blocks.$i.ffn.value.weight"].float().numpy(),
                false
            ), # v_proj
        )
        push!(blocks, Block(
            ln1,
            token_mixing,
            ln2,
            channel_mixing
        ))
    end
    RWKV(
        ln_init,
        embedding,
        blocks,
        ln_final,
        lm_head
    )
end

function get_tokenizer()
    py"""
    from tokenizers import Tokenizer

    tokenizer = Tokenizer.from_file('20B_tokenizer.json')
    """
    py"tokenizer"
end

function sample_logits(logits; temperature=1.0, top_p=0.9, use_argmax=false)
    if use_argmax
        return argmax(logits)
    end
    probs = softmax(logits; dims=1)
    sorted_probs = sort(probs; rev=true)
    cum_probs = cumsum(sorted_probs)
    cutoff = sorted_probs[argmax(cum_probs .> top_p)]
    probs = (probs .> cutoff) .* probs
    if temperature != 1.0
        probs .^= 1/temperature
    end
    probs = probs / sum(probs)

    sample(collect(1:length(probs)), ProbabilityWeights(probs)) |> Int
end

上記をトレーニングをして、推論を実行すると、結果が表示されます。

include("rwkv.jl")
include("utils.jl")

device = gpu
tokenizer = get_tokenizer()
model = rwkv_from_pth("RWKV-4-Pile-430M-20220808-8066.pth"; n_layer=24) |> device

function generate(model, prompt, n_tokens=50; top_p=0.99, temperature=1.0, use_argmax=false, device=identity)
    input_ids = tokenizer.encode(prompt).ids .+ 1

    state = State(size(model.embedding.weight, 1), length(model.blocks)) |> device

    out, state = model(input_ids[1:end-1], state)

    println("-------------------------")
    println(prompt)
    println("-------------------------")

    input_id = input_ids[end]
    for i = 1:n_tokens
        out, state = model(input_id, state)
        out_id = sample_logits(out |> cpu; top_p=top_p, temperature=temperature, use_argmax=use_argmax)
        print(tokenizer.decode([out_id-1]))
        input_id = out_id
    end
    println()
    state
end

prompt = "Japan is an economic powerhouse, but its growth rate is slowing. In the days ahead, Japan will";
generate(model, prompt, 100; top_p=0.999, device=device);

プロンプトとして、「Japan is an economic powerhouse, but its growth rate is slowing. In the days ahead, Japan will(日本は経済大国だが成長率は鈍化している。この先日本は・・・)」と渡してみたところ、返答として「 be dragged down by the decrease in exports and by the fall in the official growth rate to 6% in FY 2011. If the Japanese government continues its policies in this area, growth rates will drop below the previous rate in all of FY 2011( 輸出の減少と、2011年度の公式成長率が6%に低下したことが足を引っ張った。日本政府がこの分野の政策を続ければ、2011年度の成長率はすべて以前の成長率を下回るだろう。)」と、正しいのかよくわかりませんがそれっぽいのが返ってきました。

苦労話

RWKVのネットワーク実装については、RWKV-LMをほぼパクってますので、実態としてはpyCallしまくってます(笑)juliaを使うメリットは、ちょっとjuliaだけでは実装が厳しいとなっても、すぐPythonでなんとかできる点が良い点です。(いざとなればCが呼べるのもの嬉しいですね。)

ただ、PyCallの呼び出しで、実行させたいPythonのバージョンを指定するのに、以下のようにREPLで環境変数を指定してあげた後に、 (@hoge) pkg> build PyCall といった感じでビルドしてあげる必要があります。

               _
   _       _ _(_)_     |  Documentation: https://docs.julialang.org
  (_)     | (_) (_)    |
   _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.
  | | | | | | |/ _` |  |
  | | |_| | | | (_| |  |  Version 1.9.4 (2023-11-14)
 _/ |\__'_|_|_|\__'_|  |  Official https://julialang.org/ release
|__/                   |

julia> ENV["PYTHON"]="/home/long/.local/share/virtualenvs/XXXX/bin/python"

個人的には、juliaは、言語としてPythonよりも構文がはっきりと構造的だし、パイプ演算子があることで、状態変化の見通しがつきやすい点で、処理速度が速いというだけでなく、読みやすさという点でも、Pythonより上回ってると感じてます。どんどん使っていきたい。

まとめ

いかがでしたでしょうか?RWKVをより実装レベルで理解するために、Juliaで書き下してみることで、様々な気づきがありました。生成AI関連は来年もきっとすごいことになって行きそうですので、こんな感じで好きなプログラミング言語をあえて使いながら、面白がってキャッチアップしていきたいと思います。

さて明日は、サーバサイドエンジニアの@mizoguchi_ryosukeさんが、「初めてのpython cvxpy」というタイトルで発表します!ぜひ、お楽しみください〜

参考文献

58
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
58
10