Previous | Next --- Slide 35 of 59
Back to Lecture Thumbnails
kayvonf

Question: Can someone describe how this implementation of scan works?

russt17

Okay I think I've got it:

Step 1: Scan is done in chunks of 32 values (one warp) at ptr, and each thread gets back val, which is the the cumulative scan value at that thread's index after the 32-chunk-size scan. At this point val and ptr[idx] store the same thing.

Step 2: Since each thread has the local variable val, we can now reuse ptr for scratch space. Here we store the final scan value for each warp (which is located in both val and ptr[idx] for threads with lane == 31). If there are n warps, the first n values of ptr now store the sums of each warp's worth of values. Note that this code assumes #warps < 32

Step 3: Scan on the warp sums to accumulate them.

Step 4: Remember that for scans, the operation OP is an associative binary operator. Here, for each thread with warp_id > 0 we accumulate the sum of all previous warps (which we have from step 3) with the value held by this thread (from step 1). For warp_id == 0 we don't have to add any other warp sums.

Finally, each thread has the correct accumulation in val, and it is stored into ptr[idx].

jcarchi

I dont quite get step 2 in the explanation, could you explain more on how the scratch space works? Thanks!

russt17

@jcarchi

What I meant by calling the space allocated to ptr scratch space was that the space is actually multi-purposed throughout the function call and is used to hold intermediate values that the caller never sees. At the start of the call, the space at ptr holds the input array, and when the call returns the space at ptr holds the output array, but in steps 2 and 3 the space at ptr is repurposed to hold intermediate values.