diff --git a/src/CUDAStream.cu b/src/CUDAStream.cu index e7ce539..956be7d 100644 --- a/src/CUDAStream.cu +++ b/src/CUDAStream.cu @@ -12,8 +12,22 @@ void check_error(void) } template -CUDAStream::CUDAStream(const unsigned int ARRAY_SIZE) +CUDAStream::CUDAStream(const unsigned int ARRAY_SIZE, const int device_index) { + + // Set device + int count; + cudaGetDeviceCount(&count); + check_error(); + if (device_index >= count) + throw std::runtime_error("Invalid device index"); + cudaSetDevice(device_index); + check_error(); + + // Print out device information + std::cout << "Using OpenCL device " << getDeviceName(device_index) << std::endl; + std::cout << "Driver: " << getDeviceDriver(device_index) << std::endl; + array_size = ARRAY_SIZE; // Check buffers fit on the device diff --git a/src/CUDAStream.h b/src/CUDAStream.h index 34e0303..61e4882 100644 --- a/src/CUDAStream.h +++ b/src/CUDAStream.h @@ -20,7 +20,7 @@ class CUDAStream : public Stream public: - CUDAStream(const unsigned int); + CUDAStream(const unsigned int, const int); ~CUDAStream(); virtual void copy() override;