[SYCL] Automatically determine dot NDRange config
This commit is contained in:
parent
21556af500
commit
cbf97dc7d9
@ -13,9 +13,6 @@ using namespace cl::sycl;
|
|||||||
|
|
||||||
#define WGSIZE 256
|
#define WGSIZE 256
|
||||||
|
|
||||||
#define DOT_WGSIZE 256
|
|
||||||
#define DOT_NUM_GROUPS 256
|
|
||||||
|
|
||||||
// Cache list of devices
|
// Cache list of devices
|
||||||
bool cached = false;
|
bool cached = false;
|
||||||
std::vector<device> devices;
|
std::vector<device> devices;
|
||||||
@ -41,9 +38,22 @@ SYCLStream<T>::SYCLStream(const unsigned int ARRAY_SIZE, const int device_index)
|
|||||||
throw std::runtime_error("Invalid device index");
|
throw std::runtime_error("Invalid device index");
|
||||||
device dev = devices[device_index];
|
device dev = devices[device_index];
|
||||||
|
|
||||||
|
// Determine sensible dot kernel NDRange configuration
|
||||||
|
if (dev.is_cpu())
|
||||||
|
{
|
||||||
|
dot_num_groups = dev.get_info<info::device::max_compute_units>();
|
||||||
|
dot_wgsize = dev.get_info<info::device::native_vector_width_double>() * 2;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
dot_num_groups = dev.get_info<info::device::max_compute_units>() * 4;
|
||||||
|
dot_wgsize = dev.get_info<info::device::max_work_group_size>();
|
||||||
|
}
|
||||||
|
|
||||||
// Print out device information
|
// Print out device information
|
||||||
std::cout << "Using SYCL device " << getDeviceName(device_index) << std::endl;
|
std::cout << "Using SYCL device " << getDeviceName(device_index) << std::endl;
|
||||||
std::cout << "Driver: " << getDeviceDriver(device_index) << std::endl;
|
std::cout << "Driver: " << getDeviceDriver(device_index) << std::endl;
|
||||||
|
std::cout << "Dot kernel config: " << dot_num_groups << " groups of size " << dot_wgsize << std::endl;
|
||||||
|
|
||||||
queue = new cl::sycl::queue(dev);
|
queue = new cl::sycl::queue(dev);
|
||||||
|
|
||||||
@ -51,7 +61,7 @@ SYCLStream<T>::SYCLStream(const unsigned int ARRAY_SIZE, const int device_index)
|
|||||||
d_a = new buffer<T>(array_size);
|
d_a = new buffer<T>(array_size);
|
||||||
d_b = new buffer<T>(array_size);
|
d_b = new buffer<T>(array_size);
|
||||||
d_c = new buffer<T>(array_size);
|
d_c = new buffer<T>(array_size);
|
||||||
d_sum = new buffer<T>(DOT_NUM_GROUPS);
|
d_sum = new buffer<T>(dot_num_groups);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
@ -138,11 +148,11 @@ T SYCLStream<T>::dot()
|
|||||||
auto kb = d_b->template get_access<access::mode::read>(cgh);
|
auto kb = d_b->template get_access<access::mode::read>(cgh);
|
||||||
auto ksum = d_sum->template get_access<access::mode::write>(cgh);
|
auto ksum = d_sum->template get_access<access::mode::write>(cgh);
|
||||||
|
|
||||||
auto wg_sum = accessor<T, 1, access::mode::read_write, access::target::local>(range<1>(DOT_WGSIZE), cgh);
|
auto wg_sum = accessor<T, 1, access::mode::read_write, access::target::local>(range<1>(dot_wgsize), cgh);
|
||||||
|
|
||||||
size_t N = array_size;
|
size_t N = array_size;
|
||||||
|
|
||||||
cgh.parallel_for<class dot>(nd_range<1>(DOT_NUM_GROUPS*DOT_WGSIZE, DOT_WGSIZE), [=](nd_item<1> item)
|
cgh.parallel_for<class dot>(nd_range<1>(dot_num_groups*dot_wgsize, dot_wgsize), [=](nd_item<1> item)
|
||||||
{
|
{
|
||||||
size_t i = item.get_global(0);
|
size_t i = item.get_global(0);
|
||||||
size_t li = item.get_local(0);
|
size_t li = item.get_local(0);
|
||||||
@ -164,7 +174,7 @@ T SYCLStream<T>::dot()
|
|||||||
|
|
||||||
T sum = 0.0;
|
T sum = 0.0;
|
||||||
auto h_sum = d_sum->template get_access<access::mode::read, access::target::host_buffer>();
|
auto h_sum = d_sum->template get_access<access::mode::read, access::target::host_buffer>();
|
||||||
for (int i = 0; i < DOT_NUM_GROUPS; i++)
|
for (int i = 0; i < dot_num_groups; i++)
|
||||||
{
|
{
|
||||||
sum += h_sum[i];
|
sum += h_sum[i];
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,6 +29,10 @@ class SYCLStream : public Stream<T>
|
|||||||
cl::sycl::buffer<T> *d_c;
|
cl::sycl::buffer<T> *d_c;
|
||||||
cl::sycl::buffer<T> *d_sum;
|
cl::sycl::buffer<T> *d_sum;
|
||||||
|
|
||||||
|
// NDRange configuration for the dot kernel
|
||||||
|
size_t dot_num_groups;
|
||||||
|
size_t dot_wgsize;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
SYCLStream(const unsigned int, const int);
|
SYCLStream(const unsigned int, const int);
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user