From 3de019c156a803a6b5abd1b5865828a80965087c Mon Sep 17 00:00:00 2001 From: Tom Lin Date: Sat, 7 Oct 2023 13:50:58 +0100 Subject: [PATCH] Add init/read timing for Java Upgrade to TornadoVM 0.15 API --- src/java/java-stream/pom.xml | 8 +- .../src/main/java/javastream/JavaStream.java | 14 ++- .../src/main/java/javastream/Main.java | 93 ++++++++++++++----- .../javastream/aparapi/AparapiStreams.java | 2 +- .../javastream/jdk/GenericPlainStream.java | 2 +- .../java/javastream/jdk/GenericStream.java | 2 +- .../jdk/SpecialisedDoubleStream.java | 2 +- .../jdk/SpecialisedFloatStream.java | 2 +- .../jdk/SpecialisedPlainDoubleStream.java | 2 +- .../jdk/SpecialisedPlainFloatStream.java | 2 +- .../tornadovm/GenericTornadoVMStream.java | 34 +++---- .../tornadovm/SpecialisedDouble.java | 52 +++++++++-- .../tornadovm/SpecialisedFloat.java | 52 +++++++++-- .../tornadovm/TornadoVMStreams.java | 26 ++++-- 14 files changed, 210 insertions(+), 83 deletions(-) diff --git a/src/java/java-stream/pom.xml b/src/java/java-stream/pom.xml index d28a3d5..78d26b3 100644 --- a/src/java/java-stream/pom.xml +++ b/src/java/java-stream/pom.xml @@ -12,7 +12,7 @@ UTF-8 UTF-8 - 5.7.2 + 5.9.2 @@ -27,19 +27,19 @@ com.beust jcommander - 1.81 + 1.82 tornado tornado-api - 0.9 + 0.15.1 com.aparapi aparapi - 2.0.0 + 3.0.0 diff --git a/src/java/java-stream/src/main/java/javastream/JavaStream.java b/src/java/java-stream/src/main/java/javastream/JavaStream.java index 7ab96cb..4fdb229 100644 --- a/src/java/java-stream/src/main/java/javastream/JavaStream.java +++ b/src/java/java-stream/src/main/java/javastream/JavaStream.java @@ -56,7 +56,7 @@ public abstract class JavaStream { protected abstract T dot(); - protected abstract Data data(); + protected abstract Data readArrays(); public static class EnumeratedStream extends JavaStream { @@ -113,8 +113,8 @@ public abstract class JavaStream { } @Override - public Data data() { - return actual.data(); + public Data readArrays() { + return actual.readArrays(); } } @@ -140,6 +140,14 @@ public abstract class JavaStream { return Duration.ofNanos(end - start); } + final Duration runInitArrays() { + return timed(this::initArrays); + } + + final SimpleImmutableEntry> runReadArrays() { + return timed(this::readArrays); + } + final SimpleImmutableEntry, T> runAll(int times) { Timings timings = new Timings<>(); T lastSum = null; diff --git a/src/java/java-stream/src/main/java/javastream/Main.java b/src/java/java-stream/src/main/java/javastream/Main.java index 2442128..3732a24 100644 --- a/src/java/java-stream/src/main/java/javastream/Main.java +++ b/src/java/java-stream/src/main/java/javastream/Main.java @@ -128,6 +128,40 @@ public class Main { } } + @SuppressWarnings("unchecked") + static void showInit( + int totalBytes, double megaScale, Options opt, Duration init, Duration read) { + List> setup = + Arrays.asList( + new SimpleImmutableEntry<>("Init", durationToSeconds(init)), + new SimpleImmutableEntry<>("Read", durationToSeconds(read))); + if (opt.csv) { + tabulateCsv( + true, + setup.stream() + .map( + x -> + Arrays.asList( + new SimpleImmutableEntry<>("function", x.getKey()), + new SimpleImmutableEntry<>("n_elements", opt.arraysize + ""), + new SimpleImmutableEntry<>("sizeof", totalBytes + ""), + new SimpleImmutableEntry<>( + "max_m" + (opt.mibibytes ? "i" : "") + "bytes_per_sec", + ((megaScale * (double) totalBytes / x.getValue())) + ""), + new SimpleImmutableEntry<>("runtime", x.getValue() + ""))) + .toArray(List[]::new)); + } else { + for (Entry e : setup) { + System.out.printf( + "%s: %.5f s (%.5f M%sBytes/sec)%n", + e.getKey(), + e.getValue(), + megaScale * (double) totalBytes / e.getValue(), + opt.mibibytes ? "i" : ""); + } + } + } + static boolean run( String name, Config config, Function, JavaStream> mkStream) { @@ -183,35 +217,46 @@ public class Main { JavaStream stream = mkStream.apply(config); - stream.initArrays(); - + Duration init = stream.runInitArrays(); final boolean ok; switch (config.benchmark) { case ALL: - Entry, T> results = stream.runAll(opt.numtimes); - ok = checkSolutions(stream.data(), config, Optional.of(results.getValue())); - Timings timings = results.getKey(); - tabulateCsv( - opt.csv, - mkCsvRow(timings.copy, "Copy", 2 * arrayBytes, megaScale, opt), - mkCsvRow(timings.mul, "Mul", 2 * arrayBytes, megaScale, opt), - mkCsvRow(timings.add, "Add", 3 * arrayBytes, megaScale, opt), - mkCsvRow(timings.triad, "Triad", 3 * arrayBytes, megaScale, opt), - mkCsvRow(timings.dot, "Dot", 2 * arrayBytes, megaScale, opt)); - break; + { + Entry, T> results = stream.runAll(opt.numtimes); + SimpleImmutableEntry> read = stream.runReadArrays(); + showInit(totalBytes, megaScale, opt, init, read.getKey()); + ok = checkSolutions(read.getValue(), config, Optional.of(results.getValue())); + Timings timings = results.getKey(); + tabulateCsv( + opt.csv, + mkCsvRow(timings.copy, "Copy", 2 * arrayBytes, megaScale, opt), + mkCsvRow(timings.mul, "Mul", 2 * arrayBytes, megaScale, opt), + mkCsvRow(timings.add, "Add", 3 * arrayBytes, megaScale, opt), + mkCsvRow(timings.triad, "Triad", 3 * arrayBytes, megaScale, opt), + mkCsvRow(timings.dot, "Dot", 2 * arrayBytes, megaScale, opt)); + break; + } case NSTREAM: - List nstreamResults = stream.runNStream(opt.numtimes); - ok = checkSolutions(stream.data(), config, Optional.empty()); - tabulateCsv(opt.csv, mkCsvRow(nstreamResults, "Nstream", 4 * arrayBytes, megaScale, opt)); - break; + { + List nstreamResults = stream.runNStream(opt.numtimes); + SimpleImmutableEntry> read = stream.runReadArrays(); + showInit(totalBytes, megaScale, opt, init, read.getKey()); + ok = checkSolutions(read.getValue(), config, Optional.empty()); + tabulateCsv(opt.csv, mkCsvRow(nstreamResults, "Nstream", 4 * arrayBytes, megaScale, opt)); + break; + } case TRIAD: - Duration triadResult = stream.runTriad(opt.numtimes); - ok = checkSolutions(stream.data(), config, Optional.empty()); - int triadTotalBytes = 3 * arrayBytes * opt.numtimes; - double bandwidth = megaScale * (triadTotalBytes / durationToSeconds(triadResult)); - System.out.printf("Runtime (seconds): %.5f", durationToSeconds(triadResult)); - System.out.printf("Bandwidth (%s/s): %.3f ", gigaSuffix, bandwidth); - break; + { + Duration triadResult = stream.runTriad(opt.numtimes); + SimpleImmutableEntry> read = stream.runReadArrays(); + showInit(totalBytes, megaScale, opt, init, read.getKey()); + ok = checkSolutions(read.getValue(), config, Optional.empty()); + int triadTotalBytes = 3 * arrayBytes * opt.numtimes; + double bandwidth = megaScale * (triadTotalBytes / durationToSeconds(triadResult)); + System.out.printf("Runtime (seconds): %.5f", durationToSeconds(triadResult)); + System.out.printf("Bandwidth (%s/s): %.3f ", gigaSuffix, bandwidth); + break; + } default: throw new AssertionError(); } diff --git a/src/java/java-stream/src/main/java/javastream/aparapi/AparapiStreams.java b/src/java/java-stream/src/main/java/javastream/aparapi/AparapiStreams.java index ab2de52..052c807 100644 --- a/src/java/java-stream/src/main/java/javastream/aparapi/AparapiStreams.java +++ b/src/java/java-stream/src/main/java/javastream/aparapi/AparapiStreams.java @@ -122,7 +122,7 @@ public final class AparapiStreams { } @Override - public Data data() { + public Data readArrays() { return kernels.syncAndDispose(); } } diff --git a/src/java/java-stream/src/main/java/javastream/jdk/GenericPlainStream.java b/src/java/java-stream/src/main/java/javastream/jdk/GenericPlainStream.java index 7f210fa..8075603 100644 --- a/src/java/java-stream/src/main/java/javastream/jdk/GenericPlainStream.java +++ b/src/java/java-stream/src/main/java/javastream/jdk/GenericPlainStream.java @@ -86,7 +86,7 @@ final class GenericPlainStream extends JavaStream { } @Override - public Data data() { + public Data readArrays() { return new Data<>(a, b, c); } } diff --git a/src/java/java-stream/src/main/java/javastream/jdk/GenericStream.java b/src/java/java-stream/src/main/java/javastream/jdk/GenericStream.java index 1e65b8f..3cacf3a 100644 --- a/src/java/java-stream/src/main/java/javastream/jdk/GenericStream.java +++ b/src/java/java-stream/src/main/java/javastream/jdk/GenericStream.java @@ -80,7 +80,7 @@ final class GenericStream extends JavaStream { } @Override - public Data data() { + public Data readArrays() { return new Data<>(a, b, c); } } diff --git a/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedDoubleStream.java b/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedDoubleStream.java index 26406a6..1b54bc3 100644 --- a/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedDoubleStream.java +++ b/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedDoubleStream.java @@ -78,7 +78,7 @@ final class SpecialisedDoubleStream extends JavaStream { } @Override - public Data data() { + public Data readArrays() { return new Data<>(boxed(a), boxed(b), boxed(c)); } } diff --git a/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedFloatStream.java b/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedFloatStream.java index 6c414c1..4d8c137 100644 --- a/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedFloatStream.java +++ b/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedFloatStream.java @@ -78,7 +78,7 @@ final class SpecialisedFloatStream extends JavaStream { } @Override - public Data data() { + public Data readArrays() { return new Data<>(boxed(a), boxed(b), boxed(c)); } } diff --git a/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedPlainDoubleStream.java b/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedPlainDoubleStream.java index afda2ef..c4f38d0 100644 --- a/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedPlainDoubleStream.java +++ b/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedPlainDoubleStream.java @@ -78,7 +78,7 @@ final class SpecialisedPlainDoubleStream extends JavaStream { } @Override - public Data data() { + public Data readArrays() { return new Data<>(boxed(a), boxed(b), boxed(c)); } } diff --git a/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedPlainFloatStream.java b/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedPlainFloatStream.java index 9ccee53..5178ed2 100644 --- a/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedPlainFloatStream.java +++ b/src/java/java-stream/src/main/java/javastream/jdk/SpecialisedPlainFloatStream.java @@ -78,7 +78,7 @@ final class SpecialisedPlainFloatStream extends JavaStream { } @Override - public Data data() { + public Data readArrays() { return new Data<>(boxed(a), boxed(b), boxed(c)); } } diff --git a/src/java/java-stream/src/main/java/javastream/tornadovm/GenericTornadoVMStream.java b/src/java/java-stream/src/main/java/javastream/tornadovm/GenericTornadoVMStream.java index d936df6..a65c32a 100644 --- a/src/java/java-stream/src/main/java/javastream/tornadovm/GenericTornadoVMStream.java +++ b/src/java/java-stream/src/main/java/javastream/tornadovm/GenericTornadoVMStream.java @@ -4,8 +4,8 @@ import java.util.List; import java.util.stream.Collectors; import javastream.JavaStream; import javastream.Main.Config; -import uk.ac.manchester.tornado.api.TaskSchedule; -import uk.ac.manchester.tornado.api.TornadoRuntimeCI; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.TornadoRuntimeInterface; import uk.ac.manchester.tornado.api.common.TornadoDevice; import uk.ac.manchester.tornado.api.runtime.TornadoRuntime; @@ -13,18 +13,18 @@ abstract class GenericTornadoVMStream extends JavaStream { protected final TornadoDevice device; - protected TaskSchedule copyTask; - protected TaskSchedule mulTask; - protected TaskSchedule addTask; - protected TaskSchedule triadTask; - protected TaskSchedule nstreamTask; - protected TaskSchedule dotTask; + protected TornadoExecutionPlan copyTask; + protected TornadoExecutionPlan mulTask; + protected TornadoExecutionPlan addTask; + protected TornadoExecutionPlan triadTask; + protected TornadoExecutionPlan nstreamTask; + protected TornadoExecutionPlan dotTask; GenericTornadoVMStream(Config config) { super(config); try { - TornadoRuntimeCI runtime = TornadoRuntime.getTornadoRuntime(); + TornadoRuntimeInterface runtime = TornadoRuntime.getTornadoRuntime(); List devices = TornadoVMStreams.enumerateDevices(runtime); device = devices.get(config.options.device); @@ -42,10 +42,6 @@ abstract class GenericTornadoVMStream extends JavaStream { } } - protected static TaskSchedule mkSchedule() { - return new TaskSchedule(""); - } - @Override public List listDevices() { return TornadoVMStreams.enumerateDevices(TornadoRuntime.getTornadoRuntime()).stream() @@ -55,12 +51,12 @@ abstract class GenericTornadoVMStream extends JavaStream { @Override public void initArrays() { - this.copyTask.warmup(); - this.mulTask.warmup(); - this.addTask.warmup(); - this.triadTask.warmup(); - this.nstreamTask.warmup(); - this.dotTask.warmup(); + this.copyTask.withWarmUp(); + this.mulTask.withWarmUp(); + this.addTask.withWarmUp(); + this.triadTask.withWarmUp(); + this.nstreamTask.withWarmUp(); + this.dotTask.withWarmUp(); } @Override diff --git a/src/java/java-stream/src/main/java/javastream/tornadovm/SpecialisedDouble.java b/src/java/java-stream/src/main/java/javastream/tornadovm/SpecialisedDouble.java index 7712e31..c10153e 100644 --- a/src/java/java-stream/src/main/java/javastream/tornadovm/SpecialisedDouble.java +++ b/src/java/java-stream/src/main/java/javastream/tornadovm/SpecialisedDouble.java @@ -2,8 +2,11 @@ package javastream.tornadovm; import java.util.Arrays; import javastream.Main.Config; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.annotations.Reduce; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; final class SpecialisedDouble extends GenericTornadoVMStream { @@ -49,7 +52,7 @@ final class SpecialisedDouble extends GenericTornadoVMStream { private final double[] a, b, c; private final double[] dotSum; - @SuppressWarnings({"PrimitiveArrayArgumentToVarargsMethod", "DuplicatedCode"}) + @SuppressWarnings({"DuplicatedCode"}) SpecialisedDouble(Config config) { super(config); final int size = config.options.arraysize; @@ -58,12 +61,43 @@ final class SpecialisedDouble extends GenericTornadoVMStream { b = new double[size]; c = new double[size]; dotSum = new double[1]; - this.copyTask = mkSchedule().task("", SpecialisedDouble::copy, size, a, c); - this.mulTask = mkSchedule().task("", SpecialisedDouble::mul, size, b, c, scalar); - this.addTask = mkSchedule().task("", SpecialisedDouble::add, size, a, b, c); - this.triadTask = mkSchedule().task("", SpecialisedDouble::triad, size, a, b, c, scalar); - this.nstreamTask = mkSchedule().task("", SpecialisedDouble::nstream, size, a, b, c, scalar); - this.dotTask = mkSchedule().task("", SpecialisedDouble::dot_, a, b, dotSum).streamOut(dotSum); + this.copyTask = + new TornadoExecutionPlan( + new TaskGraph("copy") + .task("copy", SpecialisedDouble::copy, size, a, c) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a, c) + .snapshot()); + this.mulTask = + new TornadoExecutionPlan( + new TaskGraph("mul") + .task("mul", SpecialisedDouble::mul, size, b, c, scalar) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, b, c) + .snapshot()); + this.addTask = + new TornadoExecutionPlan( + new TaskGraph("add") + .task("add", SpecialisedDouble::add, size, a, b, c) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b, c) + .snapshot()); + this.triadTask = + new TornadoExecutionPlan( + new TaskGraph("triad") + .task("triad", SpecialisedDouble::triad, size, a, b, c, scalar) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b, c) + .snapshot()); + this.nstreamTask = + new TornadoExecutionPlan( + new TaskGraph("nstream") + .task("nstream", SpecialisedDouble::nstream, size, a, b, c, scalar) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b, c) + .snapshot()); + this.dotTask = + new TornadoExecutionPlan( + new TaskGraph("dot") + .task("dot", SpecialisedDouble::dot_, a, b, dotSum) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b) + .transferToHost(DataTransferMode.EVERY_EXECUTION, new Object[] {dotSum}) + .snapshot()); } @Override @@ -72,7 +106,7 @@ final class SpecialisedDouble extends GenericTornadoVMStream { Arrays.fill(a, config.initA); Arrays.fill(b, config.initB); Arrays.fill(c, config.initC); - TornadoVMStreams.xferToDevice(device, a, b, c); + TornadoVMStreams.allocAndXferToDevice(device, a, b, c); } @Override @@ -81,7 +115,7 @@ final class SpecialisedDouble extends GenericTornadoVMStream { } @Override - public Data data() { + public Data readArrays() { TornadoVMStreams.xferFromDevice(device, a, b, c); return new Data<>(boxed(a), boxed(b), boxed(c)); } diff --git a/src/java/java-stream/src/main/java/javastream/tornadovm/SpecialisedFloat.java b/src/java/java-stream/src/main/java/javastream/tornadovm/SpecialisedFloat.java index e61cfe9..0f3fffa 100644 --- a/src/java/java-stream/src/main/java/javastream/tornadovm/SpecialisedFloat.java +++ b/src/java/java-stream/src/main/java/javastream/tornadovm/SpecialisedFloat.java @@ -2,8 +2,11 @@ package javastream.tornadovm; import java.util.Arrays; import javastream.Main.Config; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.annotations.Reduce; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; final class SpecialisedFloat extends GenericTornadoVMStream { @@ -49,7 +52,7 @@ final class SpecialisedFloat extends GenericTornadoVMStream { private final float[] a, b, c; private final float[] dotSum; - @SuppressWarnings({"PrimitiveArrayArgumentToVarargsMethod", "DuplicatedCode"}) + @SuppressWarnings({"DuplicatedCode"}) SpecialisedFloat(Config config) { super(config); final int size = config.options.arraysize; @@ -58,12 +61,43 @@ final class SpecialisedFloat extends GenericTornadoVMStream { b = new float[size]; c = new float[size]; dotSum = new float[1]; - this.copyTask = mkSchedule().task("", SpecialisedFloat::copy, size, a, c); - this.mulTask = mkSchedule().task("", SpecialisedFloat::mul, size, b, c, scalar); - this.addTask = mkSchedule().task("", SpecialisedFloat::add, size, a, b, c); - this.triadTask = mkSchedule().task("", SpecialisedFloat::triad, size, a, b, c, scalar); - this.nstreamTask = mkSchedule().task("", SpecialisedFloat::nstream, size, a, b, c, scalar); - this.dotTask = mkSchedule().task("", SpecialisedFloat::dot_, a, b, dotSum).streamOut(dotSum); + this.copyTask = + new TornadoExecutionPlan( + new TaskGraph("copy") + .task("copy", SpecialisedFloat::copy, size, a, c) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a, c) + .snapshot()); + this.mulTask = + new TornadoExecutionPlan( + new TaskGraph("mul") + .task("mul", SpecialisedFloat::mul, size, b, c, scalar) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, b, c) + .snapshot()); + this.addTask = + new TornadoExecutionPlan( + new TaskGraph("add") + .task("add", SpecialisedFloat::add, size, a, b, c) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b, c) + .snapshot()); + this.triadTask = + new TornadoExecutionPlan( + new TaskGraph("triad") + .task("triad", SpecialisedFloat::triad, size, a, b, c, scalar) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b, c) + .snapshot()); + this.nstreamTask = + new TornadoExecutionPlan( + new TaskGraph("nstream") + .task("nstream", SpecialisedFloat::nstream, size, a, b, c, scalar) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b, c) + .snapshot()); + this.dotTask = + new TornadoExecutionPlan( + new TaskGraph("dot") + .task("dot", SpecialisedFloat::dot_, a, b, dotSum) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b) + .transferToHost(DataTransferMode.EVERY_EXECUTION, new Object[] {dotSum}) + .snapshot()); } @Override @@ -72,7 +106,7 @@ final class SpecialisedFloat extends GenericTornadoVMStream { Arrays.fill(a, config.initA); Arrays.fill(b, config.initB); Arrays.fill(c, config.initC); - TornadoVMStreams.xferToDevice(device, a, b, c); + TornadoVMStreams.allocAndXferToDevice(device, a, b, c); } @Override @@ -81,7 +115,7 @@ final class SpecialisedFloat extends GenericTornadoVMStream { } @Override - public Data data() { + public Data readArrays() { TornadoVMStreams.xferFromDevice(device, a, b, c); return new Data<>(boxed(a), boxed(b), boxed(c)); } diff --git a/src/java/java-stream/src/main/java/javastream/tornadovm/TornadoVMStreams.java b/src/java/java-stream/src/main/java/javastream/tornadovm/TornadoVMStreams.java index 68eecad..a43c7c8 100644 --- a/src/java/java-stream/src/main/java/javastream/tornadovm/TornadoVMStreams.java +++ b/src/java/java-stream/src/main/java/javastream/tornadovm/TornadoVMStreams.java @@ -1,36 +1,46 @@ package javastream.tornadovm; +import java.util.Arrays; import java.util.List; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; import javastream.JavaStream; import javastream.Main.Config; -import uk.ac.manchester.tornado.api.TornadoRuntimeCI; +import uk.ac.manchester.tornado.api.TornadoRuntimeInterface; +import uk.ac.manchester.tornado.api.common.Event; import uk.ac.manchester.tornado.api.common.TornadoDevice; -import uk.ac.manchester.tornado.api.mm.TornadoGlobalObjectState; +import uk.ac.manchester.tornado.api.memory.TornadoDeviceObjectState; +import uk.ac.manchester.tornado.api.memory.TornadoGlobalObjectState; import uk.ac.manchester.tornado.api.runtime.TornadoRuntime; public final class TornadoVMStreams { private TornadoVMStreams() {} - static void xferToDevice(TornadoDevice device, Object... xs) { + static void allocAndXferToDevice(TornadoDevice device, Object... xs) { for (Object x : xs) { TornadoGlobalObjectState state = TornadoRuntime.getTornadoRuntime().resolveObject(x); + device.allocateObjects( + new Object[] {x}, 0, new TornadoDeviceObjectState[] {state.getDeviceState(device)}); List writeEvent = device.ensurePresent(x, state.getDeviceState(device), null, 0, 0); if (writeEvent != null) writeEvent.forEach(e -> device.resolveEvent(e).waitOn()); } } static void xferFromDevice(TornadoDevice device, Object... xs) { - for (Object x : xs) { - TornadoGlobalObjectState state = TornadoRuntime.getTornadoRuntime().resolveObject(x); - device.resolveEvent(device.streamOut(x, 0, state.getDeviceState(device), null)).waitOn(); - } + Arrays.stream(xs) + .map( + x -> { + TornadoGlobalObjectState state = TornadoRuntime.getTornadoRuntime().resolveObject(x); + return device.resolveEvent( + device.streamOut(x, 0, state.getDeviceState(device), null)); + }) + .collect(Collectors.toList()) + .forEach(Event::waitOn); } - static List enumerateDevices(TornadoRuntimeCI runtime) { + static List enumerateDevices(TornadoRuntimeInterface runtime) { return IntStream.range(0, runtime.getNumDrivers()) .mapToObj(runtime::getDriver) .flatMap(d -> IntStream.range(0, d.getDeviceCount()).mapToObj(d::getDevice))