Ben's Blog


RWKV Language Model Math

In-depth explanation of the math behind the RWKV model, with PyTorch implementations, plus a discussion of numerical stability.

Jun 16, 2023

Lately I’ve found myself spending a lot of time messing around with the RWKV model. It’s a cool model, but it’s a bit more involved to wrap my head around than vanilla transfomers or their variants. I found this blog to be quite helpful for understanding the mechanics, as well as the corresponding simplified inference implementation here.

In this post, I write out the equations for the core WKV part of the RWKV model, and derive two numerically stable versions - one following the official implementation, another by transforming the state variables to log space - and provide implementations for each in PyTorch. Additionally, I derive the gradients for the log-space version, and provide Triton kernels for training a numerically-stable RWKV model.

In most cases, the gradients were verified with Wolfram Alpha, although there may be a typo in the math. The PyTorch implementations are verified by comparing the manual implementation of the backward pass with the autograd implementation. See this repo to check out the full code and unit tests.

Math

This section covers the basic math concepts for the WKV operator. If you’re already familiar with the math, you can skip to the PyTorch implementation or the next section, which extends the vanilla computation to be numerically stable. Additionally, the gradients for this computation are derived in a further section.

The main “attention” component in the RWKV model is the WKV computation. The equation for this is:

wkvi=eu+kivi+j=1i1e(i1j)w+kjvjeu+ki+j=1i1e(i1j)w+kj \text{wkv}_i = \frac{ e^{u+k_i} v_i + \sum_{j=1}^{i-1} e^{-(i-1-j)w+k_j} v_j}{e^{u+k_i} + \sum_{j=1}^{i-1} e^{-(i-1-j)w+k_j} }

We can rewrite this using two recursive state variables for the numerator and denominator, which we’ll call αi\alpha_i and βi\beta_i respectively:

αi=j=1ie(ij)w+kjvj=ewαi1+ekiviβi=j=1ie(ij)w+kj=ewβi1+eki \begin{aligned} \alpha_i & = \sum_{j=1}^i e^{-(i-j)w+k_j} v_j \\ & = e^{-w} \alpha_{i-1} + e^{k_i} v_i \\[1em] \beta_i & = \sum_{j=1}^i e^{-(i-j)w+k_j} \\ & = e^{-w} \beta_{i - 1} + e^{k_i} \\ \end{aligned}

We can then rewrite the WKV computation as:

wkvi=eu+kivi+αi1eu+ki+βi1\text{wkv}_i = \frac{ e^{u+k_i} v_i + \alpha_{i - 1} }{ e^{u+k_i} + \beta_{i - 1} }

PyTorch Implementation

A pure-PyTorch implementation of the above WKV computation is below:

def wkv_vanilla_forward(
    w: Tensor,
    u: Tensor,
    k: Tensor,
    v: Tensor,
    state: Tensor,
) -> tuple[Tensor, Tensor]:
    bsz, tsz, chans = k.shape
    assert w.shape == u.shape == (chans,)
    assert v.shape == (bsz, tsz, chans)
    assert state.shape == (bsz, 2, 1, chans)

    alpha, beta = state[:, :, -1].chunk(2, dim=1)  # (B, 1, D), (B, 1, D)

    ew = torch.exp(-w)

    wkvs = []
    alphas = [alpha]
    betas = [beta]

    for t in range(tsz):
        kt, vt = k[:, t : t + 1], v[:, t : t + 1]
        euk = torch.exp(u + kt)
        wkv = (alpha + euk * vt) / (beta + euk)
        wkvs.append(wkv)

        ek = torch.exp(kt)
        alpha = ew * alpha + ek * vt
        beta = ew * beta + ek

        alphas.append(alpha)
        betas.append(beta)

    alpha = torch.stack(alphas, dim=2)
    beta = torch.stack(betas, dim=2)

    return torch.cat(wkvs, 1), torch.cat((alpha, beta), dim=1)

Numerical Stability

This section extends the vanilla WKV computation discussed above to be numerically stable by adding a scaling factor to the exponent. The PyTorch implementation might be easier for some readers to follow. The variable names in the code follow the same convention as the math equations. This section explains the numerical stability approach used in the official implementation. The next section explains an alternative approach that uses log-space state variables to achieve numerical stability instead.

With traditional RNNs, there’s a common problem of exploding or vanishing gradients, if the determinant of Jacobian of the hidden state variable is not close to 1. This is because, for long sequences, the same matrix is applied many times, exacerbating any deviation from 1.

With the RWKV model, if the values of ww and kik_i are large, the exponent can grow beyond the numerical limits of our floating point type. We can solve this using another state variable, which we’ll call ϵi\epsilon_i:

αi=eϵiαi=ewϵiαi1+ekiϵivi=ew+ϵi1ϵiαi1+ekiϵiviβi=eϵiβi=ewϵiβi1+ekiϵi=ew+ϵi1ϵiβi1+ekiϵi \begin{aligned} \alpha_i' & = e^{-\epsilon_i} \alpha_i \\ & = e^{-w - \epsilon_i} \alpha_{i - 1} + e^{k_i - \epsilon_i} v_i \\ & = e^{-w + \epsilon_{i - 1} - \epsilon_i} \alpha_{i - 1}' + e^{k_i - \epsilon_i} v_i \\[1em] \beta_i' & = e^{-\epsilon_i} \beta_i \\ & = e^{-w - \epsilon_i} \beta_{i - 1} + e^{k_i - \epsilon_i} \\ & = e^{-w + \epsilon_{i - 1} - \epsilon_i} \beta_{i - 1}' + e^{k_i - \epsilon_i} \\ \end{aligned}

This allows us to rewrite the WKV computation as:

wkvi=eu+kivi+eϵi1αi1eu+ki+eϵi1βi1=eu+kiϵi1vi+αi1eu+kiϵi1+βi1 \begin{aligned} \text{wkv}_i & = \frac{ e^{u+k_i} v_i + e^{\epsilon_{i - 1}}\alpha_{i - 1}' }{ e^{u+k_i} + e^{\epsilon_{i - 1}}\beta_{i - 1}' } \\ & = \frac{ e^{u+k_i-\epsilon_{i - 1}} v_i + \alpha_{i - 1}' }{ e^{u+k_i-\epsilon_{i - 1}} + \beta_{i - 1}' } \\ \end{aligned}

We can add an additional scaling factor τi\tau_i and multiply by eτieτi\frac{e^{-\tau_i}}{e^{-\tau_i}} to get:

