Iāve recently been working on writing a CUDA kernel for a project Iāve been working on. I havenāt done much CUDA programming before, but itās been an interesting journey that I thought would be a useful exercise to share with other people. Iām going to assume some familiarity with PyTorch and the general idea of what CUDA is.

If youāre like me, you probably use PyTorch or Tensorflow as a deep learning API, and never really think too much about what is happening under-the-hood. I think this is a testament to how great these APIs are, and much they improve productivity. However, that makes it easy to forget that these APIs are under active development, and in fact there is some low-hanging fruit, performance-wise. PyTorchās `logsumexp`

is a good example of a function which is used liberally for some applications which it is not optimal for.

This idea was largely inspired by this repo from Harvard NLP, which provided a kernel for speeding up the log-sum-exp part of a CRF or HMM model. I was inspired to investigate this in greater detail.

In past posts Iāve given an introduction to the **forward-backward algorithm** and the **Viterbi algorithm**, two algorithms which are used for performing inference in Hidden Markov Models and Conditional Random Fields. In this post, Iām going to talk about one of the core concepts for making these models work, which is **log-space operations**. Doing inference in these models usually involves multiplying together very small numbers a large number of times, which can quickly become computationally intractable. Double-precision numbers are stored in 64 bits using 1 bit to represent the sign, 11 bits for the exponent, and 52 bits for the fraction.

This means we can exceed the precision of the exponent relatively easily, if our exponent cannot be represented in `2^11 = 2048`

bits. Consider the simple C++ program below, where we naively compute the value

`\frac{2^{2^{12}}}{2^{1 + 2^{12}}} = \frac{1}{2}`

by computing the numerator and denominator, then dividing.

We know the result should be `0.5`

. However, when we run it, it prints out the following:

Because performing these repeated multiplications can lead to this underflow problem relatively easily for models such as CRFs and HMMs, it behoves us to find a more numerically stable solution. In this case, we can take advantage of the following identities:

```
\begin{aligned}
x' & = \log(x) \\
x & = \exp(x') \\
x_1x_2 ... x_n & = \exp(x'_1 + x'_2 + ... + x'_n)
\end{aligned}
```

Instead of having to perform the numerically unstable multiplications, we can perform numerically stable additions on the logs of these values, and apply the exponential function once weāre done. If we want to perform additions on the logs of these values in a numerically stable way, we can naively do the following:

`x_1 + x_2 + ... + x_n = \exp(x'_1) + \exp(x'_2) + ... + \exp(x'_n)`

However, if the left side of this equation is once again very large (assuming we are going to divide it by something else later), this can lead to unwanted overflow. Instead, in practice, it is better to use the identity below, which is known as the `log-sum-exp`

function. By subtracting the max value out out of each of the components of the addition, we can usually keep the exponential part from blowing up too much.

```
\begin{aligned}
x^* & = \max(x'_1, x'_2, ..., x'_n) \\
x_1 + x_2 + ... x_n & = \exp(x^* + \log(\exp(x'_1 - x^*) + ... + \exp(x'_n - x^*)))
\end{aligned}
```

We can now re-write our C++ program from earlier:

This results in the correct answer.

In some literature, this is known as the log semiring; in particular, (Goodman 1999) showed how a number of common algorithms can simply be thought of as derivatives of value functions computed over different semirings (itās a really interesting mathematical paper and a great way to conceptualize CRFs and HMMs).

To provide some mathematical formalism for the examples above, itās important to expand on the semiring concept. Itās actually pretty straight-forward, even if it sounds a bit complicated at first. The pair of functions (`sum`

, `logsumexp`

) is an example of a semiring, meaning that it generalizes the multiplication and addition functions. This is some mathematical jargon which is easier to explain with an example. Lets define two semirings:

```
\begin{aligned}
a \oplus_{\text{normal}} b & = a + b\\
a \otimes_{\text{normal}} b & = a b \\
a \oplus_{\text{log}} b & = \text{logsumexp}(a, b) \\
a \otimes_{\text{log}} b & = a + b
\end{aligned}
```

We can then switch our operations between the two semirings:

```
\begin{aligned}
a \oplus_{\text{normal}} b & = \exp(\log a \oplus_{\text{log}} \log b) \\
a \otimes_{\text{normal}} b & = \exp(\log a \otimes_{\text{log}} \log b)
\end{aligned}
```

This is the heart of what weāre doing. Since the log semiring is much more mathematically stable when weāre dealing with probabilities than the normal semiring, we convert our data to log-space, do the computations, then convert back.

When implementing a CRF model in PyTorch, one of the core building blocks is being able to do the `log-sum-exp`

operation over pairs of matrices. Specifically, when we are building our dynamic programming tables, on each timestep `i`

we multiply (in log space, add) the potentials with the states from the previous timestep `i-1`

, then add (in log space, log-sum-exp) the results together to get the new state. Fortunately, in PyTorch, the `logsumexp`

function has already been implemented. Here is the PyTorch version of the function weāre trying to optimize, plus the code for benchmarking:

Here is our simple implementation of the log-sum-exp function in PyTorch:

Iām a big fan of using click for command line tools, rather than `argparse`

