Add RAJA dot kernel
This commit is contained in:
parent
823e12708f
commit
7408ab0366
@ -109,6 +109,23 @@ void RAJAStream<T>::triad()
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
T RAJAStream<T>::dot()
|
||||||
|
{
|
||||||
|
T* a = d_a;
|
||||||
|
T* b = d_b;
|
||||||
|
|
||||||
|
RAJA::ReduceSum<reduce_policy, T> sum(0.0);
|
||||||
|
|
||||||
|
forall<policy>(index_set, [=] RAJA_DEVICE (int index)
|
||||||
|
{
|
||||||
|
sum += a[index] * b[index];
|
||||||
|
});
|
||||||
|
|
||||||
|
return T(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void listDevices(void)
|
void listDevices(void)
|
||||||
{
|
{
|
||||||
std::cout << "This is not the device you are looking for.";
|
std::cout << "This is not the device you are looking for.";
|
||||||
|
|||||||
@ -18,11 +18,13 @@
|
|||||||
typedef RAJA::IndexSet::ExecPolicy<
|
typedef RAJA::IndexSet::ExecPolicy<
|
||||||
RAJA::seq_segit,
|
RAJA::seq_segit,
|
||||||
RAJA::omp_parallel_for_exec> policy;
|
RAJA::omp_parallel_for_exec> policy;
|
||||||
|
typedef RAJA::omp_reduce reduce_policy;
|
||||||
#else
|
#else
|
||||||
const size_t block_size = 128;
|
const size_t block_size = 128;
|
||||||
typedef RAJA::IndexSet::ExecPolicy<
|
typedef RAJA::IndexSet::ExecPolicy<
|
||||||
RAJA::seq_segit,
|
RAJA::seq_segit,
|
||||||
RAJA::cuda_exec<block_size>> policy;
|
RAJA::cuda_exec<block_size>> policy;
|
||||||
|
typedef RAJA::cuda_reduce<block_size> reduce_policy;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
@ -49,6 +51,7 @@ class RAJAStream : public Stream<T>
|
|||||||
virtual void add() override;
|
virtual void add() override;
|
||||||
virtual void mul() override;
|
virtual void mul() override;
|
||||||
virtual void triad() override;
|
virtual void triad() override;
|
||||||
|
virtual T dot() override;
|
||||||
|
|
||||||
virtual void write_arrays(
|
virtual void write_arrays(
|
||||||
const std::vector<T>& a, const std::vector<T>& b, const std::vector<T>& c) override;
|
const std::vector<T>& a, const std::vector<T>& b, const std::vector<T>& c) override;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user