diff --git a/cuda-stream.cu b/cuda-stream.cu index c6ea9c6..70a0402 100644 --- a/cuda-stream.cu +++ b/cuda-stream.cu @@ -29,6 +29,13 @@ struct badtype : public std::exception } }; +struct invaliddevice : public std::exception +{ + virtual const char * what () const throw () + { + return "Chosen device index is invalid"; + } +}; size_t sizes[4] = { 2 * sizeof(DATATYPE) * ARRAY_SIZE, @@ -128,6 +135,12 @@ int main(int argc, char *argv[]) { parseArguments(argc, argv); + // Check device index is in range + int count; + cudaGetDeviceCount(&count); + if (deviceIndex >= count) throw invaliddevice(); + cudaSetDevice(deviceIndex); + // Print out device name std::cout << "Using CUDA device " << getDeviceName() << std::endl; @@ -254,21 +267,6 @@ int main(int argc, char *argv[]) } } -unsigned getDeviceList() -{ - - // // Enumerate devices - // for (unsigned int i = 0; i < platforms.size(); i++) - // { - // std::vector plat_devices; - // platforms[i].getDevices(CL_DEVICE_TYPE_ALL, &plat_devices); - // devices.insert(devices.end(), plat_devices.begin(), plat_devices.end()); - // } - - // return devices.size(); - return 0; -} - std::string getDeviceName() { int device; @@ -299,12 +297,12 @@ void parseArguments(int argc, char *argv[]) { if (!strcmp(argv[i], "--list")) { - // Get list of devices - /*std::vector devices; - getDeviceList(devices); + // Get number of devices + int count; + cudaGetDeviceCount(&count); // Print device names - if (devices.size() == 0) + if (count == 0) { std::cout << "No devices found." << std::endl; } @@ -312,12 +310,12 @@ void parseArguments(int argc, char *argv[]) { std::cout << std::endl; std::cout << "Devices:" << std::endl; - for (unsigned i = 0; i < devices.size(); i++) + for (int i = 0; i < count; i++) { - std::cout << i << ": " << getDeviceName(devices[i]) << std::endl; + std::cout << i << ": " << getDeviceName() << std::endl; } std::cout << std::endl; - }*/ + } exit(0); } else if (!strcmp(argv[i], "--device"))