Add init/read timing for Java
Upgrade to TornadoVM 0.15 API
This commit is contained in:
parent
971d1e8ac7
commit
3de019c156
@ -12,7 +12,7 @@
|
|||||||
<properties>
|
<properties>
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
|
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
|
||||||
<junit.version>5.7.2</junit.version>
|
<junit.version>5.9.2</junit.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<repositories>
|
<repositories>
|
||||||
@ -27,19 +27,19 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.beust</groupId>
|
<groupId>com.beust</groupId>
|
||||||
<artifactId>jcommander</artifactId>
|
<artifactId>jcommander</artifactId>
|
||||||
<version>1.81</version>
|
<version>1.82</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>tornado</groupId>
|
<groupId>tornado</groupId>
|
||||||
<artifactId>tornado-api</artifactId>
|
<artifactId>tornado-api</artifactId>
|
||||||
<version>0.9</version>
|
<version>0.15.1</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.aparapi</groupId>
|
<groupId>com.aparapi</groupId>
|
||||||
<artifactId>aparapi</artifactId>
|
<artifactId>aparapi</artifactId>
|
||||||
<version>2.0.0</version>
|
<version>3.0.0</version>
|
||||||
<exclusions>
|
<exclusions>
|
||||||
<!-- don't pull in the entire Scala ecosystem! -->
|
<!-- don't pull in the entire Scala ecosystem! -->
|
||||||
<exclusion>
|
<exclusion>
|
||||||
|
|||||||
@ -56,7 +56,7 @@ public abstract class JavaStream<T> {
|
|||||||
|
|
||||||
protected abstract T dot();
|
protected abstract T dot();
|
||||||
|
|
||||||
protected abstract Data<T> data();
|
protected abstract Data<T> readArrays();
|
||||||
|
|
||||||
public static class EnumeratedStream<T> extends JavaStream<T> {
|
public static class EnumeratedStream<T> extends JavaStream<T> {
|
||||||
|
|
||||||
@ -113,8 +113,8 @@ public abstract class JavaStream<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Data<T> data() {
|
public Data<T> readArrays() {
|
||||||
return actual.data();
|
return actual.readArrays();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,6 +140,14 @@ public abstract class JavaStream<T> {
|
|||||||
return Duration.ofNanos(end - start);
|
return Duration.ofNanos(end - start);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
final Duration runInitArrays() {
|
||||||
|
return timed(this::initArrays);
|
||||||
|
}
|
||||||
|
|
||||||
|
final SimpleImmutableEntry<Duration, Data<T>> runReadArrays() {
|
||||||
|
return timed(this::readArrays);
|
||||||
|
}
|
||||||
|
|
||||||
final SimpleImmutableEntry<Timings<Duration>, T> runAll(int times) {
|
final SimpleImmutableEntry<Timings<Duration>, T> runAll(int times) {
|
||||||
Timings<Duration> timings = new Timings<>();
|
Timings<Duration> timings = new Timings<>();
|
||||||
T lastSum = null;
|
T lastSum = null;
|
||||||
|
|||||||
@ -128,6 +128,40 @@ public class Main {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
static void showInit(
|
||||||
|
int totalBytes, double megaScale, Options opt, Duration init, Duration read) {
|
||||||
|
List<Entry<String, Double>> setup =
|
||||||
|
Arrays.asList(
|
||||||
|
new SimpleImmutableEntry<>("Init", durationToSeconds(init)),
|
||||||
|
new SimpleImmutableEntry<>("Read", durationToSeconds(read)));
|
||||||
|
if (opt.csv) {
|
||||||
|
tabulateCsv(
|
||||||
|
true,
|
||||||
|
setup.stream()
|
||||||
|
.map(
|
||||||
|
x ->
|
||||||
|
Arrays.asList(
|
||||||
|
new SimpleImmutableEntry<>("function", x.getKey()),
|
||||||
|
new SimpleImmutableEntry<>("n_elements", opt.arraysize + ""),
|
||||||
|
new SimpleImmutableEntry<>("sizeof", totalBytes + ""),
|
||||||
|
new SimpleImmutableEntry<>(
|
||||||
|
"max_m" + (opt.mibibytes ? "i" : "") + "bytes_per_sec",
|
||||||
|
((megaScale * (double) totalBytes / x.getValue())) + ""),
|
||||||
|
new SimpleImmutableEntry<>("runtime", x.getValue() + "")))
|
||||||
|
.toArray(List[]::new));
|
||||||
|
} else {
|
||||||
|
for (Entry<String, Double> e : setup) {
|
||||||
|
System.out.printf(
|
||||||
|
"%s: %.5f s (%.5f M%sBytes/sec)%n",
|
||||||
|
e.getKey(),
|
||||||
|
e.getValue(),
|
||||||
|
megaScale * (double) totalBytes / e.getValue(),
|
||||||
|
opt.mibibytes ? "i" : "");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static <T extends Number> boolean run(
|
static <T extends Number> boolean run(
|
||||||
String name, Config<T> config, Function<Config<T>, JavaStream<T>> mkStream) {
|
String name, Config<T> config, Function<Config<T>, JavaStream<T>> mkStream) {
|
||||||
|
|
||||||
@ -183,35 +217,46 @@ public class Main {
|
|||||||
|
|
||||||
JavaStream<T> stream = mkStream.apply(config);
|
JavaStream<T> stream = mkStream.apply(config);
|
||||||
|
|
||||||
stream.initArrays();
|
Duration init = stream.runInitArrays();
|
||||||
|
|
||||||
final boolean ok;
|
final boolean ok;
|
||||||
switch (config.benchmark) {
|
switch (config.benchmark) {
|
||||||
case ALL:
|
case ALL:
|
||||||
Entry<Timings<Duration>, T> results = stream.runAll(opt.numtimes);
|
{
|
||||||
ok = checkSolutions(stream.data(), config, Optional.of(results.getValue()));
|
Entry<Timings<Duration>, T> results = stream.runAll(opt.numtimes);
|
||||||
Timings<Duration> timings = results.getKey();
|
SimpleImmutableEntry<Duration, Data<T>> read = stream.runReadArrays();
|
||||||
tabulateCsv(
|
showInit(totalBytes, megaScale, opt, init, read.getKey());
|
||||||
opt.csv,
|
ok = checkSolutions(read.getValue(), config, Optional.of(results.getValue()));
|
||||||
mkCsvRow(timings.copy, "Copy", 2 * arrayBytes, megaScale, opt),
|
Timings<Duration> timings = results.getKey();
|
||||||
mkCsvRow(timings.mul, "Mul", 2 * arrayBytes, megaScale, opt),
|
tabulateCsv(
|
||||||
mkCsvRow(timings.add, "Add", 3 * arrayBytes, megaScale, opt),
|
opt.csv,
|
||||||
mkCsvRow(timings.triad, "Triad", 3 * arrayBytes, megaScale, opt),
|
mkCsvRow(timings.copy, "Copy", 2 * arrayBytes, megaScale, opt),
|
||||||
mkCsvRow(timings.dot, "Dot", 2 * arrayBytes, megaScale, opt));
|
mkCsvRow(timings.mul, "Mul", 2 * arrayBytes, megaScale, opt),
|
||||||
break;
|
mkCsvRow(timings.add, "Add", 3 * arrayBytes, megaScale, opt),
|
||||||
|
mkCsvRow(timings.triad, "Triad", 3 * arrayBytes, megaScale, opt),
|
||||||
|
mkCsvRow(timings.dot, "Dot", 2 * arrayBytes, megaScale, opt));
|
||||||
|
break;
|
||||||
|
}
|
||||||
case NSTREAM:
|
case NSTREAM:
|
||||||
List<Duration> nstreamResults = stream.runNStream(opt.numtimes);
|
{
|
||||||
ok = checkSolutions(stream.data(), config, Optional.empty());
|
List<Duration> nstreamResults = stream.runNStream(opt.numtimes);
|
||||||
tabulateCsv(opt.csv, mkCsvRow(nstreamResults, "Nstream", 4 * arrayBytes, megaScale, opt));
|
SimpleImmutableEntry<Duration, Data<T>> read = stream.runReadArrays();
|
||||||
break;
|
showInit(totalBytes, megaScale, opt, init, read.getKey());
|
||||||
|
ok = checkSolutions(read.getValue(), config, Optional.empty());
|
||||||
|
tabulateCsv(opt.csv, mkCsvRow(nstreamResults, "Nstream", 4 * arrayBytes, megaScale, opt));
|
||||||
|
break;
|
||||||
|
}
|
||||||
case TRIAD:
|
case TRIAD:
|
||||||
Duration triadResult = stream.runTriad(opt.numtimes);
|
{
|
||||||
ok = checkSolutions(stream.data(), config, Optional.empty());
|
Duration triadResult = stream.runTriad(opt.numtimes);
|
||||||
int triadTotalBytes = 3 * arrayBytes * opt.numtimes;
|
SimpleImmutableEntry<Duration, Data<T>> read = stream.runReadArrays();
|
||||||
double bandwidth = megaScale * (triadTotalBytes / durationToSeconds(triadResult));
|
showInit(totalBytes, megaScale, opt, init, read.getKey());
|
||||||
System.out.printf("Runtime (seconds): %.5f", durationToSeconds(triadResult));
|
ok = checkSolutions(read.getValue(), config, Optional.empty());
|
||||||
System.out.printf("Bandwidth (%s/s): %.3f ", gigaSuffix, bandwidth);
|
int triadTotalBytes = 3 * arrayBytes * opt.numtimes;
|
||||||
break;
|
double bandwidth = megaScale * (triadTotalBytes / durationToSeconds(triadResult));
|
||||||
|
System.out.printf("Runtime (seconds): %.5f", durationToSeconds(triadResult));
|
||||||
|
System.out.printf("Bandwidth (%s/s): %.3f ", gigaSuffix, bandwidth);
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
throw new AssertionError();
|
throw new AssertionError();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -122,7 +122,7 @@ public final class AparapiStreams {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Data<T> data() {
|
public Data<T> readArrays() {
|
||||||
return kernels.syncAndDispose();
|
return kernels.syncAndDispose();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -86,7 +86,7 @@ final class GenericPlainStream<T extends Number> extends JavaStream<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Data<T> data() {
|
public Data<T> readArrays() {
|
||||||
return new Data<>(a, b, c);
|
return new Data<>(a, b, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -80,7 +80,7 @@ final class GenericStream<T extends Number> extends JavaStream<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Data<T> data() {
|
public Data<T> readArrays() {
|
||||||
return new Data<>(a, b, c);
|
return new Data<>(a, b, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -78,7 +78,7 @@ final class SpecialisedDoubleStream extends JavaStream<Double> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Data<Double> data() {
|
public Data<Double> readArrays() {
|
||||||
return new Data<>(boxed(a), boxed(b), boxed(c));
|
return new Data<>(boxed(a), boxed(b), boxed(c));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -78,7 +78,7 @@ final class SpecialisedFloatStream extends JavaStream<Float> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Data<Float> data() {
|
public Data<Float> readArrays() {
|
||||||
return new Data<>(boxed(a), boxed(b), boxed(c));
|
return new Data<>(boxed(a), boxed(b), boxed(c));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -78,7 +78,7 @@ final class SpecialisedPlainDoubleStream extends JavaStream<Double> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Data<Double> data() {
|
public Data<Double> readArrays() {
|
||||||
return new Data<>(boxed(a), boxed(b), boxed(c));
|
return new Data<>(boxed(a), boxed(b), boxed(c));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -78,7 +78,7 @@ final class SpecialisedPlainFloatStream extends JavaStream<Float> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Data<Float> data() {
|
public Data<Float> readArrays() {
|
||||||
return new Data<>(boxed(a), boxed(b), boxed(c));
|
return new Data<>(boxed(a), boxed(b), boxed(c));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,8 +4,8 @@ import java.util.List;
|
|||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import javastream.JavaStream;
|
import javastream.JavaStream;
|
||||||
import javastream.Main.Config;
|
import javastream.Main.Config;
|
||||||
import uk.ac.manchester.tornado.api.TaskSchedule;
|
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
|
||||||
import uk.ac.manchester.tornado.api.TornadoRuntimeCI;
|
import uk.ac.manchester.tornado.api.TornadoRuntimeInterface;
|
||||||
import uk.ac.manchester.tornado.api.common.TornadoDevice;
|
import uk.ac.manchester.tornado.api.common.TornadoDevice;
|
||||||
import uk.ac.manchester.tornado.api.runtime.TornadoRuntime;
|
import uk.ac.manchester.tornado.api.runtime.TornadoRuntime;
|
||||||
|
|
||||||
@ -13,18 +13,18 @@ abstract class GenericTornadoVMStream<T> extends JavaStream<T> {
|
|||||||
|
|
||||||
protected final TornadoDevice device;
|
protected final TornadoDevice device;
|
||||||
|
|
||||||
protected TaskSchedule copyTask;
|
protected TornadoExecutionPlan copyTask;
|
||||||
protected TaskSchedule mulTask;
|
protected TornadoExecutionPlan mulTask;
|
||||||
protected TaskSchedule addTask;
|
protected TornadoExecutionPlan addTask;
|
||||||
protected TaskSchedule triadTask;
|
protected TornadoExecutionPlan triadTask;
|
||||||
protected TaskSchedule nstreamTask;
|
protected TornadoExecutionPlan nstreamTask;
|
||||||
protected TaskSchedule dotTask;
|
protected TornadoExecutionPlan dotTask;
|
||||||
|
|
||||||
GenericTornadoVMStream(Config<T> config) {
|
GenericTornadoVMStream(Config<T> config) {
|
||||||
super(config);
|
super(config);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
TornadoRuntimeCI runtime = TornadoRuntime.getTornadoRuntime();
|
TornadoRuntimeInterface runtime = TornadoRuntime.getTornadoRuntime();
|
||||||
List<TornadoDevice> devices = TornadoVMStreams.enumerateDevices(runtime);
|
List<TornadoDevice> devices = TornadoVMStreams.enumerateDevices(runtime);
|
||||||
device = devices.get(config.options.device);
|
device = devices.get(config.options.device);
|
||||||
|
|
||||||
@ -42,10 +42,6 @@ abstract class GenericTornadoVMStream<T> extends JavaStream<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected static TaskSchedule mkSchedule() {
|
|
||||||
return new TaskSchedule("");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<String> listDevices() {
|
public List<String> listDevices() {
|
||||||
return TornadoVMStreams.enumerateDevices(TornadoRuntime.getTornadoRuntime()).stream()
|
return TornadoVMStreams.enumerateDevices(TornadoRuntime.getTornadoRuntime()).stream()
|
||||||
@ -55,12 +51,12 @@ abstract class GenericTornadoVMStream<T> extends JavaStream<T> {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initArrays() {
|
public void initArrays() {
|
||||||
this.copyTask.warmup();
|
this.copyTask.withWarmUp();
|
||||||
this.mulTask.warmup();
|
this.mulTask.withWarmUp();
|
||||||
this.addTask.warmup();
|
this.addTask.withWarmUp();
|
||||||
this.triadTask.warmup();
|
this.triadTask.withWarmUp();
|
||||||
this.nstreamTask.warmup();
|
this.nstreamTask.withWarmUp();
|
||||||
this.dotTask.warmup();
|
this.dotTask.withWarmUp();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@ -2,8 +2,11 @@ package javastream.tornadovm;
|
|||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import javastream.Main.Config;
|
import javastream.Main.Config;
|
||||||
|
import uk.ac.manchester.tornado.api.TaskGraph;
|
||||||
|
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
|
||||||
import uk.ac.manchester.tornado.api.annotations.Parallel;
|
import uk.ac.manchester.tornado.api.annotations.Parallel;
|
||||||
import uk.ac.manchester.tornado.api.annotations.Reduce;
|
import uk.ac.manchester.tornado.api.annotations.Reduce;
|
||||||
|
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
|
||||||
|
|
||||||
final class SpecialisedDouble extends GenericTornadoVMStream<Double> {
|
final class SpecialisedDouble extends GenericTornadoVMStream<Double> {
|
||||||
|
|
||||||
@ -49,7 +52,7 @@ final class SpecialisedDouble extends GenericTornadoVMStream<Double> {
|
|||||||
private final double[] a, b, c;
|
private final double[] a, b, c;
|
||||||
private final double[] dotSum;
|
private final double[] dotSum;
|
||||||
|
|
||||||
@SuppressWarnings({"PrimitiveArrayArgumentToVarargsMethod", "DuplicatedCode"})
|
@SuppressWarnings({"DuplicatedCode"})
|
||||||
SpecialisedDouble(Config<Double> config) {
|
SpecialisedDouble(Config<Double> config) {
|
||||||
super(config);
|
super(config);
|
||||||
final int size = config.options.arraysize;
|
final int size = config.options.arraysize;
|
||||||
@ -58,12 +61,43 @@ final class SpecialisedDouble extends GenericTornadoVMStream<Double> {
|
|||||||
b = new double[size];
|
b = new double[size];
|
||||||
c = new double[size];
|
c = new double[size];
|
||||||
dotSum = new double[1];
|
dotSum = new double[1];
|
||||||
this.copyTask = mkSchedule().task("", SpecialisedDouble::copy, size, a, c);
|
this.copyTask =
|
||||||
this.mulTask = mkSchedule().task("", SpecialisedDouble::mul, size, b, c, scalar);
|
new TornadoExecutionPlan(
|
||||||
this.addTask = mkSchedule().task("", SpecialisedDouble::add, size, a, b, c);
|
new TaskGraph("copy")
|
||||||
this.triadTask = mkSchedule().task("", SpecialisedDouble::triad, size, a, b, c, scalar);
|
.task("copy", SpecialisedDouble::copy, size, a, c)
|
||||||
this.nstreamTask = mkSchedule().task("", SpecialisedDouble::nstream, size, a, b, c, scalar);
|
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, c)
|
||||||
this.dotTask = mkSchedule().task("", SpecialisedDouble::dot_, a, b, dotSum).streamOut(dotSum);
|
.snapshot());
|
||||||
|
this.mulTask =
|
||||||
|
new TornadoExecutionPlan(
|
||||||
|
new TaskGraph("mul")
|
||||||
|
.task("mul", SpecialisedDouble::mul, size, b, c, scalar)
|
||||||
|
.transferToDevice(DataTransferMode.FIRST_EXECUTION, b, c)
|
||||||
|
.snapshot());
|
||||||
|
this.addTask =
|
||||||
|
new TornadoExecutionPlan(
|
||||||
|
new TaskGraph("add")
|
||||||
|
.task("add", SpecialisedDouble::add, size, a, b, c)
|
||||||
|
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b, c)
|
||||||
|
.snapshot());
|
||||||
|
this.triadTask =
|
||||||
|
new TornadoExecutionPlan(
|
||||||
|
new TaskGraph("triad")
|
||||||
|
.task("triad", SpecialisedDouble::triad, size, a, b, c, scalar)
|
||||||
|
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b, c)
|
||||||
|
.snapshot());
|
||||||
|
this.nstreamTask =
|
||||||
|
new TornadoExecutionPlan(
|
||||||
|
new TaskGraph("nstream")
|
||||||
|
.task("nstream", SpecialisedDouble::nstream, size, a, b, c, scalar)
|
||||||
|
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b, c)
|
||||||
|
.snapshot());
|
||||||
|
this.dotTask =
|
||||||
|
new TornadoExecutionPlan(
|
||||||
|
new TaskGraph("dot")
|
||||||
|
.task("dot", SpecialisedDouble::dot_, a, b, dotSum)
|
||||||
|
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b)
|
||||||
|
.transferToHost(DataTransferMode.EVERY_EXECUTION, new Object[] {dotSum})
|
||||||
|
.snapshot());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -72,7 +106,7 @@ final class SpecialisedDouble extends GenericTornadoVMStream<Double> {
|
|||||||
Arrays.fill(a, config.initA);
|
Arrays.fill(a, config.initA);
|
||||||
Arrays.fill(b, config.initB);
|
Arrays.fill(b, config.initB);
|
||||||
Arrays.fill(c, config.initC);
|
Arrays.fill(c, config.initC);
|
||||||
TornadoVMStreams.xferToDevice(device, a, b, c);
|
TornadoVMStreams.allocAndXferToDevice(device, a, b, c);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -81,7 +115,7 @@ final class SpecialisedDouble extends GenericTornadoVMStream<Double> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Data<Double> data() {
|
public Data<Double> readArrays() {
|
||||||
TornadoVMStreams.xferFromDevice(device, a, b, c);
|
TornadoVMStreams.xferFromDevice(device, a, b, c);
|
||||||
return new Data<>(boxed(a), boxed(b), boxed(c));
|
return new Data<>(boxed(a), boxed(b), boxed(c));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,8 +2,11 @@ package javastream.tornadovm;
|
|||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import javastream.Main.Config;
|
import javastream.Main.Config;
|
||||||
|
import uk.ac.manchester.tornado.api.TaskGraph;
|
||||||
|
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
|
||||||
import uk.ac.manchester.tornado.api.annotations.Parallel;
|
import uk.ac.manchester.tornado.api.annotations.Parallel;
|
||||||
import uk.ac.manchester.tornado.api.annotations.Reduce;
|
import uk.ac.manchester.tornado.api.annotations.Reduce;
|
||||||
|
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
|
||||||
|
|
||||||
final class SpecialisedFloat extends GenericTornadoVMStream<Float> {
|
final class SpecialisedFloat extends GenericTornadoVMStream<Float> {
|
||||||
|
|
||||||
@ -49,7 +52,7 @@ final class SpecialisedFloat extends GenericTornadoVMStream<Float> {
|
|||||||
private final float[] a, b, c;
|
private final float[] a, b, c;
|
||||||
private final float[] dotSum;
|
private final float[] dotSum;
|
||||||
|
|
||||||
@SuppressWarnings({"PrimitiveArrayArgumentToVarargsMethod", "DuplicatedCode"})
|
@SuppressWarnings({"DuplicatedCode"})
|
||||||
SpecialisedFloat(Config<Float> config) {
|
SpecialisedFloat(Config<Float> config) {
|
||||||
super(config);
|
super(config);
|
||||||
final int size = config.options.arraysize;
|
final int size = config.options.arraysize;
|
||||||
@ -58,12 +61,43 @@ final class SpecialisedFloat extends GenericTornadoVMStream<Float> {
|
|||||||
b = new float[size];
|
b = new float[size];
|
||||||
c = new float[size];
|
c = new float[size];
|
||||||
dotSum = new float[1];
|
dotSum = new float[1];
|
||||||
this.copyTask = mkSchedule().task("", SpecialisedFloat::copy, size, a, c);
|
this.copyTask =
|
||||||
this.mulTask = mkSchedule().task("", SpecialisedFloat::mul, size, b, c, scalar);
|
new TornadoExecutionPlan(
|
||||||
this.addTask = mkSchedule().task("", SpecialisedFloat::add, size, a, b, c);
|
new TaskGraph("copy")
|
||||||
this.triadTask = mkSchedule().task("", SpecialisedFloat::triad, size, a, b, c, scalar);
|
.task("copy", SpecialisedFloat::copy, size, a, c)
|
||||||
this.nstreamTask = mkSchedule().task("", SpecialisedFloat::nstream, size, a, b, c, scalar);
|
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, c)
|
||||||
this.dotTask = mkSchedule().task("", SpecialisedFloat::dot_, a, b, dotSum).streamOut(dotSum);
|
.snapshot());
|
||||||
|
this.mulTask =
|
||||||
|
new TornadoExecutionPlan(
|
||||||
|
new TaskGraph("mul")
|
||||||
|
.task("mul", SpecialisedFloat::mul, size, b, c, scalar)
|
||||||
|
.transferToDevice(DataTransferMode.FIRST_EXECUTION, b, c)
|
||||||
|
.snapshot());
|
||||||
|
this.addTask =
|
||||||
|
new TornadoExecutionPlan(
|
||||||
|
new TaskGraph("add")
|
||||||
|
.task("add", SpecialisedFloat::add, size, a, b, c)
|
||||||
|
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b, c)
|
||||||
|
.snapshot());
|
||||||
|
this.triadTask =
|
||||||
|
new TornadoExecutionPlan(
|
||||||
|
new TaskGraph("triad")
|
||||||
|
.task("triad", SpecialisedFloat::triad, size, a, b, c, scalar)
|
||||||
|
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b, c)
|
||||||
|
.snapshot());
|
||||||
|
this.nstreamTask =
|
||||||
|
new TornadoExecutionPlan(
|
||||||
|
new TaskGraph("nstream")
|
||||||
|
.task("nstream", SpecialisedFloat::nstream, size, a, b, c, scalar)
|
||||||
|
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b, c)
|
||||||
|
.snapshot());
|
||||||
|
this.dotTask =
|
||||||
|
new TornadoExecutionPlan(
|
||||||
|
new TaskGraph("dot")
|
||||||
|
.task("dot", SpecialisedFloat::dot_, a, b, dotSum)
|
||||||
|
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b)
|
||||||
|
.transferToHost(DataTransferMode.EVERY_EXECUTION, new Object[] {dotSum})
|
||||||
|
.snapshot());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -72,7 +106,7 @@ final class SpecialisedFloat extends GenericTornadoVMStream<Float> {
|
|||||||
Arrays.fill(a, config.initA);
|
Arrays.fill(a, config.initA);
|
||||||
Arrays.fill(b, config.initB);
|
Arrays.fill(b, config.initB);
|
||||||
Arrays.fill(c, config.initC);
|
Arrays.fill(c, config.initC);
|
||||||
TornadoVMStreams.xferToDevice(device, a, b, c);
|
TornadoVMStreams.allocAndXferToDevice(device, a, b, c);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -81,7 +115,7 @@ final class SpecialisedFloat extends GenericTornadoVMStream<Float> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Data<Float> data() {
|
public Data<Float> readArrays() {
|
||||||
TornadoVMStreams.xferFromDevice(device, a, b, c);
|
TornadoVMStreams.xferFromDevice(device, a, b, c);
|
||||||
return new Data<>(boxed(a), boxed(b), boxed(c));
|
return new Data<>(boxed(a), boxed(b), boxed(c));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,36 +1,46 @@
|
|||||||
package javastream.tornadovm;
|
package javastream.tornadovm;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.IntStream;
|
import java.util.stream.IntStream;
|
||||||
import javastream.JavaStream;
|
import javastream.JavaStream;
|
||||||
import javastream.Main.Config;
|
import javastream.Main.Config;
|
||||||
import uk.ac.manchester.tornado.api.TornadoRuntimeCI;
|
import uk.ac.manchester.tornado.api.TornadoRuntimeInterface;
|
||||||
|
import uk.ac.manchester.tornado.api.common.Event;
|
||||||
import uk.ac.manchester.tornado.api.common.TornadoDevice;
|
import uk.ac.manchester.tornado.api.common.TornadoDevice;
|
||||||
import uk.ac.manchester.tornado.api.mm.TornadoGlobalObjectState;
|
import uk.ac.manchester.tornado.api.memory.TornadoDeviceObjectState;
|
||||||
|
import uk.ac.manchester.tornado.api.memory.TornadoGlobalObjectState;
|
||||||
import uk.ac.manchester.tornado.api.runtime.TornadoRuntime;
|
import uk.ac.manchester.tornado.api.runtime.TornadoRuntime;
|
||||||
|
|
||||||
public final class TornadoVMStreams {
|
public final class TornadoVMStreams {
|
||||||
|
|
||||||
private TornadoVMStreams() {}
|
private TornadoVMStreams() {}
|
||||||
|
|
||||||
static void xferToDevice(TornadoDevice device, Object... xs) {
|
static void allocAndXferToDevice(TornadoDevice device, Object... xs) {
|
||||||
for (Object x : xs) {
|
for (Object x : xs) {
|
||||||
TornadoGlobalObjectState state = TornadoRuntime.getTornadoRuntime().resolveObject(x);
|
TornadoGlobalObjectState state = TornadoRuntime.getTornadoRuntime().resolveObject(x);
|
||||||
|
device.allocateObjects(
|
||||||
|
new Object[] {x}, 0, new TornadoDeviceObjectState[] {state.getDeviceState(device)});
|
||||||
List<Integer> writeEvent = device.ensurePresent(x, state.getDeviceState(device), null, 0, 0);
|
List<Integer> writeEvent = device.ensurePresent(x, state.getDeviceState(device), null, 0, 0);
|
||||||
if (writeEvent != null) writeEvent.forEach(e -> device.resolveEvent(e).waitOn());
|
if (writeEvent != null) writeEvent.forEach(e -> device.resolveEvent(e).waitOn());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void xferFromDevice(TornadoDevice device, Object... xs) {
|
static void xferFromDevice(TornadoDevice device, Object... xs) {
|
||||||
for (Object x : xs) {
|
Arrays.stream(xs)
|
||||||
TornadoGlobalObjectState state = TornadoRuntime.getTornadoRuntime().resolveObject(x);
|
.map(
|
||||||
device.resolveEvent(device.streamOut(x, 0, state.getDeviceState(device), null)).waitOn();
|
x -> {
|
||||||
}
|
TornadoGlobalObjectState state = TornadoRuntime.getTornadoRuntime().resolveObject(x);
|
||||||
|
return device.resolveEvent(
|
||||||
|
device.streamOut(x, 0, state.getDeviceState(device), null));
|
||||||
|
})
|
||||||
|
.collect(Collectors.toList())
|
||||||
|
.forEach(Event::waitOn);
|
||||||
}
|
}
|
||||||
|
|
||||||
static List<TornadoDevice> enumerateDevices(TornadoRuntimeCI runtime) {
|
static List<TornadoDevice> enumerateDevices(TornadoRuntimeInterface runtime) {
|
||||||
return IntStream.range(0, runtime.getNumDrivers())
|
return IntStream.range(0, runtime.getNumDrivers())
|
||||||
.mapToObj(runtime::getDriver)
|
.mapToObj(runtime::getDriver)
|
||||||
.flatMap(d -> IntStream.range(0, d.getDeviceCount()).mapToObj(d::getDevice))
|
.flatMap(d -> IntStream.range(0, d.getDeviceCount()).mapToObj(d::getDevice))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user