From 13cb8ffced2033d86875aa4e56de91c7c5da9926 Mon Sep 17 00:00:00 2001 From: Tom Lin Date: Sat, 28 Aug 2021 11:10:49 +0100 Subject: [PATCH] Use custom static reduction for CPU --- JuliaStream.jl/src/ThreadedStream.jl | 40 +++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/JuliaStream.jl/src/ThreadedStream.jl b/JuliaStream.jl/src/ThreadedStream.jl index 4422e66..0faabeb 100644 --- a/JuliaStream.jl/src/ThreadedStream.jl +++ b/JuliaStream.jl/src/ThreadedStream.jl @@ -63,12 +63,46 @@ function nstream!(data::VectorData{T}, _) where {T} end end +# Threads.@threads/Threads.@spawn doesn't support OpenMP's firstprivate, etc +function static_par_ranged(f::Function, range::Int, n::Int) + stride = range รท n + rem = range % n + strides = map(0:n) do i + width = stride + (i < rem ? 1 : 0) + offset = i < rem ? (stride + 1) * i : ((stride + 1) * rem) + (stride * (i - rem)) + (offset, width) + end + ccall(:jl_enter_threaded_region, Cvoid, ()) + try + foreach(wait, map(1:n) do group + (offset, size) = strides[group] + task = Task(() -> f(group, offset+1, offset+size)) + task.sticky = true + ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, group-1) # ccall, so 0-based for group + schedule(task) + end) + finally + ccall(:jl_exit_threaded_region, Cvoid, ()) + end +end + function dot(data::VectorData{T}, _) where {T} - partial = zeros(T, Threads.nthreads()) - Threads.@threads for i = 1:data.size - @inbounds partial[Threads.threadid()] += data.a[i] * data.b[i] + partial = Vector{T}(undef, Threads.nthreads()) + static_par_ranged(data.size, Threads.nthreads()) do group, startidx, endidx + acc = zero(T) + @fastmath for i = startidx:endidx + @inbounds acc += data.a[i] * data.b[i] + end + @inbounds partial[group] = acc end return sum(partial) + # This doesn't do well on aarch64 because of the excessive Threads.threadid() ccall + # and inhibited vectorisation from the lack of @fastmath + # partial = zeros(T, Threads.nthreads()) + # Threads.@threads for i = 1:data.size + # @inbounds partial[Threads.threadid()] += (data.a[i] * data.b[i]) + # end + # return sum(partial) end function read_data(data::VectorData{T}, _)::VectorData{T} where {T}