wkvi=eu+kiτivi+eϵi1τiαi1eu+kiτi+eϵi1τiβi1\text{wkv}_i = \frac{ e^{u+k_i-\tau_i} v_i + e^{\epsilon_{i - 1}-\tau_i}\alpha_{i - 1}' }{ e^{u+k_i-\tau_i} + e^{\epsilon_{i - 1}-\tau_i}\beta_{i - 1}' }

The value of ϵi\epsilon_i is arbitrary, and since we want to keep ew+ϵi1ϵie^{w + \epsilon_{i - 1} - \epsilon_i} and ekiϵie^{k_i - \epsilon_i} less than 1 to prevent it growing really large, we can set it as:

ϵi=max(w+ϵi1,ki)\epsilon_{i} = \max(-w + \epsilon_{i - 1}, k_i)

Then, to keep eu+kiτie^{u + k_i - \tau_i} and eϵi1τie^{\epsilon_{i - 1} - \tau_i} less than 1, we can set τi\tau_i as:

τi=max(u+ki,ϵi1)\tau_i = \max(u + k_i, \epsilon_{i - 1})

So ultimately we have three state variables:

αi\alpha_i' and βi\beta_i' are accumulated over time, while ϵi\epsilon_i is just passed to the subsequent step (in other words, wkvi\text{wkv}_i depends on ϵi1\epsilon_{i-1}, but ϵi\epsilon_i doesn’t).

PyTorch Implementation

The PyTorch implementation of the more stable form of the WKV computation follows fairly directly from the equations above.

def wkv_with_eps_forward(
    w: Tensor,
    u: Tensor,
    k: Tensor,
    v: Tensor,
    state: Tensor,
) -> tuple[Tensor, Tensor]:
    assert w.dim() == u.dim() == 1
    assert k.dim() == v.dim() == 3
    assert state.dim() == 4

    alpha, beta, eps = state[:, :, -1].chunk(3, dim=1)  # (B, 1, D), (B, 1, D), (B, 1, D)

    _, tsz, _ = k.shape

    wkvs = []
    alphas = [alpha]
    betas = [beta]
    epss = [eps]

    for t in range(tsz):
        kt, vt = k[:, t : t + 1], v[:, t : t + 1]
        ukt = u + kt
        tau = torch.maximum(ukt, eps)
        e1 = torch.exp(eps - tau)
        e2 = torch.exp(ukt - tau)
        wkv = (e1 * alpha + e2 * vt) / (e1 * beta + e2)
        wkvs.append(wkv)

        ww = eps - w
        eps = torch.maximum(ww, kt)
        e1 = torch.exp(ww - eps)
        e2 = torch.exp(kt - eps)
        alpha = e1 * alpha + e2 * vt
        beta = e1 * beta + e2

        alphas.append(alpha)
        betas.append(beta)
        epss.append(eps)

    alpha = torch.stack(alphas, dim=2)
    beta = torch.stack(betas, dim=2)
    eps = torch.stack(epss, dim=2)

    return torch.cat(wkvs, 1), torch.cat((alpha, beta, eps), dim=1)

Log-Space Operations

This section provides an alternative approach to achieving numerical stability in the WKV computation to the approach described above, by using log-space operations. This approach should be familiar to anyone who has dealt with log-domain Viterbi algorithms or graphical models, and is included here mainly as a point of comparison with the approach described above. For readers who are more comfortable reading code than equations, you can skip directly to the PyTorch implementation.

The prevalence of exponentials in the WKV computation suggests that it might be a good idea to perform some operations in log-space. For example, if we consider the vanilla αi\alpha_i and βi\beta_i equations:

αi=ewαi1+ekiviβi=ewβi1+eki \begin{aligned} \alpha_i & = e^{-w} \alpha_{i-1} + e^{k_i} v_i \\ \beta_i & = e^{-w} \beta_{i - 1} + e^{k_i} \\ \end{aligned}

If we were guaranteed that the values of αi1\alpha_{i-1}, βi1\beta_{i-1} and viv_i were always positive, we could take advantage of the log-sum-exp trick for computing the log-sum of exponentials in a numerically stable way:

LSE(x,y)=log(ex+ey)=max(x,y)+log(exmax(x,y)+eymax(x,y))=max(x,y)+log(1+exy) \begin{aligned} LSE(x, y) & = \log(e^x + e^y) \\ & = \max(x, y) + \log(e^{x - \max(x, y)} + e^{y - \max(x, y)}) \\ & = \max(x, y) + \log(1 + e^{-|x - y|}) \\ \end{aligned}

Re-writing the equations in log-space, we get:

logαi=log(ewαi1+ekivi)=LSE(w+logαi1,ki+logvi)logβi=log(ewβi1+eki)=LSE(w+logβi1,ki) \begin{aligned} \log \alpha_i & = \log(e^{-w} \alpha_{i-1} + e^{k_i} v_i) \\ & = LSE(-w + \log \alpha_{i-1}, k_i + \log v_i) \\[1em] \log \beta_i & = \log(e^{-w} \beta_{i - 1} + e^{k_i}) \\ & = LSE(-w + \log \beta_{i - 1}, k_i) \\ \end{aligned}

Revisiting our WKV equation:

wkvi=eu+kivi+αi1eu+ki+βi1\text{wkv}_i = \frac{ e^{u+k_i} v_i + \alpha_{i - 1} }{ e^{u+k_i} + \beta_{i - 1} }

We can re-write this in log-space as:

logwkvi=log(eu+kivi+αi1eu+ki+βi1)=log(eu+kivi+αi1)log(eu+ki+βi1)=LSE(u+ki+logvi,logαi1)LSE(u+ki,logβi1) \begin{aligned} \log \text{wkv}_i & = \log \left( \frac{ e^{u+k_i} v_i + \alpha_{i - 1} }{ e^{u+k_i} + \beta_{i - 1} } \right) \\ & = \log(e^{u+k_i} v_i + \alpha_{i - 1}) - \log(e^{u+k_i} + \beta_{i - 1}) \\ & = LSE(u + k_i + \log v_i, \log \alpha_{i - 1}) - LSE(u + k_i, \log \beta_{i - 1}) \\ \end{aligned}

The advantage here is that we no longer need to store ϵi\epsilon_i between steps.

Reparametrization

In order to avoid trying to take the log of a negative value, we need to make sure that viv_i is strictly positive. We can do so by reparametrizing viv_i as the sum of its positive and negative parts:

