I’ve recently been working on writing a CUDA kernel for a project I’ve been working on. I haven’t done much CUDA programming before, but it’s been an interesting journey that I thought would be a useful exercise to share with other people. I’m going to assume some familiarity with PyTorch and the general idea of what CUDA is.

If you’re like me, you probably use PyTorch or Tensorflow as a deep learning API, and never really think too much about what is happening under-the-hood. I think this is a testament to how great these APIs are, and much they improve productivity. However, that makes it easy to forget that these APIs are under active development, and in fact there is some low-hanging fruit, performance-wise. PyTorch’s logsumexp is a good example of a function which is used liberally for some applications which it is not optimal for.

This idea was largely inspired by this repo from Harvard NLP, which provided a kernel for speeding up the log-sum-exp part of a CRF or HMM model. I was inspired to investigate this in greater detail.

Introduction

In past posts I’ve given an introduction to the forward-backward algorithm and the Viterbi algorithm, two algorithms which are used for performing inference in Hidden Markov Models and Conditional Random Fields. In this post, I’m going to talk about one of the core concepts for making these models work, which is log-space operations. Doing inference in these models usually involves multiplying together very small numbers a large number of times, which can quickly become computationally intractable. Double-precision numbers are stored in 64 bits using 1 bit to represent the sign, 11 bits for the exponent, and 52 bits for the fraction.

This means we can exceed the precision of the exponent relatively easily, if our exponent cannot be represented in 2^11 = 2048 bits. Consider the simple C++ program below, where we naively compute the value

\frac{2^{2^{12}}}{2^{1 + 2^{12}}} = \frac{1}{2}

by computing the numerator and denominator, then dividing.

We know the result should be 0.5. However, when we run it, it prints out the following:

Log-Space Operations

Because performing these repeated multiplications can lead to this underflow problem relatively easily for models such as CRFs and HMMs, it behoves us to find a more numerically stable solution. In this case, we can take advantage of the following identities:

