Set CUDA dot kernel to use number of blocks relative to device property

This aligns with the approach implemented in other models (SYCL 1.2.1 and HIP)

Cherry-picks the CUDA updates from lmeadows in #122
This commit is contained in:
Tom Deakin 2023-10-06 17:56:42 +01:00
parent 92fed7082b
commit 9954b7d38c
2 changed files with 18 additions and 11 deletions

View File

@ -51,16 +51,22 @@ CUDAStream<T>::CUDAStream(const int ARRAY_SIZE, const int device_index)
#endif #endif
array_size = ARRAY_SIZE; array_size = ARRAY_SIZE;
// Query device for sensible dot kernel block count
cudaDeviceProp props;
cudaGetDeviceProperties(&props, device_index);
check_error();
dot_num_blocks = props.multiProcessorCount * 4;
// Allocate the host array for partial sums for dot kernels // Allocate the host array for partial sums for dot kernels
sums = (T*)malloc(sizeof(T) * DOT_NUM_BLOCKS); sums = (T*)malloc(sizeof(T) * dot_num_blocks);
size_t array_bytes = sizeof(T); size_t array_bytes = sizeof(T);
array_bytes *= ARRAY_SIZE; array_bytes *= ARRAY_SIZE;
size_t total_bytes = array_bytes * 3; size_t total_bytes = array_bytes * 4;
std::cout << "Reduction kernel config: " << dot_num_blocks << " groups of (fixed) size " << TBSIZE << std::endl;
// Check buffers fit on the device // Check buffers fit on the device
cudaDeviceProp props;
cudaGetDeviceProperties(&props, 0);
if (props.totalGlobalMem < total_bytes) if (props.totalGlobalMem < total_bytes)
throw std::runtime_error("Device does not have enough memory for all 3 buffers"); throw std::runtime_error("Device does not have enough memory for all 3 buffers");
@ -72,13 +78,13 @@ CUDAStream<T>::CUDAStream(const int ARRAY_SIZE, const int device_index)
check_error(); check_error();
cudaMallocManaged(&d_c, array_bytes); cudaMallocManaged(&d_c, array_bytes);
check_error(); check_error();
cudaMallocManaged(&d_sum, DOT_NUM_BLOCKS*sizeof(T)); cudaMallocManaged(&d_sum, dot_num_blocks*sizeof(T));
check_error(); check_error();
#elif defined(PAGEFAULT) #elif defined(PAGEFAULT)
d_a = (T*)malloc(array_bytes); d_a = (T*)malloc(array_bytes);
d_b = (T*)malloc(array_bytes); d_b = (T*)malloc(array_bytes);
d_c = (T*)malloc(array_bytes); d_c = (T*)malloc(array_bytes);
d_sum = (T*)malloc(sizeof(T)*DOT_NUM_BLOCKS); d_sum = (T*)malloc(sizeof(T)*dot_num_blocks);
#else #else
cudaMalloc(&d_a, array_bytes); cudaMalloc(&d_a, array_bytes);
check_error(); check_error();
@ -86,7 +92,7 @@ CUDAStream<T>::CUDAStream(const int ARRAY_SIZE, const int device_index)
check_error(); check_error();
cudaMalloc(&d_c, array_bytes); cudaMalloc(&d_c, array_bytes);
check_error(); check_error();
cudaMalloc(&d_sum, DOT_NUM_BLOCKS*sizeof(T)); cudaMalloc(&d_sum, dot_num_blocks*sizeof(T));
check_error(); check_error();
#endif #endif
} }
@ -267,19 +273,19 @@ __global__ void dot_kernel(const T * a, const T * b, T * sum, int array_size)
template <class T> template <class T>
T CUDAStream<T>::dot() T CUDAStream<T>::dot()
{ {
dot_kernel<<<DOT_NUM_BLOCKS, TBSIZE>>>(d_a, d_b, d_sum, array_size); dot_kernel<<<dot_num_blocks, TBSIZE>>>(d_a, d_b, d_sum, array_size);
check_error(); check_error();
#if defined(MANAGED) || defined(PAGEFAULT) #if defined(MANAGED) || defined(PAGEFAULT)
cudaDeviceSynchronize(); cudaDeviceSynchronize();
check_error(); check_error();
#else #else
cudaMemcpy(sums, d_sum, DOT_NUM_BLOCKS*sizeof(T), cudaMemcpyDeviceToHost); cudaMemcpy(sums, d_sum, dot_num_blocks*sizeof(T), cudaMemcpyDeviceToHost);
check_error(); check_error();
#endif #endif
T sum = 0.0; T sum = 0.0;
for (int i = 0; i < DOT_NUM_BLOCKS; i++) for (int i = 0; i < dot_num_blocks; i++)
{ {
#if defined(MANAGED) || defined(PAGEFAULT) #if defined(MANAGED) || defined(PAGEFAULT)
sum += d_sum[i]; sum += d_sum[i];

View File

@ -16,7 +16,6 @@
#define IMPLEMENTATION_STRING "CUDA" #define IMPLEMENTATION_STRING "CUDA"
#define TBSIZE 1024 #define TBSIZE 1024
#define DOT_NUM_BLOCKS 1024
template <class T> template <class T>
class CUDAStream : public Stream<T> class CUDAStream : public Stream<T>
@ -34,6 +33,8 @@ class CUDAStream : public Stream<T>
T *d_c; T *d_c;
T *d_sum; T *d_sum;
// Number of blocks for dot kernel
int dot_num_blocks;
public: public: