Add SYCL 1.2.1 nstream kernel

This commit is contained in:
Tom Deakin 2021-02-02 12:29:00 +00:00
parent b470e4466c
commit bda9525b95
2 changed files with 20 additions and 0 deletions

View File

@ -148,6 +148,23 @@ void SYCLStream<T>::triad()
queue->wait(); queue->wait();
} }
template <class T>
void SYCLStream<T>::nstream()
{
const T scalar = startScalar;
queue->submit([&](handler &cgh)
{
auto ka = d_a->template get_access<access::mode::read_write>(cgh);
auto kb = d_b->template get_access<access::mode::read>(cgh);
auto kc = d_c->template get_access<access::mode::read>(cgh);
cgh.parallel_for<nstream_kernel>(range<1>{array_size}, [=](id<1> idx)
{
ka[idx] += kb[idx] + scalar * kc[idx];
});
});
queue->wait();
}
template <class T> template <class T>
T SYCLStream<T>::dot() T SYCLStream<T>::dot()
{ {

View File

@ -22,6 +22,7 @@ namespace sycl_kernels
template <class T> class mul; template <class T> class mul;
template <class T> class add; template <class T> class add;
template <class T> class triad; template <class T> class triad;
template <class T> class nstream;
template <class T> class dot; template <class T> class dot;
} }
@ -45,6 +46,7 @@ class SYCLStream : public Stream<T>
typedef sycl_kernels::mul<T> mul_kernel; typedef sycl_kernels::mul<T> mul_kernel;
typedef sycl_kernels::add<T> add_kernel; typedef sycl_kernels::add<T> add_kernel;
typedef sycl_kernels::triad<T> triad_kernel; typedef sycl_kernels::triad<T> triad_kernel;
typedef sycl_kernels::nstream<T> nstream_kernel;
typedef sycl_kernels::dot<T> dot_kernel; typedef sycl_kernels::dot<T> dot_kernel;
// NDRange configuration for the dot kernel // NDRange configuration for the dot kernel
@ -60,6 +62,7 @@ class SYCLStream : public Stream<T>
virtual void add() override; virtual void add() override;
virtual void mul() override; virtual void mul() override;
virtual void triad() override; virtual void triad() override;
virtual void nstream() override;
virtual T dot() override; virtual T dot() override;
virtual void init_arrays(T initA, T initB, T initC) override; virtual void init_arrays(T initA, T initB, T initC) override;