Torch CUDA Extension Tricks


Some tricks I found useful for writing CUDA extensions for PyTorch.

May 06, 2020

This is a tracking document for some things I’ve found useful when writing CUDA extensions for PyTorch.

Python

I found it useful to put these at the top of my Python file. manual_seed is for reproducability and set_printoptions is to make it easier to quickly identify whether or not two numbers match up.

torch.manual_seed(seed)
torch.set_printoptions(precision=6, sci_mode=False)

CUDA Debugging

This answer suggests the first step for debugging CUDA code is to enable CUDA launch blocking using this at the top of the Python file:

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

However, this didn’t work for a weird memory access issue I was having. This guide was more helpful. Actually, I had a minor improvement over that function with:

#define CUDA_CHECK(X)                                                          \
  do {                                                                         \
    cudaError_t err = X;                                                       \
    if (err != cudaSuccess) {                                                  \
      std::cerr << "CUDA error in " << __FILE__ << "(" << __LINE__             \
                << "): " << cudaGetErrorString(err) << std::endl;              \
      exit(EXIT_FAILURE);                                                      \
    }                                                                          \
  } while (0);

#define cudaMemoryTestREMOVE_WHEN_DONE()                                       \
  do {                                                                         \
    const int N = 1337, bytes = N * sizeof(float);                             \
    std::vector<float> cpuvec(N);                                              \
    for (size_t i = 0; i < N; i++)                                             \
      cpuvec[i] = (float)i;                                                    \
    float *gpuvec = NULL;                                                      \
    CUDA_CHECK(cudaMalloc(&gpuvec, bytes));                                    \
    assert(gpuvec != NULL);                                                    \
    CUDA_CHECK(                                                                \
        cudaMemcpy(gpuvec, cpuvec.data(), bytes, cudaMemcpyHostToDevice));     \
    CUDA_CHECK(                                                                \
        cudaMemcpy(cpuvec.data(), gpuvec, bytes, cudaMemcpyDeviceToHost));     \
    CUDA_CHECK(cudaFree(gpuvec));                                              \
  } while (0);

Math Tricks

I found it to be very important to know how to divide rounding up. Normally integer division rounds down. For example, the code below will print 0, 0, 0, 0, 1, 1, 1, 1, 1, 2.

int denominator = 5;
for (int numerator = 1; numerator < 10; numerator++)
  printf("%d\n", numerator / denominator);

To round up, you use the identity (numerator + denominator - 1) / denominator. The code below will print 1, 1, 1, 1, 1, 2, 2, 2, 2, 2:

int denominator = 5;
for (int numerator = 1; numerator < 10; numerator++)
  printf("%d\n", (numerator + denominator - 1) / denominator);

C++ Definitions

These are some definitions which I found useful.

// Short-hand for getting a packed accessor of a particular type.
#define ACCESSOR(x, n, type)                                                   \
  x.packed_accessor32<type, n, torch::RestrictPtrTraits>()

// Short-hand for getting the CUDA thread index.
#define CUDA_IDX(x) (blockIdx.x * blockDim.x + threadIdx.x)

// Checks that a number is between two other numbers.
#define CHECK_BETWEEN(v, a, b, n)                                              \
  TORCH_CHECK(v >= a && v < b, n, " should be between ", a, " and ", b,        \
              ", got ", v);

// Checks that a tensor has the right dimensionality at some index.
#define CHECK_DIM(x, v, n, d)                                                  \
  TORCH_CHECK(x.size(d) == v, n, " should have size ", v, " in dimension ", d, \
              ", got ", x.size(d))

// Checks that a tensor has the right number of dimensions.
#define CHECK_DIMS(x, d, n)                                                    \
  TORCH_CHECK(x.ndimension() == d, n, " should have ", d, " dimensions, got ", \
              x.ndimension())

// Checks that the tensor is a CUDA tensor.
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")

// Checks that the tensor is contiguous.
#define CHECK_CONTIGUOUS(x)                                                    \
  TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

// Combines the CUDA tensor and contiguousness checks.
#define CHECK_INPUT(x)                                                         \
  CHECK_CUDA(x);                                                               \
  CHECK_CONTIGUOUS(x)

Additionally, here’s a useful struct to handle CUDA streams.

template <int N> struct streams_t {
  cudaStream_t streams[N];

  streams_t() {
    for (int i = 0; i < N; i++)
      cudaStreamCreate(&streams[i]);
  }

  ~streams_t() {
    sync();
    destroy();
  }

  void destroy() {
    for (int i = 0; i < N; i++)
      cudaStreamDestroy(streams[i]);
  }

  void sync() {
    for (int i = 0; i < N; i++)
      cudaStreamSynchronize(streams[i]);
  }

  void sync(int i) {
    assert(i < N);
    cudaStreamSynchronize(streams[i]);
  }

  cudaStream_t get(int i) const {
    assert(i < N);
    return streams[i];
  }
};

Here’s a worthwhile addition to print dim3 objects.

#include <iostream>

std::ostream &operator<<(std::ostream &os, const dim3 &d) {
  os << "(" << d.x << ", " << d.y << ", " << d.z << ")";
  return os;
}

Additional Resources

Below are some of the resources that I found useful.