From 7408ab0366eebce6ee917743a17dacde36f3c507 Mon Sep 17 00:00:00 2001 From: Tom Deakin Date: Mon, 24 Oct 2016 11:34:40 +0100 Subject: [PATCH] Add RAJA dot kernel --- RAJAStream.cpp | 17 +++++++++++++++++ RAJAStream.hpp | 3 +++ 2 files changed, 20 insertions(+) diff --git a/RAJAStream.cpp b/RAJAStream.cpp index e418f09..21c1843 100644 --- a/RAJAStream.cpp +++ b/RAJAStream.cpp @@ -109,6 +109,23 @@ void RAJAStream::triad() }); } +template +T RAJAStream::dot() +{ + T* a = d_a; + T* b = d_b; + + RAJA::ReduceSum sum(0.0); + + forall(index_set, [=] RAJA_DEVICE (int index) + { + sum += a[index] * b[index]; + }); + + return T(sum); +} + + void listDevices(void) { std::cout << "This is not the device you are looking for."; diff --git a/RAJAStream.hpp b/RAJAStream.hpp index 454e20e..768314a 100644 --- a/RAJAStream.hpp +++ b/RAJAStream.hpp @@ -18,11 +18,13 @@ typedef RAJA::IndexSet::ExecPolicy< RAJA::seq_segit, RAJA::omp_parallel_for_exec> policy; +typedef RAJA::omp_reduce reduce_policy; #else const size_t block_size = 128; typedef RAJA::IndexSet::ExecPolicy< RAJA::seq_segit, RAJA::cuda_exec> policy; +typedef RAJA::cuda_reduce reduce_policy; #endif template @@ -49,6 +51,7 @@ class RAJAStream : public Stream virtual void add() override; virtual void mul() override; virtual void triad() override; + virtual T dot() override; virtual void write_arrays( const std::vector& a, const std::vector& b, const std::vector& c) override;