I recently heard about a new paper entitled Retentive Network: A Successor to Transformer for Large Language Models which purports to deliver Transformer-level performance with constant memory costs at inference time. It is pretty similar to the RWKV paper, which I've described in a previous post. I figured I'd write a short post walking through the problem that these two papers are trying to solve:
The cost of doing inference with a self-attention model scales linearly with the number of input timesteps.
To rephrase the above issue, as the number of timesteps in our sequence increases, so does the computational cost of computing the next predicted token. The RetNet and RWKV language models attempt to ameliorate these issues using two approaches. In this post I'll give a hand-wavy explanation for how you might come up with these models if you were interested in changing the attention architecture to reduce the runtime computational cost.
First consider the vanilla self-attention equations:
SoftmaxAttention(Q,K,V)=softmax(dkQKT)V
The Q, K and V matrices have shape (t,d) where t is the number of tokens in the input sequence and d is the dimensionality of the embedding. We'll ignore the d dimension to make it easier to visualize. The QKT matrix multiplication looks like this:
From the equations above, in order to compute output t′, we need to keep k1:t′ to v1:t′ around in memory (although we only need qt′). This means that the memory cost of the cache grows linearly as we compute more tokens, as does the computational cost. There aren't any "cachable" computations between each row.
The high-level approach that both the RWKV and RetNet papers take is to avoid using QKT by introducing some weight decay. Suppose we modified the above computation to look like this instead:
Note the slight difference - we use qt instead of qt′ in both the numerator and denominator. This probably wouldn't give us very good performance as a model, since we've done away with the entire concept of "attention", but it would give us some nice computational properties. If we write the ot′ as a numerator nt′ and denominator dt′:
We could just cache the nt′−1 and dt′−1 tensors instead of keeping everything around in memory, giving us constant-time computations on each update step.
There is actually another approach we can use to convert the original attention equations to use constant memory, if we are willing to do away with the softmax. Consider the alternate equation below, which takes advantage of the associative property of matrix multiplication:
UnnormalizedAttention(Q,K,V)=(QKT)V=Q(KTV)
Without re-deriving all the equations above (including adding back our causal mask), we can write the outputs as:
Actually, this is essentially the approach that the RetNet paper takes, except to additionally include a decay term (really, multiple decay terms) to represent time:
This can be trivially rewritten as a recurrent relationship (I'll leave this as an exercise for the reader).
The above equations look like they shouldn't work. There's an additional key component in the RetNet paper which involves applying a RoPE-like transformation to the keys and queries, which intuitively means that you're letting the queries sort of search inside this recurrent vector space using the RoPE approach, which is a neat idea to think about.