vi=min(vi,0)+ϵvi+=max(vi,0)+ϵvi=vi+vi \begin{aligned} v_i^- & = -\min(v_i, 0) + \epsilon \\ v_i^+ & = \max(v_i, 0) + \epsilon \\[1em] v_i & = v_i^+ - v_i^- \\ \end{aligned}

Note that the ϵ\epsilon in this equation is a small value added for numerical stability, not the ϵi\epsilon_i from earlier. This reparametrization ensures that the values of vi+v_i^+ and viv_i^- will always be in the range [ϵ,)[\epsilon, \infty) and therefore will have a non-imaginary log value.

We can take advantage of this fact to rewrite our equation for αi\alpha_i:

αi=j=1ie(ij)w+kj(vj+vj)=j=1ie(ij)w+kjvj+j=1ie(ij)w+kjvj \begin{aligned} \alpha_i & = \sum_{j=1}^i e^{-(i-j)w+k_j} (v_j^+ - v_j^-) \\ & = \sum_{j=1}^i e^{-(i-j)w+k_j} v_j^+ - \sum_{j=1}^i e^{-(i-j)w+k_j} v_j^- \\ \end{aligned}

Separating out αi=αi+αi\alpha_i = \alpha_i^+ - \alpha_i^-, we get:

αi+=ewαi1++ekjvj+αi=ewαi1+ekjvj \begin{aligned} \alpha_i^+ & = e^{-w} \alpha_{i - 1}^+ + e^{k_j} v_j^+ \\ \alpha_i^- & = e^{-w} \alpha_{i - 1}^- + e^{k_j} v_j^- \\ \end{aligned}

Note that we can renormalize αi+\alpha_i^+ and αi\alpha_i^- by subtracting min(αi+,αi)ϵ\min(\alpha_i^+, \alpha_i^-) - \epsilon from both value, which helps prevent the values of αi+\alpha_i^+ and αi\alpha_i^- exploding. Since we are working in the log domain, we should use the subtraction version of the log-sum-exp trick:

log(exey)=max(x,y)+log(exmax(x,y)eymax(x,y))=max(x,y)+log(1exy) \begin{aligned} \log(e^x - e^y) & = \max(x, y) + \log(e^{x - \max(x, y)} - e^{y - \max(x, y)}) \\ & = \max(x, y) + \log(1 - e^{-|x - y|}) \\ \end{aligned}

This only works since we know that αi+\alpha_i^+ and αi\alpha_i^- are both strictly greater than min(αi+,αi)ϵ\min(\alpha_i^+, \alpha_i^-) - \epsilon.

Finally, we can incorporate vi+v_i^+, viv_i^-, αi+\alpha_i^+ and αi\alpha_i^- into our log-space equations for wkvi\text{wkv}_i:

wkvi=eu+ki(vi+vi)+(αi1+αi1)eu+ki+βi1wkvi=eu+kivi++αi1+eu+ki+βi1eu+kivi+αi1eu+ki+βi1 \begin{aligned} \text{wkv}_i = \frac{ e^{u+k_i} (v_i^+ - v_i^-) + (\alpha_{i - 1}^+ - \alpha_{i - 1}^-) }{ e^{u+k_i} + \beta_{i - 1} } \\ \text{wkv}_i = \frac{ e^{u+k_i} v_i^+ + \alpha_{i - 1}^+ }{ e^{u+k_i} + \beta_{i - 1} } - \frac{ e^{u+k_i} v_i^- + \alpha_{i - 1}^- }{ e^{u+k_i} + \beta_{i - 1} }\\ \end{aligned}

Now we do have an equation (or rather, two equations) for wkvi\text{wkv}_i with strictly positive values. Separating out wkvi=wkvi+wkvi\text{wkv}_i = \text{wkv}_i^+ - \text{wkv}_i^-, we get:

wkvi+=eu+kivi++αi1+eu+ki+βi1logwkvi+=LSE(u+ki+logvi+,logαi1+)LSE(u+ki,logβi1)wkvi=eu+kivi+αi1eu+ki+βi1logwkvi=LSE(u+ki+logvi,logαi1)LSE(u+ki,logβi1) \begin{aligned} \text{wkv}_i^+ & = \frac{ e^{u+k_i} v_i^+ + \alpha_{i - 1}^+ }{ e^{u+k_i} + \beta_{i - 1} } \\ \log \text{wkv}_i^+ & = LSE(u + k_i + \log v_i^+, \log \alpha_{i - 1}^+) - LSE(u + k_i, \log \beta_{i - 1}) \\[1em] \text{wkv}_i^- & = \frac{ e^{u+k_i} v_i^- + \alpha_{i - 1}^- }{ e^{u+k_i} + \beta_{i - 1} } \\ \log \text{wkv}_i^- & = LSE(u + k_i + \log v_i^-, \log \alpha_{i - 1}^-) - LSE(u + k_i, \log \beta_{i - 1}) \\ \end{aligned}

Note that while we no longer need to use ϵi\epsilon_i as a state variable, we now need to carry αi+\alpha_i^+ and αi\alpha_i^-.

PyTorch Implementation

We can implement the above equations in PyTorch as follows:

