[SYCL] Fix multiple template specializations
This commit is contained in:
parent
66776d5839
commit
1e976ff150
@ -18,16 +18,6 @@ std::vector<device> devices;
|
|||||||
void getDeviceList(void);
|
void getDeviceList(void);
|
||||||
program * p;
|
program * p;
|
||||||
|
|
||||||
/* Forward declaration of SYCL kernels */
|
|
||||||
namespace kernels {
|
|
||||||
class init;
|
|
||||||
class copy;
|
|
||||||
class mul;
|
|
||||||
class add;
|
|
||||||
class triad;
|
|
||||||
class dot;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
SYCLStream<T>::SYCLStream(const unsigned int ARRAY_SIZE, const int device_index)
|
SYCLStream<T>::SYCLStream(const unsigned int ARRAY_SIZE, const int device_index)
|
||||||
{
|
{
|
||||||
@ -61,12 +51,12 @@ SYCLStream<T>::SYCLStream(const unsigned int ARRAY_SIZE, const int device_index)
|
|||||||
|
|
||||||
/* Pre-build the kernels */
|
/* Pre-build the kernels */
|
||||||
p = new program(queue->get_context());
|
p = new program(queue->get_context());
|
||||||
p->build_from_kernel_name<kernels::init>();
|
p->build_from_kernel_name<init_kernel>();
|
||||||
p->build_from_kernel_name<kernels::copy>();
|
p->build_from_kernel_name<copy_kernel>();
|
||||||
p->build_from_kernel_name<kernels::mul>();
|
p->build_from_kernel_name<mul_kernel>();
|
||||||
p->build_from_kernel_name<kernels::add>();
|
p->build_from_kernel_name<add_kernel>();
|
||||||
p->build_from_kernel_name<kernels::triad>();
|
p->build_from_kernel_name<triad_kernel>();
|
||||||
p->build_from_kernel_name<kernels::dot>();
|
p->build_from_kernel_name<dot_kernel>();
|
||||||
|
|
||||||
// Create buffers
|
// Create buffers
|
||||||
d_a = new buffer<T>(array_size);
|
d_a = new buffer<T>(array_size);
|
||||||
@ -94,7 +84,7 @@ void SYCLStream<T>::copy()
|
|||||||
{
|
{
|
||||||
auto ka = d_a->template get_access<access::mode::read>(cgh);
|
auto ka = d_a->template get_access<access::mode::read>(cgh);
|
||||||
auto kc = d_c->template get_access<access::mode::write>(cgh);
|
auto kc = d_c->template get_access<access::mode::write>(cgh);
|
||||||
cgh.parallel_for<kernels::copy>(p->get_kernel<kernels::copy>(),
|
cgh.parallel_for<copy_kernel>(p->get_kernel<copy_kernel>(),
|
||||||
range<1>{array_size}, [=](item<1> item)
|
range<1>{array_size}, [=](item<1> item)
|
||||||
{
|
{
|
||||||
auto id = item.get();
|
auto id = item.get();
|
||||||
@ -112,7 +102,7 @@ void SYCLStream<T>::mul()
|
|||||||
{
|
{
|
||||||
auto kb = d_b->template get_access<access::mode::write>(cgh);
|
auto kb = d_b->template get_access<access::mode::write>(cgh);
|
||||||
auto kc = d_c->template get_access<access::mode::read>(cgh);
|
auto kc = d_c->template get_access<access::mode::read>(cgh);
|
||||||
cgh.parallel_for<kernels::mul>(p->get_kernel<kernels::mul>(),
|
cgh.parallel_for<mul_kernel>(p->get_kernel<mul_kernel>(),
|
||||||
range<1>{array_size}, [=](item<1> item)
|
range<1>{array_size}, [=](item<1> item)
|
||||||
{
|
{
|
||||||
auto id = item.get();
|
auto id = item.get();
|
||||||
@ -130,7 +120,7 @@ void SYCLStream<T>::add()
|
|||||||
auto ka = d_a->template get_access<access::mode::read>(cgh);
|
auto ka = d_a->template get_access<access::mode::read>(cgh);
|
||||||
auto kb = d_b->template get_access<access::mode::read>(cgh);
|
auto kb = d_b->template get_access<access::mode::read>(cgh);
|
||||||
auto kc = d_c->template get_access<access::mode::write>(cgh);
|
auto kc = d_c->template get_access<access::mode::write>(cgh);
|
||||||
cgh.parallel_for<kernels::add>(p->get_kernel<kernels::add>(),
|
cgh.parallel_for<add_kernel>(p->get_kernel<add_kernel>(),
|
||||||
range<1>{array_size}, [=](item<1> item)
|
range<1>{array_size}, [=](item<1> item)
|
||||||
{
|
{
|
||||||
auto id = item.get();
|
auto id = item.get();
|
||||||
@ -149,7 +139,7 @@ void SYCLStream<T>::triad()
|
|||||||
auto ka = d_a->template get_access<access::mode::write>(cgh);
|
auto ka = d_a->template get_access<access::mode::write>(cgh);
|
||||||
auto kb = d_b->template get_access<access::mode::read>(cgh);
|
auto kb = d_b->template get_access<access::mode::read>(cgh);
|
||||||
auto kc = d_c->template get_access<access::mode::read>(cgh);
|
auto kc = d_c->template get_access<access::mode::read>(cgh);
|
||||||
cgh.parallel_for<kernels::triad>(p->get_kernel<kernels::triad>(),
|
cgh.parallel_for<triad_kernel>(p->get_kernel<triad_kernel>(),
|
||||||
range<1>{array_size}, [=](item<1> item)
|
range<1>{array_size}, [=](item<1> item)
|
||||||
{
|
{
|
||||||
auto id = item.get();
|
auto id = item.get();
|
||||||
@ -172,7 +162,8 @@ T SYCLStream<T>::dot()
|
|||||||
|
|
||||||
size_t N = array_size;
|
size_t N = array_size;
|
||||||
|
|
||||||
cgh.parallel_for<kernels::dot>(nd_range<1>(dot_num_groups*dot_wgsize, dot_wgsize), [=](nd_item<1> item)
|
cgh.parallel_for<dot_kernel>(p->get_kernel<dot_kernel>(),
|
||||||
|
nd_range<1>(dot_num_groups*dot_wgsize, dot_wgsize), [=](nd_item<1> item)
|
||||||
{
|
{
|
||||||
size_t i = item.get_global(0);
|
size_t i = item.get_global(0);
|
||||||
size_t li = item.get_local(0);
|
size_t li = item.get_local(0);
|
||||||
@ -210,8 +201,8 @@ void SYCLStream<T>::init_arrays(T initA, T initB, T initC)
|
|||||||
auto ka = d_a->template get_access<access::mode::write>(cgh);
|
auto ka = d_a->template get_access<access::mode::write>(cgh);
|
||||||
auto kb = d_b->template get_access<access::mode::write>(cgh);
|
auto kb = d_b->template get_access<access::mode::write>(cgh);
|
||||||
auto kc = d_c->template get_access<access::mode::write>(cgh);
|
auto kc = d_c->template get_access<access::mode::write>(cgh);
|
||||||
cgh.parallel_for<kernels::init>(p->get_kernel<kernels::init>(),
|
cgh.parallel_for<init_kernel>(p->get_kernel<init_kernel>(),
|
||||||
range<1>{array_size}, [=](item<1> item)
|
range<1>{array_size}, [=](item<1> item)
|
||||||
{
|
{
|
||||||
auto id = item.get();
|
auto id = item.get();
|
||||||
ka[id[0]] = initA;
|
ka[id[0]] = initA;
|
||||||
@ -311,5 +302,5 @@ std::string getDeviceDriver(const int device)
|
|||||||
|
|
||||||
|
|
||||||
// TODO: Fix kernel names to allow multiple template specializations
|
// TODO: Fix kernel names to allow multiple template specializations
|
||||||
//template class SYCLStream<float>;
|
template class SYCLStream<float>;
|
||||||
template class SYCLStream<double>;
|
template class SYCLStream<double>;
|
||||||
|
|||||||
18
SYCLStream.h
18
SYCLStream.h
@ -15,6 +15,16 @@
|
|||||||
|
|
||||||
#define IMPLEMENTATION_STRING "SYCL"
|
#define IMPLEMENTATION_STRING "SYCL"
|
||||||
|
|
||||||
|
namespace sycl_kernels
|
||||||
|
{
|
||||||
|
template <class T> class init;
|
||||||
|
template <class T> class copy;
|
||||||
|
template <class T> class mul;
|
||||||
|
template <class T> class add;
|
||||||
|
template <class T> class triad;
|
||||||
|
template <class T> class dot;
|
||||||
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
class SYCLStream : public Stream<T>
|
class SYCLStream : public Stream<T>
|
||||||
{
|
{
|
||||||
@ -29,6 +39,14 @@ class SYCLStream : public Stream<T>
|
|||||||
cl::sycl::buffer<T> *d_c;
|
cl::sycl::buffer<T> *d_c;
|
||||||
cl::sycl::buffer<T> *d_sum;
|
cl::sycl::buffer<T> *d_sum;
|
||||||
|
|
||||||
|
// SYCL kernel names
|
||||||
|
typedef sycl_kernels::init<T> init_kernel;
|
||||||
|
typedef sycl_kernels::copy<T> copy_kernel;
|
||||||
|
typedef sycl_kernels::mul<T> mul_kernel;
|
||||||
|
typedef sycl_kernels::add<T> add_kernel;
|
||||||
|
typedef sycl_kernels::triad<T> triad_kernel;
|
||||||
|
typedef sycl_kernels::dot<T> dot_kernel;
|
||||||
|
|
||||||
// NDRange configuration for the dot kernel
|
// NDRange configuration for the dot kernel
|
||||||
size_t dot_num_groups;
|
size_t dot_num_groups;
|
||||||
size_t dot_wgsize;
|
size_t dot_wgsize;
|
||||||
|
|||||||
4
main.cpp
4
main.cpp
@ -61,13 +61,11 @@ int main(int argc, char *argv[])
|
|||||||
|
|
||||||
parseArguments(argc, argv);
|
parseArguments(argc, argv);
|
||||||
|
|
||||||
// TODO: Fix SYCL to allow multiple template specializations
|
// TODO: Fix Kokkos to allow multiple template specializations
|
||||||
#ifndef SYCL
|
|
||||||
#ifndef KOKKOS
|
#ifndef KOKKOS
|
||||||
if (use_float)
|
if (use_float)
|
||||||
run<float>();
|
run<float>();
|
||||||
else
|
else
|
||||||
#endif
|
|
||||||
#endif
|
#endif
|
||||||
run<double>();
|
run<double>();
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user