[SYCL] Actually use device_index to select device

This commit is contained in:
James Price 2016-05-08 21:35:24 +01:00
parent 3b3f6dfc26
commit 6d913bab4b
2 changed files with 20 additions and 9 deletions

View File

@ -21,12 +21,21 @@ void getDeviceList(void);
template <class T> template <class T>
SYCLStream<T>::SYCLStream(const unsigned int ARRAY_SIZE, const int device_index) SYCLStream<T>::SYCLStream(const unsigned int ARRAY_SIZE, const int device_index)
{ {
if (!cached)
getDeviceList();
array_size = ARRAY_SIZE; array_size = ARRAY_SIZE;
if (device_index >= devices.size())
throw std::runtime_error("Invalid device index");
device dev = devices[device_index];
// Print out device information // Print out device information
std::cout << "Using SYCL device " << getDeviceName(device_index) << std::endl; std::cout << "Using SYCL device " << getDeviceName(device_index) << std::endl;
std::cout << "Driver: " << getDeviceDriver(device_index) << std::endl; std::cout << "Driver: " << getDeviceDriver(device_index) << std::endl;
queue = new cl::sycl::queue(dev);
// Create buffers // Create buffers
d_a = new buffer<T>(array_size); d_a = new buffer<T>(array_size);
d_b = new buffer<T>(array_size); d_b = new buffer<T>(array_size);
@ -39,12 +48,14 @@ SYCLStream<T>::~SYCLStream()
delete d_a; delete d_a;
delete d_b; delete d_b;
delete d_c; delete d_c;
delete queue;
} }
template <class T> template <class T>
void SYCLStream<T>::copy() void SYCLStream<T>::copy()
{ {
queue.submit([&](handler &cgh) queue->submit([&](handler &cgh)
{ {
auto ka = d_a->template get_access<access::mode::read>(cgh); auto ka = d_a->template get_access<access::mode::read>(cgh);
auto kc = d_c->template get_access<access::mode::write>(cgh); auto kc = d_c->template get_access<access::mode::write>(cgh);
@ -53,14 +64,14 @@ void SYCLStream<T>::copy()
kc[item.get_global()] = ka[item.get_global()]; kc[item.get_global()] = ka[item.get_global()];
}); });
}); });
queue.wait(); queue->wait();
} }
template <class T> template <class T>
void SYCLStream<T>::mul() void SYCLStream<T>::mul()
{ {
const T scalar = 3.0; const T scalar = 3.0;
queue.submit([&](handler &cgh) queue->submit([&](handler &cgh)
{ {
auto kb = d_b->template get_access<access::mode::write>(cgh); auto kb = d_b->template get_access<access::mode::write>(cgh);
auto kc = d_c->template get_access<access::mode::read>(cgh); auto kc = d_c->template get_access<access::mode::read>(cgh);
@ -69,13 +80,13 @@ void SYCLStream<T>::mul()
kb[item.get_global()] = scalar * kc[item.get_global()]; kb[item.get_global()] = scalar * kc[item.get_global()];
}); });
}); });
queue.wait(); queue->wait();
} }
template <class T> template <class T>
void SYCLStream<T>::add() void SYCLStream<T>::add()
{ {
queue.submit([&](handler &cgh) queue->submit([&](handler &cgh)
{ {
auto ka = d_a->template get_access<access::mode::read>(cgh); auto ka = d_a->template get_access<access::mode::read>(cgh);
auto kb = d_b->template get_access<access::mode::read>(cgh); auto kb = d_b->template get_access<access::mode::read>(cgh);
@ -85,14 +96,14 @@ void SYCLStream<T>::add()
kc[item.get_global()] = ka[item.get_global()] + kb[item.get_global()]; kc[item.get_global()] = ka[item.get_global()] + kb[item.get_global()];
}); });
}); });
queue.wait(); queue->wait();
} }
template <class T> template <class T>
void SYCLStream<T>::triad() void SYCLStream<T>::triad()
{ {
const T scalar = 3.0; const T scalar = 3.0;
queue.submit([&](handler &cgh) queue->submit([&](handler &cgh)
{ {
auto ka = d_a->template get_access<access::mode::write>(cgh); auto ka = d_a->template get_access<access::mode::write>(cgh);
auto kb = d_b->template get_access<access::mode::read>(cgh); auto kb = d_b->template get_access<access::mode::read>(cgh);
@ -102,7 +113,7 @@ void SYCLStream<T>::triad()
ka[item.get_global()] = kb[item.get_global()] + scalar * kc[item.get_global()]; ka[item.get_global()] = kb[item.get_global()] + scalar * kc[item.get_global()];
}); });
}); });
queue.wait(); queue->wait();
} }
template <class T> template <class T>

View File

@ -21,7 +21,7 @@ class SYCLStream : public Stream<T>
unsigned int array_size; unsigned int array_size;
// SYCL objects // SYCL objects
cl::sycl::queue queue; cl::sycl::queue *queue;
cl::sycl::buffer<T> *d_a; cl::sycl::buffer<T> *d_a;
cl::sycl::buffer<T> *d_b; cl::sycl::buffer<T> *d_b;
cl::sycl::buffer<T> *d_c; cl::sycl::buffer<T> *d_c;