def wkv_log_space_forward(
    w: Tensor,
    u: Tensor,
    k: Tensor,
    v: Tensor,
    state: Tensor,
    eps: float = EPS,
    normalize: bool = False,
) -> tuple[Tensor, Tensor]:
    bsz, tsz, chans = k.shape

    assert w.shape == u.shape == (chans,)
    assert v.shape == (bsz, tsz, chans)
    assert state.shape == (bsz, 3, 1, chans)

    ln_alpha_p, ln_alpha_m, ln_beta = state[:, :, -1].chunk(3, dim=1)

    log_eps = math.log(eps)

    wkvs = []
    ln_alpha_ps = [ln_alpha_p]
    ln_alpha_ms = [ln_alpha_m]
    ln_betas = [ln_beta]

    def logaddexp(a: Tensor, b: Tensor) -> Tensor:
        max_av = torch.maximum(a, b)
        return max_av + torch.log(torch.exp(a - max_av) + torch.exp(b - max_av))

    def logsubexp(a: Tensor, b: Tensor) -> Tensor:
        max_av = torch.maximum(torch.maximum(a, b), torch.full_like(a, log_eps))
        return max_av + torch.log(torch.exp(a - max_av) - torch.exp(b - max_av))

    for t in range(tsz):
        kt, vt = k[:, t : t + 1], v[:, t : t + 1]
        vt_p, vt_m = torch.clamp_min(vt, 0) + eps, torch.clamp_min(-vt, 0) + eps
        ln_v_p, ln_v_m = torch.log(vt_p), torch.log(vt_m)

        if normalize:
            ln_alpha_pm = torch.minimum(ln_alpha_p, ln_alpha_m) - eps
            ln_alpha_p = logsubexp(ln_alpha_p, ln_alpha_pm)
            ln_alpha_m = logsubexp(ln_alpha_m, ln_alpha_pm)

        ln_wkv_p = logaddexp(u + kt + ln_v_p, ln_alpha_p) - logaddexp(u + kt, ln_beta)
        ln_wkv_m = logaddexp(u + kt + ln_v_m, ln_alpha_m) - logaddexp(u + kt, ln_beta)

        wkv = torch.exp(ln_wkv_p) - torch.exp(ln_wkv_m)
        wkvs.append(wkv)

        ln_alpha_p = logaddexp(-w + ln_alpha_p, kt + ln_v_p)
        ln_alpha_m = logaddexp(-w + ln_alpha_m, kt + ln_v_m)
        ln_beta = logaddexp(-w + ln_beta, kt)

        ln_alpha_ps.append(ln_alpha_p)
        ln_alpha_ms.append(ln_alpha_m)
        ln_betas.append(ln_beta)

    ln_alpha_p = torch.stack(ln_alpha_ps, dim=2)
    ln_alpha_m = torch.stack(ln_alpha_ms, dim=2)
    ln_beta = torch.stack(ln_betas, dim=2)

    return torch.cat(wkvs, 1), torch.cat((ln_alpha_p, ln_alpha_m, ln_beta), dim=1)

Gradients

This section implements a manual backward pass for the vanilla WKV computation described above, first derived from the equations above, then implemented in PyTorch.

When we are actually implementing this in PyTorch, we will want to write an optimized kernel for performing the WKV computation. The downside is that it means we can’t use autograd to compute the gradients for us, so we need to derive equations for the gradients.

Revisiting our original equation for the WKV computation:

wkvi=eu+kivi+αi1eu+ki+βi1\text{wkv}_i = \frac{ e^{u+k_i} v_i + \alpha_{i - 1} }{ e^{u+k_i} + \beta_{i - 1} }

The partial derivatives of the WKV computation with respect to uu, kk, and vv are as follows:

wkviu=wkviki=eu+kivieu+ki+βi1eu+ki(eu+kivi+αi1)(eu+ki+βi1)2=eu+ki(βi1viαi1)(βi1+eu+ki)2wkvivi=eu+kieu+ki+βi1wkviαi1=1eu+ki+βi1wkviβi1=veu+ki+αi1(eu+ki+βi1)2 \begin{aligned} \frac{\partial \text{wkv}_i}{\partial u} = \frac{\partial \text{wkv}_i}{\partial k_i} & = \frac{ e^{u + k_i} v_i}{e^{u + k_i} + \beta_{i - 1}} - \frac{ e^{u + k_i} (e^{u + k_i} v_i + \alpha_{i - 1})}{(e^{u + k_i} + \beta_{i - 1})^2} \\ & = \frac{ e^{u + k_i} (\beta_{i - 1} v_i - \alpha_{i - 1})}{(\beta_{i - 1} + e^{u + k_i})^2} \\[1.5em] \frac{\partial \text{wkv}_i}{\partial v_i} & = \frac{ e^{u + k_i}}{e^{u + k_i} + \beta_{i-1}} \\[1.5em] \frac{\partial \text{wkv}_i}{\partial \alpha_{i-1}} & = \frac{1}{e^{u + k_i} + \beta_{i-1}} \\ \frac{\partial \text{wkv}_i}{\partial \beta_{i-1}} & = -\frac{v e^{u + k_i} + \alpha_{i-1}}{(e^{u + k_i} + \beta_{i-1})^2} \\ \end{aligned}

We also need to compute the partial derivatives of αi\alpha_i and βi\beta_i. Fortunately these are comparatively simple. Revisiting our original equations:

αi=ewαi1+ekiviβi=ewβi1+eki \begin{aligned} \alpha_i & = e^{-w} \alpha_{i - 1} + e^{k_i} v_i \\ \beta_i & = e^{-w} \beta_{i - 1} + e^{k_i} \\ \end{aligned}

For αi\alpha_i we have:

αiw=ewαi1  αiki=ekiviαiαi1=ew  αivi=eki \begin{aligned} \frac{\partial \alpha_i}{\partial w} & = -e^{-w} \alpha_{i - 1} \; & \frac{\partial \alpha_i}{\partial k_i} & = e^{k_i} v_i \\[1.5em] \frac{\partial \alpha_i}{\partial \alpha_{i-1}} & = e^{-w} \; & \frac{\partial \alpha_i}{\partial v_i} & = e^{k_i} \\ \end{aligned}

For βi\beta_i we have:

βiw=ewβi1  βiki=ekiβiβi1=ew   \begin{aligned} \frac{\partial \beta_i}{\partial w} & = -e^{-w} \beta_{i - 1} \; & \frac{\partial \beta_i}{\partial k_i} & = e^{k_i} \\[1.5em] \frac{\partial \beta_i}{\partial \beta_{i-1}} & = e^{-w} \; & & \\ \end{aligned}

PyTorch Implementation

We can manually implement the above equations in PyTorch. This implementation is more academic than practical, since it’s a straightforward function to implement as a CUDA or Triton kernel, but sometimes it is easier to read code than equations (also, and perhaps more importantly, it lets us write unit tests to make sure the equations are correct).

