[SYCL] Implement dot kernel

This commit is contained in:
James Price 2016-10-25 16:39:23 +01:00
parent e5b67ac969
commit ed630e7dbc
2 changed files with 50 additions and 0 deletions

View File

@ -13,6 +13,9 @@ using namespace cl::sycl;
#define WGSIZE 256
#define DOT_WGSIZE 256
#define DOT_NUM_GROUPS 256
// Cache list of devices
bool cached = false;
std::vector<device> devices;
@ -48,6 +51,7 @@ SYCLStream<T>::SYCLStream(const unsigned int ARRAY_SIZE, const int device_index)
d_a = new buffer<T>(array_size);
d_b = new buffer<T>(array_size);
d_c = new buffer<T>(array_size);
d_sum = new buffer<T>(DOT_NUM_GROUPS);
}
template <class T>
@ -56,6 +60,7 @@ SYCLStream<T>::~SYCLStream()
delete d_a;
delete d_b;
delete d_c;
delete d_sum;
delete queue;
}
@ -124,6 +129,49 @@ void SYCLStream<T>::triad()
queue->wait();
}
template <class T>
T SYCLStream<T>::dot()
{
queue->submit([&](handler &cgh)
{
auto ka = d_a->template get_access<access::mode::read>(cgh);
auto kb = d_b->template get_access<access::mode::read>(cgh);
auto ksum = d_sum->template get_access<access::mode::write>(cgh);
auto wg_sum = accessor<T, 1, access::mode::read_write, access::target::local>(range<1>(DOT_WGSIZE), cgh);
size_t N = array_size;
cgh.parallel_for<class dot>(nd_range<1>(DOT_NUM_GROUPS*DOT_WGSIZE, DOT_WGSIZE), [=](nd_item<1> item)
{
size_t i = item.get_global(0);
size_t li = item.get_local(0);
wg_sum[li] = 0.0;
for (; i < N; i += item.get_global_range()[0])
wg_sum[li] += ka[i] * kb[i];
for (int offset = item.get_local_range()[0]; offset > 0; offset /= 2)
{
item.barrier(cl::sycl::access::fence_space::local_space);
if (li < offset)
wg_sum[li] += wg_sum[li + offset];
}
if (li == 0)
ksum[item.get_group(0)] = wg_sum[0];
});
});
T sum = 0.0;
auto h_sum = d_sum->template get_access<access::mode::read, access::target::host_buffer>();
for (int i = 0; i < DOT_NUM_GROUPS; i++)
{
sum += h_sum[i];
}
return sum;
}
template <class T>
void SYCLStream<T>::write_arrays(const std::vector<T>& a, const std::vector<T>& b, const std::vector<T>& c)
{

View File

@ -27,6 +27,7 @@ class SYCLStream : public Stream<T>
cl::sycl::buffer<T> *d_a;
cl::sycl::buffer<T> *d_b;
cl::sycl::buffer<T> *d_c;
cl::sycl::buffer<T> *d_sum;
public:
@ -37,6 +38,7 @@ class SYCLStream : public Stream<T>
virtual void add() override;
virtual void mul() override;
virtual void triad() override;
virtual T dot() override;
virtual void write_arrays(const std::vector<T>& a, const std::vector<T>& b, const std::vector<T>& c) override;
virtual void read_arrays(std::vector<T>& a, std::vector<T>& b, std::vector<T>& c) override;