From fd121c2467ef0c3554f3b8e292d2b06559905b00 Mon Sep 17 00:00:00 2001 From: Tom Deakin Date: Tue, 3 May 2016 11:15:38 +0100 Subject: [PATCH] Use device info to select CUDA device --- src/CUDAStream.cu | 16 +++++++++++++++- src/CUDAStream.h | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/CUDAStream.cu b/src/CUDAStream.cu index e7ce539..956be7d 100644 --- a/src/CUDAStream.cu +++ b/src/CUDAStream.cu @@ -12,8 +12,22 @@ void check_error(void) } template -CUDAStream::CUDAStream(const unsigned int ARRAY_SIZE) +CUDAStream::CUDAStream(const unsigned int ARRAY_SIZE, const int device_index) { + + // Set device + int count; + cudaGetDeviceCount(&count); + check_error(); + if (device_index >= count) + throw std::runtime_error("Invalid device index"); + cudaSetDevice(device_index); + check_error(); + + // Print out device information + std::cout << "Using OpenCL device " << getDeviceName(device_index) << std::endl; + std::cout << "Driver: " << getDeviceDriver(device_index) << std::endl; + array_size = ARRAY_SIZE; // Check buffers fit on the device diff --git a/src/CUDAStream.h b/src/CUDAStream.h index 34e0303..61e4882 100644 --- a/src/CUDAStream.h +++ b/src/CUDAStream.h @@ -20,7 +20,7 @@ class CUDAStream : public Stream public: - CUDAStream(const unsigned int); + CUDAStream(const unsigned int, const int); ~CUDAStream(); virtual void copy() override;