Simulation / Modeling / Design

CUDA Pro Tip: Write Flexible Kernels with Grid-Stride Loops

GPU Pro Tip

One of the most common tasks in CUDA programming is to parallelize a loop using a kernel. As an example, let’s use our old friend SAXPY. Here’s the basic sequential implementation, which uses a for loop. To efficiently parallelize this, we need to launch enough threads to fully utilize the GPU.

void saxpy(int n, float a, float *x, float *y)
{
    for (int i = 0; i < n; ++i)
        y[i] = a * x[i] + y[i];
}

Common CUDA guidance is to launch one thread per data element, which means to parallelize the above SAXPY loop we write a kernel that assumes we have enough threads to more than cover the array size.

__global__
void saxpy(int n, float a, float *x, float *y)
{
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) 
        y[i] = a * x[i] + y[i];
}

I’ll refer to this style of kernel as a monolithic kernel, because it assumes a single large grid of threads to process the entire array in one pass. You might use the following code to launch the saxpy kernel to process one million elements.

// Perform SAXPY on 1M elements
saxpy<<<4096,256>>>(1<<20, 2.0, x, y);

Instead of completely eliminating the loop when parallelizing the computation, I recommend to use a grid-stride loop, as in the following kernel.

__global__
void saxpy(int n, float a, float *x, float *y)
{
    for (int i = blockIdx.x * blockDim.x + threadIdx.x; 
         i < n; 
         i += blockDim.x * gridDim.x) 
      {
          y[i] = a * x[i] + y[i];
      }
}

Rather than assume that the thread grid is large enough to cover the entire data array, this kernel loops over the data array one grid-size at a time.

Notice that the stride of the loop is blockDim.x * gridDim.x which is the total number of threads in the grid. So if there are 1280 threads in the grid, thread 0 will compute elements 0, 1280, 2560, etc. This is why I call this a grid-stride loop. By using a loop with stride equal to the grid size, we ensure that all addressing within warps is unit-stride, so we get maximum memory coalescing, just as in the monolithic version.

When launched with a grid large enough to cover all iterations of the loop, the grid-stride loop should have essentially the same instruction cost as the if statement in the monolithic kernel, because the loop increment will only be evaluated when the loop condition evaluates to true.

There are several benefits to using a grid-stride loop.

  1. Scalability and thread reuse. By using a loop, you can support any problem size even if it exceeds the largest grid size your CUDA device supports. Moreover, you can limit the number of blocks you use to tune performance. For example, it’s often useful to launch a number of blocks that is a multiple of the number of multiprocessors on the device, to balance utilization. As an example, we might launch the loop version of the kernel like this.
    int numSMs;
    cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, devId);
    // Perform SAXPY on 1M elements
    saxpy<<<32*numSMs, 256>>>(1 << 20, 2.0, x, y);

    When you limit the number of blocks in your grid, threads are reused for multiple computations. Thread reuse amortizes thread creation and destruction cost along with any other processing the kernel might do before or after the loop (such as thread-private or shared data initialization).

  2. Debugging. By using a loop instead of a monolithic kernel, you can easily switch to serial processing by launching one block with one thread.
    saxpy<<<1,1>>>(1<<20, 2.0, x, y);

    This makes it easier to emulate a serial host implementation to validate results, and it can make printf debugging easier by serializing the print order. Serializing the computation also allows you to eliminate numerical variations caused by changes in the order of operations from run to run, helping you to verify that your numerics are correct before tuning the parallel version.

  3. Portability and readability.The grid-stride loop code is more like the original sequential loop code than the monolithic kernel code, making it clearer for other users. In fact we can pretty easily write a version of the kernel that compiles and runs either as a parallel CUDA kernel on the GPU or as a sequential loop on the CPU. The Hemi library provides a `grid_stride_range()` helper that makes this trivial using C++11 range-based for loops.
    HEMI_LAUNCHABLE
    void saxpy(int n, float a, float *x, float *y)
    {
      for (auto i : hemi::grid_stride_range(0, n)) {
        y[i] = a * x[i] + y[i];
      }
    }

    We can launch the kernel using this code, which generates a kernel launch when compiled for CUDA, or a function call when compiled for the CPU.

    hemi::cudaLaunch(saxpy, 1<<20, 2.0, x, y);

Grid-stride loops are a great way to make your CUDA kernels flexible, scalable, debuggable, and even portable. While the examples in this post have all used CUDA C/C++, the same concepts apply in other CUDA languages such as CUDA Fortran.

I’d like to thank Justin Luitjens from the NVIDIA Developer Technology group for the idea and many of the details in this CUDA Pro Tip.

Discuss (18)

Tags