def wkv_vanilla_backward(
    w: Tensor,
    u: Tensor,
    k: Tensor,
    v: Tensor,
    state: Tensor,
    grad_wkv: Tensor,
    grad_state: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
    bsz, tsz, chans = k.shape
    assert w.shape == u.shape == (chans,)
    assert v.shape == (bsz, tsz, chans)
    assert state.shape == (bsz, 2, tsz + 1, chans)
    assert grad_wkv.shape == (bsz, tsz, chans)
    assert grad_state.shape == (bsz, 2, 1, chans)

    alpha, beta = state.chunk(2, dim=1)  # (B, 1, T + 1, D), (B, 1, T + 1, D)
    grad_alpha, grad_beta = grad_state[:, :, 0].chunk(2, dim=1)  # (B, 1, D), (B, 1, D)

    ew = torch.exp(-w)

    grad_w = torch.zeros_like(w)
    grad_u = torch.zeros_like(u)
    grad_k = torch.zeros_like(k)
    grad_v = torch.zeros_like(v)

    for t in reversed(range(tsz)):
        kt, vt = k[:, t : t + 1], v[:, t : t + 1]
        alpha_prev, beta_prev = alpha[:, :, t], beta[:, :, t]
        euk = torch.exp(u + kt)
        ek = torch.exp(kt)

        grad_wkvt = grad_wkv[:, t : t + 1]

        # Backpropagates wkv gradients.
        grad_uk = grad_wkvt * euk * (beta_prev * vt - alpha_prev) / (beta_prev + euk) ** 2
        grad_u += grad_uk.flatten(0, -2).sum(0)
        grad_k[:, t : t + 1] += grad_uk
        grad_v[:, t : t + 1] += grad_wkvt * euk / (beta_prev + euk)

        # Backpropagate alpha gradients.
        grad_w += (grad_alpha * ew * alpha_prev).flatten(0, -2).sum(0)
        grad_k[:, t : t + 1] += grad_alpha * ek * vt
        grad_v[:, t : t + 1] += grad_alpha * ek

        # Backpropagate beta gradients.
        grad_w += (grad_beta * ew * beta_prev).flatten(0, -2).sum(0)
        grad_k[:, t : t + 1] += grad_beta * ek

        # Compute gradients for alpha and beta.
        grad_alpha = grad_alpha * ew + grad_wkvt / (beta_prev + euk)
        grad_beta = grad_beta * ew - grad_wkvt * (euk * vt + alpha_prev) / (beta_prev + euk) ** 2

    return -grad_w, grad_u, grad_k, grad_v, torch.stack((grad_alpha, grad_beta), dim=1)

Numerically Stable Gradients

This section implements a manual backward pass for the numerically stable WKV computation described above, first derived from the equations above, then implemented in PyTorch.

Recall the numerically stable version of our WKV computation:

αi=ew+ϵi1ϵiαi1+ekiϵiviβi=ew+ϵi1ϵiβi1+ekiϵiwkvi=eu+kiτivi+eϵi1τiαi1eu+kiτi+eϵi1τiβi1ϵi=max(w+ϵi1,ki)τi=max(u+ki,ϵi1) \begin{aligned} \alpha_i' & = e^{-w + \epsilon_{i - 1} - \epsilon_i} \alpha_{i - 1}' + e^{k_i - \epsilon_i} v_i \\ \beta_i' & = e^{-w + \epsilon_{i - 1} - \epsilon_i} \beta_{i - 1}' + e^{k_i - \epsilon_i} \\[1em] \text{wkv}_i & = \frac{ e^{u+k_i-\tau_i} v_i + e^{\epsilon_{i - 1}-\tau_i}\alpha_{i - 1}' }{ e^{u+k_i-\tau_i} + e^{\epsilon_{i - 1}-\tau_i}\beta_{i - 1}' } \\[1em] \epsilon_{i} & = \max(-w + \epsilon_{i - 1}, k_i) \\ \tau_i & = \max(u + k_i, \epsilon_{i - 1}) \\ \end{aligned}

The partial derivatives for the above computation are similar to the vanilla WKV computation, but adjusted for the ϵ\epsilon terms:

wkviu=wkviki=eu+kiτivieu+kiτi+eϵi1τiβi1eu+kiτi(eu+kiτivi+eϵi1τiαi1)(eu+kiτi+eϵi1τiβi1)2=eu+kiτi(eϵi1τiβi1vieϵi1τiαi1)(eϵi1τiβi1+eu+kiτi)2wkvivi=eu+kiτieu+kiτi+eϵi1τiβi1wkviαi1=eϵi1τieu+kiτi+eϵi1τiβi1wkviβi1=eϵi1τi(vieu+kiτi+eϵi1τiαi1)(eu+kiτi+eϵi1τiβi1)2wkviϵi1=eu+ki+ϵi1(αi1viβi1)(eϵi1βi1+eu+ki)2=eu+ki+ϵi12τi(αi1viβi1)(eϵi1τiβi1+eu+kiτi)2 \begin{aligned} \frac{\partial \text{wkv}_i'}{\partial u} = \frac{\partial \text{wkv}_i'}{\partial k_i} & = \frac{ e^{u + k_i - \tau_i} v_i}{e^{u + k_i - \tau_i} + e^{\epsilon_{i - 1} - \tau_i} \beta_{i - 1}'} - \frac{ e^{u + k_i - \tau_i} (e^{u + k_i - \tau_i} v_i + e^{\epsilon_{i - 1} - \tau_i} \alpha_{i - 1}')}{(e^{u + k_i - \tau_i} + e^{\epsilon_{i - 1} - \tau_i} \beta_{i - 1}')^2} \\ & = \frac{ e^{u + k_i - \tau_i} (e^{\epsilon_{i - 1} - \tau_i} \beta_{i - 1}' v_i - e^{\epsilon_{i - 1} - \tau_i} \alpha_{i - 1}')}{(e^{\epsilon_{i - 1} - \tau_i} \beta_{i - 1}' + e^{u + k_i - \tau_i})^2} \\[1.5em] \frac{\partial \text{wkv}_i'}{\partial v_i} & = \frac{ e^{u + k_i - \tau_i}}{e^{u + k_i - \tau_i} + e^{\epsilon_{i - 1} - \tau_i} \beta_{i - 1}'} \\[1.5em] \frac{\partial \text{wkv}_i'}{\partial \alpha_{i-1}'} & = \frac{e^{\epsilon_{i - 1} - \tau_i}}{e^{u + k_i - \tau_i} + e^{\epsilon_{i - 1} - \tau_i} \beta_{i - 1}'} \\[1.5em] \frac{\partial \text{wkv}_i'}{\partial \beta_{i-1}'} & = -\frac{e^{\epsilon_{i-1} - \tau_i}(v_i e^{u + k_i - \tau_i} + e^{\epsilon_{i-1} - \tau_i}\alpha_{i-1}')}{(e^{u + k_i - \tau_i} + e^{\epsilon_{i - 1} - \tau_i} \beta_{i - 1}')^2} \\[1.5em] \frac{\partial \text{wkv}_i'}{\partial \epsilon_{i - 1}} & = \frac{ e^{u + k_i + \epsilon_{i - 1}} (\alpha_{i - 1}' - v_i \beta_{i - 1}')}{(e^{\epsilon_{i - 1}} \beta_{i - 1}' + e^{u + k_i})^2} = \frac{ e^{u + k_i + \epsilon_{i - 1} - 2 \tau_i} (\alpha_{i - 1}' - v_i \beta_{i - 1}')}{(e^{\epsilon_{i - 1} - \tau_i} \beta_{i - 1}' + e^{u + k_i - \tau_i})^2} \\ \end{aligned}

For αi\alpha_i' we have:

αiw=ew+ϵi1ϵiαi1  αiϵi1=ew+ϵi1ϵiαi1αiki=ekiϵivi  αiϵi=αiαiαi1=ew+ϵi1ϵi  αivi=ekiϵi \begin{aligned} \frac{\partial \alpha_i'}{\partial w} & = -e^{-w + \epsilon_{i - 1} - \epsilon_i} \alpha_{i - 1}' \; & \frac{\partial \alpha_i'}{\partial \epsilon_{i-1}} & = e^{-w + \epsilon_{i - 1} - \epsilon_i} \alpha_{i - 1}' \\[1.5em] \frac{\partial \alpha_i'}{\partial k_i} & = e^{k_i - \epsilon_i} v_i \; & \frac{\partial \alpha_i'}{\partial \epsilon_i} & = -\alpha_i' \\[1.5em] \frac{\partial \alpha_i'}{\partial \alpha_{i-1}'} & = e^{w + \epsilon_{i - 1} - \epsilon_i} \; & \frac{\partial \alpha_i'}{\partial v_i} & = e^{k_i - \epsilon_i} \\ \end{aligned}

For βi\beta_i' we have:

βiw=ew+ϵi1ϵiβi1  βiϵi1=ew+ϵi1ϵiβi1βiki=ekiϵi  βiϵi=βiβiβi1=ew+ϵi1ϵi \begin{aligned} \frac{\partial \beta_i'}{\partial w} & = e^{w + \epsilon_{i - 1} - \epsilon_i} \beta_{i - 1}' \; & \frac{\partial \beta_i'}{\partial \epsilon_{i-1}} & = e^{w + \epsilon_{i - 1} - \epsilon_i} \beta_{i - 1}'\\[1.5em] \frac{\partial \beta_i'}{\partial k_i} & = e^{k_i - \epsilon_i} \; & \frac{\partial \beta_i'}{\partial \epsilon_i} & = -\beta_i' \\[1.5em] \frac{\partial \beta_i'}{\partial \beta_{i-1}'} & = e^{w + \epsilon_{i - 1} - \epsilon_i} & & \\ \end{aligned}

For ϵi\epsilon_i we have:

ϵiw={1if w+ϵi1>ki0otherwiseϵiϵi1={1if w+ϵi1>ki0otherwiseϵiki={1if w+ϵi1<ki0otherwise \begin{aligned} \frac{\partial \epsilon_i}{\partial w} & = \begin{cases} -1 & \text{if } -w + \epsilon_{i - 1} > k_i \\ 0 & \text{otherwise} \end{cases} \\ \frac{\partial \epsilon_i}{\partial \epsilon_{i - 1}} & = \begin{cases} 1 & \text{if } -w + \epsilon_{i - 1} > k_i \\ 0 & \text{otherwise} \end{cases} \\ \frac{\partial \epsilon_i}{\partial k_i} & = \begin{cases} 1 & \text{if } -w + \epsilon_{i - 1} < k_i \\ 0 & \text{otherwise} \end{cases} \\ \end{aligned}

PyTorch Implementation

The PyTorch implementation for the numerically stable gradients will be similar to the vanilla gradients, but with the addition of the $\epsilon$ terms.

def wkv_with_eps_backward(
    w: Tensor,
    u: Tensor,
    k: Tensor,
    v: Tensor,
    state: Tensor,
    grad_wkv: Tensor,
    grad_state: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
    bsz, tsz, chans = k.shape
    assert w.shape == u.shape == (chans,)
    assert v.shape == (bsz, tsz, chans)
    assert state.shape == (bsz, 3, tsz + 1, chans)
    assert grad_wkv.shape == (bsz, tsz, chans)
    assert grad_state.shape == (bsz, 3, 1, chans)

    alpha, beta, eps = state.chunk(3, dim=1)  # (B, 1, T + 1, D), (B, 1, T + 1, D), (B, 1, T + 1, D)
    grad_alpha, grad_beta, grad_eps = grad_state[:, :, 0].chunk(3, dim=1)  # (B, 1, D), (B, 1, D), (B, 1, D)
    grad_eps = grad_eps.clone()

    grad_w = torch.zeros_like(w)
    grad_u = torch.zeros_like(u)
    grad_k = torch.zeros_like(k)
    grad_v = torch.zeros_like(v)

    for t in reversed(range(tsz)):
        kt, vt = k[:, t : t + 1], v[:, t : t + 1]
        alpha_prev, beta_prev, eps_prev = alpha[:, :, t], beta[:, :, t], eps[:, :, t]
        alpha_curr, beta_curr, eps_curr = alpha[:, :, t + 1], beta[:, :, t + 1], eps[:, :, t + 1]
        ukt = u + kt
        tau = torch.maximum(ukt, eps_prev)
        e1 = torch.exp(eps_prev - tau)
        e2 = torch.exp(ukt - tau)

        euke = torch.exp(ukt + eps_prev - 2 * tau)

        denom = e1 * beta_prev + e2
        denom_sq = denom**2

        grad_wkvt = grad_wkv[:, t : t + 1]

        # Backpropagates wkv gradients.
        grad_uk = grad_wkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq
        grad_u += grad_uk.flatten(0, -2).sum(0)
        grad_k[:, t : t + 1] += grad_uk
        grad_v[:, t : t + 1] += grad_wkvt * e2 / denom

        grad_alpha_wkv = grad_wkvt * e1 / denom
        grad_beta_wkv = -grad_wkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq
        grad_eps_wkv = grad_wkvt * euke * (alpha_prev - vt * beta_prev) / (e1 * beta_prev + e2) ** 2

        e1 = torch.exp(-w + eps_prev - eps_curr)
        e2 = torch.exp(kt - eps_curr)

        # Backpropagates alpha gradients.
        grad_alpha_we = grad_alpha * e1 * alpha_prev
        grad_w -= grad_alpha_we.flatten(0, -2).sum(0)
        grad_k[:, t : t + 1] += grad_alpha * e2 * vt
        grad_v[:, t : t + 1] += grad_alpha * e2
        grad_eps += grad_alpha * -alpha_curr

        # Backpropagates beta gradients.
        grad_beta_we = grad_beta * e1 * beta_prev
        grad_w -= grad_beta_we.flatten(0, -2).sum(0)
        grad_k[:, t : t + 1] += grad_beta * e2
        grad_eps += grad_beta * -beta_curr

        # Backpropagates epsilon gradients.
        eps_grad_mask = -w + eps_prev > kt
        grad_eps_we = torch.where(eps_grad_mask, grad_eps, torch.zeros_like(grad_eps))
        grad_w -= grad_eps_we.flatten(0, -2).sum(0)
        grad_k[:, t : t + 1] += torch.where(eps_grad_mask, torch.zeros_like(grad_eps), grad_eps)

        # Computes gradients for alpha, beta and epsilon.
        grad_alpha = grad_alpha * e1 + grad_alpha_wkv
        grad_beta = grad_beta * e1 + grad_beta_wkv
        grad_eps = grad_alpha_we + grad_beta_we + grad_eps_we + grad_eps_wkv

    return grad_w, grad_u, grad_k, grad_v, torch.stack((grad_alpha, grad_beta, grad_eps), dim=1)

Log-Space Gradients

In our log-space implementation of the WKV computation, we have:

wkvi=elogwkvi+elogwkvi\text{wkv}_i = e^{\log \text{wkv}_i^+} - e^{\log \text{wkv}_i^-}

This gives us the following partial derivatives:

wkvilogwkvi+=wkvi+wkvilogwkvi=wkvi \begin{aligned} \frac{\partial \text{wkv}_i}{\partial \log \text{wkv}_i^+} & = \text{wkv}_i^+ \\[1.5em] \frac{\partial \text{wkv}_i}{\partial \log \text{wkv}_i^-} & = -\text{wkv}_i^- \\ \end{aligned}

Next, we need to compute the gradients of logwkvi+\log \text{wkv}_i^+ and logwkvi\log \text{wkv}_i^- with respect to each of our inputs. We have:

logwkvi+=LSE(u+ki+logvi+,logαi1+)LSE(u+ki,logβi1)logwkvi=LSE(u+ki+logvi,logαi1)LSE(u+ki,logβi1) \begin{aligned} \log \text{wkv}_i^+ & = LSE(u + k_i + \log v_i^+, \log \alpha_{i - 1}^+) - LSE(u + k_i, \log \beta_{i - 1}) \\[1em] \log \text{wkv}_i^- & = LSE(u + k_i + \log v_i^-, \log \alpha_{i - 1}^-) - LSE(u + k_i, \log \beta_{i - 1}) \\ \end{aligned}

These two equations are identical, so in the equations below we omit the sign and simply use logwkvi\log \text{wkv}_i.

The gradients of the log-sum-exp function are given by:

LSE(a,b)a=eaea+eb=11+ebaLSE(a,b)b=ebea+eb=11+eab \begin{aligned} \frac{\partial LSE(a, b)}{\partial a} & = \frac{e^a}{e^a + e^b} & = \frac{1}{1 + e^{b - a}} \\[1em] \frac{\partial LSE(a, b)}{\partial b} & = \frac{e^b}{e^a + e^b} & = \frac{1}{1 + e^{a - b}} \\ \end{aligned}

We can use this to find the partial derivatives. Note that since we are using log-space state variables, we need to find the partial derivatives with respect to logαi1\log \alpha_{i-1} and logβi1\log \beta_{i-1} rather than αi1\alpha_{i-1} and βi1\beta_{i-1}. We avoid simplifying the expression because it more closely matches the implementation in code.

logwkviui=logwkviki=11+elogαi1(u+ki+logvi)11+elogβi1(u+ki)logwkvivi=1vi(1+elogαi1(u+ki+logvi))logwkvilogαi1=11+e(u+ki+logvi)logαi1logwkvilogβi1=11+e(u+ki)logβi1 \begin{aligned} \frac{\partial \log \text{wkv}_i}{\partial u_i} = \frac{\partial \log \text{wkv}_i}{\partial k_i} & = \frac{1}{1 + e^{\log \alpha_{i - 1} - (u + k_i + \log v_i)}} - \frac{1}{1 + e^{\log \beta_{i - 1} - (u + k_i)}} \\[1em] \frac{\partial \log \text{wkv}_i}{\partial v_i} & = \frac{1}{v_i (1 + e^{\log \alpha_{i - 1} - (u + k_i + \log v_i)})} \\[1em] \frac{\partial \log \text{wkv}_i}{\partial \log \alpha_{i - 1}} & = \frac{1}{1 + e^{(u + k_i + \log v_i) - \log \alpha_{i - 1}}} \\[1em] \frac{\partial \log \text{wkv}_i}{\partial \log \beta_{i - 1}} & = -\frac{1}{1 + e^{(u + k_i) - \log \beta_{i - 1}}} \\[1em] \end{aligned}

Additionally, we need to find the partial derivatives of logαi\log \alpha_{i} and logβi\log \beta_{i}. Recall the log-space update rule:

logαi=LSE(w+logαi1,ki+logvi)logβi=LSE(w+logβi1,ki) \begin{aligned} \log \alpha_i & = LSE(-w + \log \alpha_{i-1}, k_i + \log v_i) \\ \log \beta_i & = LSE(-w + \log \beta_{i - 1}, k_i) \\ \end{aligned}

The partial derivatives of logαi\log \alpha_i are:

logαiw=11+e(ki+logvi)(w+logαi1)logαilogαi1=11+e(ki+logvi)(w+logαi1)logαiki=11+e(logαi1w)(ki+logvi)logαivi=1vi(1+e(logαi1w)(ki+logvi)) \begin{aligned} \frac{\partial \log \alpha_i}{\partial w} & = \frac{-1}{1 + e^{(k_i + \log v_i) - (-w + \log \alpha_{i - 1})}} \\[1em] \frac{\partial \log \alpha_i}{\partial \log \alpha_{i - 1}} & = \frac{1}{1 + e^{(k_i + \log v_i) - (-w + \log \alpha_{i - 1})}} \\[1em] \frac{\partial \log \alpha_i}{\partial k_i} & = \frac{1}{1 + e^{(\log \alpha_{i - 1} - w) - (k_i + \log v_i)}} \\[1em] \frac{\partial \log \alpha_i}{\partial v_i} & = \frac{1}{v_i (1 + e^{(\log \alpha_{i - 1} - w) - (k_i + \log v_i)})} \\ \end{aligned}

The partial derivatives of logβi\log \beta_{i} are:

logβiw=11+eki(w+logβi1)logβilogβi1=11+eki(w+logβi1)logβiki=11+e(w+logβi1)ki \begin{aligned} \frac{\partial \log \beta_i}{\partial w} & = \frac{-1}{1 + e^{k_i - (-w + \log \beta_{i - 1})}} \\[1em] \frac{\partial \log \beta_i}{\partial \log \beta_{i - 1}} & = \frac{1}{1 + e^{k_i - (-w + \log \beta_{i - 1})}} \\[1em] \frac{\partial \log \beta_i}{\partial k_i} & = \frac{1}{1 + e^{(-w + \log \beta_{i - 1}) - k_i}} \\ \end{aligned}

Lastly, a small point of note regarding the partial derivatives of viv_i:

vi+vi={1if vi>00otherwisevivi={1if vi<00otherwise \begin{aligned} \frac{\partial v_i^+}{\partial v_i} & = \begin{cases} 1 & \text{if } v_i > 0 \\ 0 & \text{otherwise} \end{cases} \\ \frac{\partial v_i^-}{\partial v_i} & = \begin{cases} -1 & \text{if } v_i < 0 \\ 0 & \text{otherwise} \end{cases} \\ \end{aligned}

PyTorch Implementation

The PyTorch implementation follows from the equation above, although there is some trickiness involved in dealing with the positive and negative sides.

def wkv_log_space_backward(
    w: Tensor,
    u: Tensor,
    k: Tensor,
    v: Tensor,
    state: Tensor,
    grad_wkv: Tensor,
    grad_state: Tensor,
    eps: float = EPS,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
    bsz, tsz, chans = k.shape

    assert w.shape == u.shape == (chans,)
    assert v.shape == (bsz, tsz, chans)
    assert state.shape == (bsz, 3, tsz + 1, chans)
    assert grad_wkv.shape == (bsz, tsz, chans)
    assert grad_state.shape == (bsz, 3, 1, chans)

    grad_ln_alpha_p, grad_ln_alpha_m, grad_ln_beta = grad_state[:, :, 0].chunk(3, dim=1)

    grad_w = torch.zeros_like(w)
    grad_u = torch.zeros_like(u)
    grad_k = torch.zeros_like(k)
    grad_v = torch.zeros_like(v)

    def logaddexp(a: Tensor, b: Tensor) -> Tensor:
        max_av = torch.maximum(a, b)
        return max_av + torch.log(torch.exp(a - max_av) + torch.exp(b - max_av))

    for t in reversed(range(tsz)):
        kt, vt = k[:, t : t + 1], v[:, t : t + 1]
        vt_p, vt_m = torch.clamp_min(vt, 0) + eps, torch.clamp_min(-vt, 0) + eps
        ln_v_p, ln_v_m = torch.log(vt_p), torch.log(vt_m)

        ln_alpha_p_prev, ln_alpha_m_prev, ln_beta_prev = state[:, :, t].chunk(3, dim=1)

        uk = u + kt
        ukv_p, ukv_m = uk + ln_v_p, uk + ln_v_m

        ukb = logaddexp(uk, ln_beta_prev)
        wkv_p = torch.exp(logaddexp(ukv_p, ln_alpha_p_prev) - ukb)
        wkv_m = torch.exp(logaddexp(ukv_m, ln_alpha_m_prev) - ukb)

        grad_wkvt = grad_wkv[:, t : t + 1]
        grad_ln_wkv_p, grad_ln_wkv_m = grad_wkvt * wkv_p, grad_wkvt * -wkv_m

        # Backpropagates wkv gradients.
        e_num_p = torch.exp(ln_alpha_p_prev - ukv_p)
        e_num_m = torch.exp(ln_alpha_m_prev - ukv_m)
        e_den = torch.exp(ln_beta_prev - uk)
        grad_wkv_den_p, grad_wkv_den_m = grad_ln_wkv_p / (1 + e_den), grad_ln_wkv_m / (1 + e_den)
        grad_kv_p, grad_kv_m = grad_ln_wkv_p / (1 + e_num_p), grad_ln_wkv_m / (1 + e_num_m)
        grad_uk = grad_kv_p + grad_kv_m - grad_wkv_den_p - grad_wkv_den_m
        grad_u += grad_uk.flatten(0, -2).sum(0)
        grad_k[:, t : t + 1] += grad_uk
        grad_v[:, t : t + 1] += torch.where(vt > 0, grad_kv_p / vt_p, grad_kv_m / -vt_m)

        grad_ln_alpha_wkv_p = grad_ln_wkv_p / (1 + (1 / e_num_p))
        grad_ln_alpha_wkv_m = grad_ln_wkv_m / (1 + (1 / e_num_m))
        grad_ln_beta_wkv = -grad_ln_wkv_p / (1 + (1 / e_den)) - grad_ln_wkv_m / (1 + (1 / e_den))

        # Backpropagates alpha gradients.
        e_alpha_p = torch.exp(kt + ln_v_p - (-w + ln_alpha_p_prev))
        e_alpha_m = torch.exp(kt + ln_v_m - (-w + ln_alpha_m_prev))
        grad_wa_p = grad_ln_alpha_p / (1 + e_alpha_p)
        grad_wa_m = grad_ln_alpha_m / (1 + e_alpha_m)
        grad_w -= (grad_wa_p + grad_wa_m).flatten(0, -2).sum(0)
        grad_kv_p, grad_kv_m = grad_ln_alpha_p / (1 + (1 / e_alpha_p)), grad_ln_alpha_m / (1 + (1 / e_alpha_m))
        grad_k[:, t : t + 1] += grad_kv_p + grad_kv_m
        grad_v[:, t : t + 1] += torch.where(vt > 0, grad_kv_p / vt_p, -grad_kv_m / vt_m)

        # Backpropagates beta gradients.
        e_beta = torch.exp(kt - (-w + ln_beta_prev))
        grad_wb = grad_ln_beta / (1 + e_beta)
        grad_w -= grad_wb.flatten(0, -2).sum(0)
        grad_k[:, t : t + 1] += grad_ln_beta / (1 + (1 / e_beta))

        # Compute gradients for log alpha and log beta.
        grad_ln_alpha_p = grad_wa_p + grad_ln_alpha_wkv_p
        grad_ln_alpha_m = grad_wa_m + grad_ln_alpha_wkv_m
        grad_ln_beta = grad_wb + grad_ln_beta_wkv

    return grad_w, grad_u, grad_k, grad_v, torch.stack((grad_ln_alpha_p, grad_ln_alpha_m, grad_ln_beta), dim=1)