Ben's Blog

Retentive Networks and RWKV

A short, hand-wavy explainer for the mathematical intuition behind faster attention mechanisms.

Sep 16, 2023

I recently heard about a new paper entitled Retentive Network: A Successor to Transformer for Large Language Models which purports to deliver Transformer-level performance with constant memory costs at inference time. It is pretty similar to the RWKV paper, which I’ve described in a previous post. I figured I’d write a short post walking through the problem that these two papers are trying to solve:

The cost of doing inference with a self-attention model scales linearly with the number of input timesteps.

To rephrase the above issue, as the number of timesteps in our sequence increases, so does the computational cost of computing the next predicted token. The RetNet and RWKV language models attempt to ameliorate these issues using two approaches. In this post I’ll give a hand-wavy explanation for how you might come up with these models if you were interested in changing the attention architecture to reduce the runtime computational cost.

First consider the vanilla self-attention equations:

SoftmaxAttention(Q,K,V)=softmax(QKTdk)V\text{SoftmaxAttention}(Q, K, V) = \text{softmax}(\frac{Q K^T}{\sqrt{d_k}}) V

The QQ, KK and VV matrices have shape (t,d)(\text{t}, \text{d}) where t\text{t} is the number of tokens in the input sequence and d\text{d} is the dimensionality of the embedding. We’ll ignore the d\text{d} dimension to make it easier to visualize. The QKTQ K^T matrix multiplication looks like this:

[q1k1q1k2q1k3q2k1q2k2q2k3q3k1q3k2q3k3] \begin{bmatrix} q_{1} k_{1} & q_{1} k_{2} & q_{1} k_{3} & \\ q_{2} k_{1} & q_{2} k_{2} & q_{2} k_{3} & \cdots \\ q_{3} k_{1} & q_{3} k_{2} & q_{3} k_{3} & & \\ & \vdots & & \\ \end{bmatrix}

When doing causal self-attention, we zero out any cells where tq>tkt_q > t_k, so the matrix multiplication looks like this:

[q1k1q2k1q2k2q3k1q3k2q3k3] \begin{bmatrix} q_{1} k_{1} & -\infty & -\infty & \\ q_{2} k_{1} & q_{2} k_{2} & -\infty & \cdots \\ q_{3} k_{1} & q_{3} k_{2} & q_{3} k_{3} & & \\ & \vdots & & \\ \end{bmatrix}

We then take the row-wise softmax to get the attention weights matrix:

[eq1k1t1eq1kt00eq2k1t2eq2kteq2k2t2eq2kt0eq3k1t3eq3kteq3k2t3eq3kteq3k3t3eq3kt] \begin{bmatrix} \frac{e^{q_{1} k_{1}}}{\sum_{t}^{1} e^{q_{1} k_{t}}} & 0 & 0 & \\ \frac{e^{q_{2} k_{1}}}{\sum_{t}^{2} e^{q_{2} k_{t}}} & \frac{e^{q_{2} k_{2}}}{\sum_{t}^{2} e^{q_{2} k_{t}}} & 0 & \cdots \\ \frac{e^{q_{3} k_{1}}}{\sum_{t}^{3} e^{q_{3} k_{t}}} & \frac{e^{q_{3} k_{2}}}{\sum_{t}^{3} e^{q_{3} k_{t}}} & \frac{e^{q_{3} k_{3}}}{\sum_{t}^{3} e^{q_{3} k_{t}}} & & \\ & \vdots & & \\ \end{bmatrix}

Finally (ignoring the scaling factor dkd_k), we multiply by VV to get the output:

[eq1k1t1eq1ktv1eq2k1t2eq2ktv1+eq2k2t2eq2ktv2eq3k1t3eq3ktv1+eq3k2t3eq3ktv2+eq3k3t3eq3ktv3] \begin{bmatrix} \frac{e^{q_{1} k_{1}}}{\sum_{t}^{1} e^{q_{1} k_{t}}} v_1 \\ \frac{e^{q_{2} k_{1}}}{\sum_{t}^{2} e^{q_{2} k_{t}}} v_1 + \frac{e^{q_{2} k_{2}}}{\sum_{t}^{2} e^{q_{2} k_{t}}} v_2 \\ \frac{e^{q_{3} k_{1}}}{\sum_{t}^{3} e^{q_{3} k_{t}}} v_1 + \frac{e^{q_{3} k_{2}}}{\sum_{t}^{3} e^{q_{3} k_{t}}} v_2 + \frac{e^{q_{3} k_{3}}}{\sum_{t}^{3} e^{q_{3} k_{t}}} v_3 \\ \vdots \\ \end{bmatrix}

We can write each of the outputs as:

