Add nstream to C++ STD version -- untested as compilers not ready

This commit is contained in:
Tom Deakin 2021-02-03 10:54:33 +00:00
parent 579247dc06
commit 490af52147
4 changed files with 27 additions and 0 deletions

View File

@ -94,6 +94,20 @@ void STD20Stream<T>::triad()
); );
} }
template <class T>
void STD20Stream<T>::nstream()
{
const T scalar = startScalar;
std::for_each_n(
std::execution::par_unseq,
std::views::iota(0).begin(), array_size,
[&] (int i) {
a[i] += b[i] + scalar * c[i];
}
);
}
template <class T> template <class T>
T STD20Stream<T>::dot() T STD20Stream<T>::dot()
{ {

View File

@ -33,6 +33,7 @@ class STD20Stream : public Stream<T>
virtual void add() override; virtual void add() override;
virtual void mul() override; virtual void mul() override;
virtual void triad() override; virtual void triad() override;
virtual void nstream() override;
virtual T dot() override; virtual T dot() override;
virtual void init_arrays(T initA, T initB, T initC) override; virtual void init_arrays(T initA, T initB, T initC) override;

View File

@ -72,6 +72,17 @@ void STDStream<T>::triad()
std::transform(exe_policy, b, b+array_size, c, a, [](T bi, T ci){ return bi+startScalar*ci; }); std::transform(exe_policy, b, b+array_size, c, a, [](T bi, T ci){ return bi+startScalar*ci; });
} }
template <class T>
void STDStream<T>::nstream()
{
// a[i] += b[i] + scalar * c[i];
// Need to do in two stages with C++11 STL.
// 1: a[i] += b[i]
// 2: a[i] += scalar * c[i];
std::transform(exe_policy, a, a+array_size, b, a, [](T ai, T bi){ return ai + bi; });
std::transform(exe_policy, a, a+array_size, c, a, [](T ai, T ci){ return ai + startScalar*ci; });
}
template <class T> template <class T>
T STDStream<T>::dot() T STDStream<T>::dot()
{ {

View File

@ -31,6 +31,7 @@ class STDStream : public Stream<T>
virtual void add() override; virtual void add() override;
virtual void mul() override; virtual void mul() override;
virtual void triad() override; virtual void triad() override;
virtual void nstream() override;
virtual T dot() override; virtual T dot() override;
virtual void init_arrays(T initA, T initB, T initC) override; virtual void init_arrays(T initA, T initB, T initC) override;