- I find that it simplifies building hierarchical tools and looks nicer as code. Hereās the benchmarking code:

Lastly, at the end of the file itās important to add some boilerplate to run the script:

The naive implementation of this function as implemented gives us a useful baseline. Running the benchmark code above for this function gives us the following memory usage stats (note that the y-axis is log-scaled):

Here is the corresponding chart for the runtime:

There is a very simple speed-up that we can do on the above function by preserve memory continuity. Note that the reduction part of the forward pass will take place over the last dimension, so we want to make sure the last dimension is contiguous in memory. If we remove the `transpose`

part, the forward pass can be performed much faster.

The memory usage for this function is identical to the corresponding function:

However, the runtime for the forward and backward passes is slightly faster:

This is a useful lesson: **where possible, avoid transposes**. Note that the above function will return a different result from our canonical implementation, so it is the callerās responsibility to make sure the inputs are correct.

Letās write a CUDA implementation of the above function, to see if we can improve the performance.

We can write the `log_bmm`

function as a matrix-matrix operation (the batch part can be added trivially in the CUDA implementation). For a regular batch matrix multiplication function, we expect as our inputs two matrices with elements `a_{i, j}`

and `b_{i, j}`

. We will output a matrix with elements `o_{i, j}`

, which is defined as the following:

`o_{i, j} = \sum_k a_{i, k} b_{k, j}`

The log-space version of this function is instead:

`o_{i, j} = \log \sum_k \exp(a_{i, k} + b_{k, j})`

Note that, to make this function mathematically stable, we use the `logsumexp`

trick above, rather than naively summing over the exponents.

We can differentiate the above function with respect to each `a_{i, k}`

and `b_{k, j}`

giving:

```
\begin{aligned}
\frac{\delta o_{i, j}}{\delta a_{i, k}} = \frac{\delta o_{i, j}}{\delta b_{k, j}} = & \frac{\exp(a_{i, k} + b_{k, j})}{\sum_{k'} \exp(a_{i, k'} + b_{k', j})} \\
= & \frac{\exp(a_{i, k} + b_{k, j})}{\exp(o_{i, j})} \\
= & \exp(a_{i, k} + b_{k, j} - o_{i, j})
\end{aligned}
```

This means that gradients of the loss function with respect to `a_{i, k}`

can be written as the accumulation of all of the gradients `\frac{\delta L}{\delta o_{i, j}}`

:

`\frac{\delta L}{\delta a_{i, k}} = \sum_j \exp(a_{i, k} + b_{k, j} - o_{i, j}) \frac{\delta L}{\delta o_{i, j}}`

Similarly, the gradient with respect to `b_{k, j}`

can be written as:

`\frac{\delta L}{\delta b_{j, k}} = \sum_i \exp(a_{i, k} + b_{k, j} - o_{i, j}) \frac{\delta L}{\delta o_{i, j}}`

For the CUDA implementation below, Iām using some of the constants defined in this post. Hereās the relevant headers and aliases:

Letās write the forward pass of the algorithm. Here are some notes about this implementation:

- In log space, zero is represented as negative infinity. For practical purposes we can just choose a large negative number.
- We first find the maximum value over each element, then add together all the elements minus this maximum. This is mathematically identical to the formulation above (although, as weāll see below, this can be improved on).
- We can fiddle with the number of threads. For the block size, we use the identity
`(x + y - 1) / y`

to do integer division`x / y`

rounding up, to ensure that we are allocating a sufficient number of blocks.

Letās see the code:

Great! Letās see what the backward pass looks like. Note that we have to backpropagate to both input tensors, so we need two kernels running in parallel streams. While there is more code, it largely uses the same general idea as the forward pass kernel.

Lastly, weāll add some boilerplate for pybind11 to be able to access our functions from the Python side. I think it usually makes sense to have the forward and backward passes in their own submodule, for coherence.

There is some additional boilerplate that is missing from the above code. See here for a complete tutorial on how to compile this CUDA kernel for your extension. The general idea for incorporating this as a PyTorch function is:

Plugging this function into our performance benchmarking script gives the following chart for memory usage:

Here is the corresponding chart for the runtime:

This is a quite significant improvement in memory usage and runtime! It turns out that we can save a substantial amount of memory by writing a pure CUDA implementation.

There are a few things we can do to improve on this baseline CUDA implementation.

We can cut the number of memory accesses in the forward function in half by performing element-wise `logsumexp`

instead of getting the global max. We can do that with the following function:

This gives the following graph for memory usage:

Here is the corresponding chart for the runtime:

There has been a lot of work on how to optimize reduction operations on GPUs, including a great tutorial by NVIDIA. There are a lot of tricks involved in doing this efficiently (for more info, see that post), but the basic idea is that we want to avoid doing the reduction operations serially, like in the image below.

Instead, when the reduction operation is **associative** and **commutative** (which, fortunately, is the case for all semirings, not just the one in question), we can perform them with `O(\log(N))`

parallel steps, as in the tree below.

These performance boosts donāt really apply for reduce operations which are relatively small (and by āsmallā, I mean less than ~1000 dimensions).

Links: Resume Github Twitter Email Feed Directory