o1=t1vteq1ktt1eq1kto2=t2vteq2ktt2eq2kto3=t3vteq3ktt3eq3kt \begin{aligned} o_1 & = \frac{\sum_{t}^{1} v_t e^{q_{1} k_{t}}}{\sum_{t}^{1} e^{q_{1} k_{t}}} \\ o_2 & = \frac{\sum_{t}^{2} v_t e^{q_{2} k_{t}}}{\sum_{t}^{2} e^{q_{2} k_{t}}} \\ o_3 & = \frac{\sum_{t}^{3} v_t e^{q_{3} k_{t}}}{\sum_{t}^{3} e^{q_{3} k_{t}}} \\ & \vdots \\ \end{aligned}

From the equations above, in order to compute output tt', we need to keep k1:tk_{1:t'} to v1:tv_{1:t'} around in memory (although we only need qtq_t'). This means that the memory cost of the cache grows linearly as we compute more tokens, as does the computational cost. There aren’t any “cachable” computations between each row.

The high-level approach that both the RWKV and RetNet papers take is to avoid using QKTQ K^T by introducing some weight decay. Suppose we modified the above computation to look like this instead:

o1=t1vteqtktt1eqtkto2=t2vteqtktt2eqtkto3=t3vteqtktt3eqtkt \begin{aligned} o_1 & = \frac{\sum_{t}^{1} v_t e^{q_{t} k_{t}}}{\sum_{t}^{1} e^{q_{t} k_{t}}} \\ o_2 & = \frac{\sum_{t}^{2} v_t e^{q_{t} k_{t}}}{\sum_{t}^{2} e^{q_{t} k_{t}}} \\ o_3 & = \frac{\sum_{t}^{3} v_t e^{q_{t} k_{t}}}{\sum_{t}^{3} e^{q_{t} k_{t}}} \\ & \vdots \\ \end{aligned}

Note the slight difference - we use qtq_t instead of qtq_{t'} in both the numerator and denominator. This probably wouldn’t give us very good performance as a model, since we’ve done away with the entire concept of “attention”, but it would give us some nice computational properties. If we write the oto_{t'} as a numerator ntn_{t'} and denominator dtd_{t'}:

ot=ntdtnt=ttvteqtkt=vteqtkt+nt1dt=tteqtkt=eqtkt+dt1 \begin{aligned} o_{t'} & = \frac{n_{t'}}{d_{t'}} \\ n_{t'} & = \sum_{t}^{t'} v_t e^{q_{t} k_{t}} = v_{t'} e^{q_{t'} k_{t'}} + n_{t' - 1} \\ d_{t'} & = \sum_{t}^{t'} e^{q_{t} k_{t}} = e^{q_{t'} k_{t'}} + d_{t' - 1} \\ \end{aligned}

We could just cache the nt1n_{t' - 1} and dt1d_{t' - 1} tensors instead of keeping everything around in memory, giving us constant-time computations on each update step.

There is actually another approach we can use to convert the original attention equations to use constant memory, if we are willing to do away with the softmax. Consider the alternate equation below, which takes advantage of the associative property of matrix multiplication:

UnnormalizedAttention(Q,K,V)=(QKT)V=Q(KTV) \begin{aligned} \text{UnnormalizedAttention}(Q, K, V) & = (Q K^T) V \\ & = Q (K^T V) \end{aligned}

Without re-deriving all the equations above (including adding back our causal mask), we can write the outputs as:

o1=q1k1v1o2=q2(k1v1+k2v2)o3=q3(k1v1+k2v2+k3v3) \begin{aligned} o_1 & = q_1 k_1 v_1 \\ o_2 & = q_2 (k_1 v_1 + k_2 v_2) \\ o_3 & = q_3 (k_1 v_1 + k_2 v_2 + k_3 v_3) \\ & \vdots \\ \end{aligned}

Actually, this is essentially the approach that the RetNet paper takes, except to additionally include a decay term (really, multiple decay terms) to represent time:

o1=q1k1v1o2=q2(γk1v1+k2v2)o3=q3(γ2k1v1+γk2v2+k3v3) \begin{aligned} o_1 & = q_1 k_1 v_1 \\ o_2 & = q_2 (γ k_1 v_1 + k_2 v_2) \\ o_3 & = q_3 (γ^2 k_1 v_1 + γ k_2 v_2 + k_3 v_3) \\ & \vdots \\ \end{aligned}

This can be trivially rewritten as a recurrent relationship (I’ll leave this as an exercise for the reader).

The above equations look like they shouldn’t work. There’s an additional key component in the RetNet paper which involves applying a RoPE-like transformation to the keys and queries, which intuitively means that you’re letting the queries sort of search inside this recurrent vector space using the RoPE approach, which is a neat idea to think about.