\begin{aligned}
x' & = \log(x) \\
x & = \exp(x') \\
x_1x_2 ... x_n & = \exp(x'_1 + x'_2 + ... + x'_n)
\end{aligned}

Instead of having to perform the numerically unstable multiplications, we can perform numerically stable additions on the logs of these values, and apply the exponential function once we’re done. If we want to perform additions on the logs of these values in a numerically stable way, we can naively do the following:

x_1 + x_2 + ... + x_n = \exp(x'_1) + \exp(x'_2) + ... + \exp(x'_n)

However, if the left side of this equation is once again very large (assuming we are going to divide it by something else later), this can lead to unwanted overflow. Instead, in practice, it is better to use the identity below, which is known as the log-sum-exp function. By subtracting the max value out out of each of the components of the addition, we can usually keep the exponential part from blowing up too much.

\begin{aligned}
x^* & = \max(x'_1, x'_2, ..., x'_n) \\
x_1 + x_2 + ... x_n & = \exp(x^* + \log(\exp(x'_1 - x^*) + ... + \exp(x'_n - x^*)))
\end{aligned}

We can now re-write our C++ program from earlier:

This results in the correct answer.

In some literature, this is known as the log semiring; in particular, (Goodman 1999) showed how a number of common algorithms can simply be thought of as derivatives of value functions computed over different semirings (it’s a really interesting mathematical paper and a great way to conceptualize CRFs and HMMs).

Mathematical Formalism

To provide some mathematical formalism for the examples above, it’s important to expand on the semiring concept. It’s actually pretty straight-forward, even if it sounds a bit complicated at first. The pair of functions (sum, logsumexp) is an example of a semiring, meaning that it generalizes the multiplication and addition functions. This is some mathematical jargon which is easier to explain with an example. Lets define two semirings:

\begin{aligned}
a \oplus_{\text{normal}} b & = a + b\\
a \otimes_{\text{normal}} b & = a b \\
a \oplus_{\text{log}} b & = \text{logsumexp}(a, b) \\
a \otimes_{\text{log}} b & = a + b
\end{aligned}

We can then switch our operations between the two semirings:

\begin{aligned}
a \oplus_{\text{normal}} b & = \exp(\log a \oplus_{\text{log}} \log b) \\
a \otimes_{\text{normal}} b & = \exp(\log a \otimes_{\text{log}} \log b)
\end{aligned}

This is the heart of what we’re doing. Since the log semiring is much more mathematically stable when we’re dealing with probabilities than the normal semiring, we convert our data to log-space, do the computations, then convert back.

Problem Statement

When implementing a CRF model in PyTorch, one of the core building blocks is being able to do the log-sum-exp operation over pairs of matrices. Specifically, when we are building our dynamic programming tables, on each timestep i we multiply (in log space, add) the potentials with the states from the previous timestep i-1, then add (in log space, log-sum-exp) the results together to get the new state. Fortunately, in PyTorch, the logsumexp function has already been implemented. Here is the PyTorch version of the function we’re trying to optimize, plus the code for benchmarking:

Here is our simple implementation of the log-sum-exp function in PyTorch:

I’m a big fan of using click for command line tools, rather than argparse - I find that it simplifies building hierarchical tools and looks nicer as code. Here’s the benchmarking code:

Lastly, at the end of the file it’s important to add some boilerplate to run the script:

The naive implementation of this function as implemented gives us a useful baseline. Running the benchmark code above for this function gives us the following memory usage stats (note that the y-axis is log-scaled):

Here is the corresponding chart for the runtime:

Simple Speed-up

There is a very simple speed-up that we can do on the above function by preserve memory continuity. Note that the reduction part of the forward pass will take place over the last dimension, so we want to make sure the last dimension is contiguous in memory. If we remove the transpose part, the forward pass can be performed much faster.

The memory usage for this function is identical to the corresponding function:

However, the runtime for the forward and backward passes is slightly faster:

This is a useful lesson: where possible, avoid transposes. Note that the above function will return a different result from our canonical implementation, so it is the caller’s responsibility to make sure the inputs are correct.

CUDA Implementation

Let’s write a CUDA implementation of the above function, to see if we can improve the performance.

We can write the log_bmm function as a matrix-matrix operation (the batch part can be added trivially in the CUDA implementation). For a regular batch matrix multiplication function, we expect as our inputs two matrices with elements a_{i, j}$a_{i, j}$ and b_{i, j}$b_{i, j}$. We will output a matrix with elements o_{i, j}$o_{i, j}$, which is defined as the following:

o_{i, j} = \sum_k a_{i, k} b_{k, j}

The log-space version of this function is instead:

o_{i, j} = \log \sum_k \exp(a_{i, k} + b_{k, j})

Note that, to make this function mathematically stable, we use the logsumexp trick above, rather than naively summing over the exponents.

We can differentiate the above function with respect to each a_{i, k}$a_{i, k}$ and b_{k, j}$b_{k, j}$giving:

\begin{aligned}
\frac{\delta o_{i, j}}{\delta a_{i, k}} = \frac{\delta o_{i, j}}{\delta b_{k, j}} = & \frac{\exp(a_{i, k} + b_{k, j})}{\sum_{k'} \exp(a_{i, k'} + b_{k', j})} \\
= & \frac{\exp(a_{i, k} + b_{k, j})}{\exp(o_{i, j})} \\
= & \exp(a_{i, k} + b_{k, j} - o_{i, j})
\end{aligned}

This means that gradients of the loss function with respect to a_{i, k}$a_{i, k}$ can be written as the accumulation of all of the gradients \frac{\delta L}{\delta o_{i, j}}$\frac{\delta L}{\delta o_{i, j}}$:

\frac{\delta L}{\delta a_{i, k}} = \sum_j \exp(a_{i, k} + b_{k, j} - o_{i, j}) \frac{\delta L}{\delta o_{i, j}}

Similarly, the gradient with respect to b_{k, j}$b_{k, j}$ can be written as:

\frac{\delta L}{\delta b_{j, k}} = \sum_i \exp(a_{i, k} + b_{k, j} - o_{i, j}) \frac{\delta L}{\delta o_{i, j}}

For the CUDA implementation below, I’m using some of the constants defined in this post. Here’s the relevant headers and aliases:

• In log space, zero is represented as negative infinity. For practical purposes we can just choose a large negative number.
• We first find the maximum value over each element, then add together all the elements minus this maximum. This is mathematically identical to the formulation above (although, as we’ll see below, this can be improved on).
• We can fiddle with the number of threads. For the block size, we use the identity (x + y - 1) / y to do integer division x / y rounding up, to ensure that we are allocating a sufficient number of blocks.

Let’s see the code:

Great! Let’s see what the backward pass looks like. Note that we have to backpropagate to both input tensors, so we need two kernels running in parallel streams. While there is more code, it largely uses the same general idea as the forward pass kernel.

Lastly, we’ll add some boilerplate for pybind11 to be able to access our functions from the Python side. I think it usually makes sense to have the forward and backward passes in their own submodule, for coherence.

There is some additional boilerplate that is missing from the above code. See here for a complete tutorial on how to compile this CUDA kernel for your extension. The general idea for incorporating this as a PyTorch function is:

Plugging this function into our performance benchmarking script gives the following chart for memory usage:

Here is the corresponding chart for the runtime:

This is a quite significant improvement in memory usage and runtime! It turns out that we can save a substantial amount of memory by writing a pure CUDA implementation.

Improving the CUDA Implementation

There are a few things we can do to improve on this baseline CUDA implementation.

Reduce Memory Accesses

We can cut the number of memory accesses in the forward function in half by performing element-wise logsumexp instead of getting the global max. We can do that with the following function:

This gives the following graph for memory usage:

Here is the corresponding chart for the runtime:

Tree Reduction

There has been a lot of work on how to optimize reduction operations on GPUs, including a great tutorial by NVIDIA. There are a lot of tricks involved in doing this efficiently (for more info, see that post), but the basic idea is that we want to avoid doing the reduction operations serially, like in the image below.

Instead, when the reduction operation is associative and commutative (which, fortunately, is the case for all semirings, not just the one in question), we can perform them with O(\log(N))$O(\log(N))$ parallel steps, as in the tree below.

These performance boosts don’t really apply for reduce operations which are relatively small (and by “small”, I mean less than ~1000 dimensions).