[OCL] Automatically determine dot NDRange config
This commit is contained in:
parent
bd0ba4dee9
commit
21556af500
@ -90,9 +90,22 @@ OCLStream<T>::OCLStream(const unsigned int ARRAY_SIZE, const int device_index)
|
|||||||
throw std::runtime_error("Invalid device index");
|
throw std::runtime_error("Invalid device index");
|
||||||
device = devices[device_index];
|
device = devices[device_index];
|
||||||
|
|
||||||
|
// Determine sensible dot kernel NDRange configuration
|
||||||
|
if (device.getInfo<CL_DEVICE_TYPE>() & CL_DEVICE_TYPE_CPU)
|
||||||
|
{
|
||||||
|
dot_num_groups = device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
|
||||||
|
dot_wgsize = device.getInfo<CL_DEVICE_NATIVE_VECTOR_WIDTH_DOUBLE>() * 2;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
dot_num_groups = device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>() * 4;
|
||||||
|
dot_wgsize = device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>();
|
||||||
|
}
|
||||||
|
|
||||||
// Print out device information
|
// Print out device information
|
||||||
std::cout << "Using OpenCL device " << getDeviceName(device_index) << std::endl;
|
std::cout << "Using OpenCL device " << getDeviceName(device_index) << std::endl;
|
||||||
std::cout << "Driver: " << getDeviceDriver(device_index) << std::endl;
|
std::cout << "Driver: " << getDeviceDriver(device_index) << std::endl;
|
||||||
|
std::cout << "Dot kernel config: " << dot_num_groups << " groups of size " << dot_wgsize << std::endl;
|
||||||
|
|
||||||
context = cl::Context(device);
|
context = cl::Context(device);
|
||||||
queue = cl::CommandQueue(context);
|
queue = cl::CommandQueue(context);
|
||||||
@ -147,9 +160,9 @@ OCLStream<T>::OCLStream(const unsigned int ARRAY_SIZE, const int device_index)
|
|||||||
d_a = cl::Buffer(context, CL_MEM_READ_WRITE, sizeof(T) * ARRAY_SIZE);
|
d_a = cl::Buffer(context, CL_MEM_READ_WRITE, sizeof(T) * ARRAY_SIZE);
|
||||||
d_b = cl::Buffer(context, CL_MEM_READ_WRITE, sizeof(T) * ARRAY_SIZE);
|
d_b = cl::Buffer(context, CL_MEM_READ_WRITE, sizeof(T) * ARRAY_SIZE);
|
||||||
d_c = cl::Buffer(context, CL_MEM_READ_WRITE, sizeof(T) * ARRAY_SIZE);
|
d_c = cl::Buffer(context, CL_MEM_READ_WRITE, sizeof(T) * ARRAY_SIZE);
|
||||||
d_sum = cl::Buffer(context, CL_MEM_WRITE_ONLY, sizeof(T) * DOT_NUM_GROUPS);
|
d_sum = cl::Buffer(context, CL_MEM_WRITE_ONLY, sizeof(T) * dot_num_groups);
|
||||||
|
|
||||||
sums = std::vector<T>(DOT_NUM_GROUPS);
|
sums = std::vector<T>(dot_num_groups);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
@ -205,8 +218,8 @@ template <class T>
|
|||||||
T OCLStream<T>::dot()
|
T OCLStream<T>::dot()
|
||||||
{
|
{
|
||||||
(*dot_kernel)(
|
(*dot_kernel)(
|
||||||
cl::EnqueueArgs(queue, cl::NDRange(DOT_NUM_GROUPS*DOT_WGSIZE), cl::NDRange(DOT_WGSIZE)),
|
cl::EnqueueArgs(queue, cl::NDRange(dot_num_groups*dot_wgsize), cl::NDRange(dot_wgsize)),
|
||||||
d_a, d_b, d_sum, cl::Local(sizeof(T) * DOT_WGSIZE), array_size
|
d_a, d_b, d_sum, cl::Local(sizeof(T) * dot_wgsize), array_size
|
||||||
);
|
);
|
||||||
cl::copy(queue, d_sum, sums.begin(), sums.end());
|
cl::copy(queue, d_sum, sums.begin(), sums.end());
|
||||||
|
|
||||||
|
|||||||
@ -21,10 +21,6 @@
|
|||||||
|
|
||||||
#define IMPLEMENTATION_STRING "OpenCL"
|
#define IMPLEMENTATION_STRING "OpenCL"
|
||||||
|
|
||||||
// NDRange configuration for the dot kernel
|
|
||||||
#define DOT_WGSIZE 256
|
|
||||||
#define DOT_NUM_GROUPS 256
|
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
class OCLStream : public Stream<T>
|
class OCLStream : public Stream<T>
|
||||||
{
|
{
|
||||||
@ -52,6 +48,10 @@ class OCLStream : public Stream<T>
|
|||||||
cl::KernelFunctor<cl::Buffer, cl::Buffer, cl::Buffer> *triad_kernel;
|
cl::KernelFunctor<cl::Buffer, cl::Buffer, cl::Buffer> *triad_kernel;
|
||||||
cl::KernelFunctor<cl::Buffer, cl::Buffer, cl::Buffer, cl::LocalSpaceArg, cl_int> *dot_kernel;
|
cl::KernelFunctor<cl::Buffer, cl::Buffer, cl::Buffer, cl::LocalSpaceArg, cl_int> *dot_kernel;
|
||||||
|
|
||||||
|
// NDRange configuration for the dot kernel
|
||||||
|
size_t dot_num_groups;
|
||||||
|
size_t dot_wgsize;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
OCLStream(const unsigned int, const int);
|
OCLStream(const unsigned int, const int);
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user