Use device info to select CUDA device
This commit is contained in:
parent
3462e61c16
commit
fd121c2467
@ -12,8 +12,22 @@ void check_error(void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
CUDAStream<T>::CUDAStream(const unsigned int ARRAY_SIZE)
|
CUDAStream<T>::CUDAStream(const unsigned int ARRAY_SIZE, const int device_index)
|
||||||
{
|
{
|
||||||
|
|
||||||
|
// Set device
|
||||||
|
int count;
|
||||||
|
cudaGetDeviceCount(&count);
|
||||||
|
check_error();
|
||||||
|
if (device_index >= count)
|
||||||
|
throw std::runtime_error("Invalid device index");
|
||||||
|
cudaSetDevice(device_index);
|
||||||
|
check_error();
|
||||||
|
|
||||||
|
// Print out device information
|
||||||
|
std::cout << "Using OpenCL device " << getDeviceName(device_index) << std::endl;
|
||||||
|
std::cout << "Driver: " << getDeviceDriver(device_index) << std::endl;
|
||||||
|
|
||||||
array_size = ARRAY_SIZE;
|
array_size = ARRAY_SIZE;
|
||||||
|
|
||||||
// Check buffers fit on the device
|
// Check buffers fit on the device
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class CUDAStream : public Stream<T>
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
CUDAStream(const unsigned int);
|
CUDAStream(const unsigned int, const int);
|
||||||
~CUDAStream();
|
~CUDAStream();
|
||||||
|
|
||||||
virtual void copy() override;
|
virtual void copy() override;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user