diff --git a/SYCLStream.cpp b/SYCLStream.cpp index 215f161..e5fd9c6 100644 --- a/SYCLStream.cpp +++ b/SYCLStream.cpp @@ -13,6 +13,9 @@ using namespace cl::sycl; #define WGSIZE 256 +#define DOT_WGSIZE 256 +#define DOT_NUM_GROUPS 256 + // Cache list of devices bool cached = false; std::vector devices; @@ -48,6 +51,7 @@ SYCLStream::SYCLStream(const unsigned int ARRAY_SIZE, const int device_index) d_a = new buffer(array_size); d_b = new buffer(array_size); d_c = new buffer(array_size); + d_sum = new buffer(DOT_NUM_GROUPS); } template @@ -56,6 +60,7 @@ SYCLStream::~SYCLStream() delete d_a; delete d_b; delete d_c; + delete d_sum; delete queue; } @@ -124,6 +129,49 @@ void SYCLStream::triad() queue->wait(); } +template +T SYCLStream::dot() +{ + queue->submit([&](handler &cgh) + { + auto ka = d_a->template get_access(cgh); + auto kb = d_b->template get_access(cgh); + auto ksum = d_sum->template get_access(cgh); + + auto wg_sum = accessor(range<1>(DOT_WGSIZE), cgh); + + size_t N = array_size; + + cgh.parallel_for(nd_range<1>(DOT_NUM_GROUPS*DOT_WGSIZE, DOT_WGSIZE), [=](nd_item<1> item) + { + size_t i = item.get_global(0); + size_t li = item.get_local(0); + wg_sum[li] = 0.0; + for (; i < N; i += item.get_global_range()[0]) + wg_sum[li] += ka[i] * kb[i]; + + for (int offset = item.get_local_range()[0]; offset > 0; offset /= 2) + { + item.barrier(cl::sycl::access::fence_space::local_space); + if (li < offset) + wg_sum[li] += wg_sum[li + offset]; + } + + if (li == 0) + ksum[item.get_group(0)] = wg_sum[0]; + }); + }); + + T sum = 0.0; + auto h_sum = d_sum->template get_access(); + for (int i = 0; i < DOT_NUM_GROUPS; i++) + { + sum += h_sum[i]; + } + + return sum; +} + template void SYCLStream::write_arrays(const std::vector& a, const std::vector& b, const std::vector& c) { diff --git a/SYCLStream.h b/SYCLStream.h index 8bc515d..ce3225e 100644 --- a/SYCLStream.h +++ b/SYCLStream.h @@ -27,6 +27,7 @@ class SYCLStream : public Stream cl::sycl::buffer *d_a; cl::sycl::buffer *d_b; cl::sycl::buffer *d_c; + cl::sycl::buffer *d_sum; public: @@ -37,6 +38,7 @@ class SYCLStream : public Stream virtual void add() override; virtual void mul() override; virtual void triad() override; + virtual T dot() override; virtual void write_arrays(const std::vector& a, const std::vector& b, const std::vector& c) override; virtual void read_arrays(std::vector& a, std::vector& b, std::vector& c) override;