Commit df7df5b0b1d125f0882bf012cab9c7ce619da096

Authored by David Mayerich
2 parents db598823 f86a38d3

Merge branch 'JACK_netmets' into 'master'

add device choice and cpu implementation

See merge request !1
Showing 1 changed file with 38 additions and 13 deletions   Show diff stats
... ... @@ -14,6 +14,11 @@
14 14 #include <stim/parser/arguments.h>
15 15 #include <stim/visualization/camera.h>
16 16  
  17 +#ifdef __CUDACC__
  18 +//CUDA includes
  19 +#include <cuda.h>
  20 +#endif
  21 +
17 22 //ANN includes
18 23 //#include <ANN/ANN.h>
19 24  
... ... @@ -29,9 +34,9 @@ stim::gl_network&lt;float&gt; GT; //ground truth network
29 34 stim::gl_network<float> T; //test network
30 35  
31 36 //hard-coded parameters
32   -float resample_rate = 0.5; //sample rate for the network (fraction of sigma used as the maximum sample rate)
33   -float camera_factor = 1.2; //start point of the camera as a function of X and Y size
34   -float orbit_factor = 0.01; //degrees per pixel used to orbit the camera
  37 +float resample_rate = 0.5f; //sample rate for the network (fraction of sigma used as the maximum sample rate)
  38 +float camera_factor = 1.2f; //start point of the camera as a function of X and Y size
  39 +float orbit_factor = 0.01f; //degrees per pixel used to orbit the camera
35 40  
36 41 //mouse position tracking
37 42 int mouse_x;
... ... @@ -195,11 +200,26 @@ void glut_initialize(){
195 200 cam.LookAt(c[0], c[1], c[2]); //look at the center of the network
196 201 }
197 202  
  203 +#ifdef __CUDACC__
  204 +void setdevice(int &device){
  205 + int count;
  206 + cudaGetDeviceCount(&count); // numbers of device that are available
  207 + if(count < device + 1){
  208 + std::cout<<"No such device available, please set another device"<<std::endl;
  209 + exit(1);
  210 + }
  211 +}
  212 +#else
  213 +void setdevice(int &device){
  214 + device = -1;
  215 +}
  216 +#endif
  217 +
198 218 //compare both networks and fill the networks with error information
199   -void compare(float sigma){
  219 +void compare(float sigma, int device){
200 220  
201   - GT = GT.compare(T, sigma); //compare the ground truth to the test case - store errors in GT
202   - T = T.compare(GT, sigma); //compare the test case to the ground truth - store errors in T
  221 + GT = GT.compare(T, sigma, device); //compare the ground truth to the test case - store errors in GT
  222 + T = T.compare(GT, sigma, device); //compare the test case to the ground truth - store errors in T
203 223  
204 224 //calculate the metrics
205 225 float FPR = GT.average(1); //calculate the metrics
... ... @@ -212,7 +232,7 @@ void compare(float sigma){
212 232 // writes features of the networks i.e average segment length, tortuosity, branching index, contraction, fractal dimension, number of end and branch points to a csv file
213 233 // Pranathi wrote this - saves network features to a CSV file
214 234 void features(std::string filename){
215   - double avgL_t, avgL_gt, avgT_t, avgT_gt, avgB_t, avgB_gt, avgC_t, avgC_gt, avgFD_t, avgFD_gt;
  235 + double avgL_t, avgL_gt, avgT_t, avgT_gt, avgB_t, avgB_gt, avgC_t, avgC_gt, avgFD_t, avgFD_gt;
216 236 unsigned int e_t, e_gt, b_gt, b_t;
217 237 avgL_gt = GT.Lengths();
218 238 avgT_gt = GT.Tortuosities();
... ... @@ -246,10 +266,12 @@ void advertise(){
246 266 std::cout<<"Source: https://git.stim.ee.uh.edu/segmentation/netmets"<<std::endl;
247 267 std::cout<<"========================================================================="<<std::endl<<std::endl;
248 268  
249   - std::cout<<"usage: netmets file1 file2 --sigma 10"<<std::endl;
250   - std::cout<<" compare two files with a tolerance of 10 (units defined by the network)"<<std::endl;
251   - std::cout<<" netmets file1 --gui"<<std::endl<<std::endl;
252   - std::cout<<" load a file and display it using OpenGL"<<std::endl;
  269 + std::cout<<"usage: netmets file1 file2 --sigma 3"<<std::endl;
  270 + std::cout<<" compare two files with a tolerance of 3 (units defined by the network)"<<std::endl<<std::endl;
  271 + std::cout<<" netmets file1 --gui"<<std::endl;
  272 + std::cout<<" load a file and display it using OpenGL"<<std::endl<<std::endl;
  273 + std::cout<<" netmets file1 file2 --device 0"<<std::endl;
  274 + std::cout<<" compare two files using device 0 (if there isn't a gpu, use cpu)"<<std::endl<<std::endl;
253 275 }
254 276  
255 277 int main(int argc, char* argv[])
... ... @@ -258,8 +280,9 @@ int main(int argc, char* argv[])
258 280  
259 281 //add arguments
260 282 args.add("help", "prints this help");
261   - args.add("sigma", "force a sigma value to specify the tolerance of the network comparison", "10");
  283 + args.add("sigma", "force a sigma value to specify the tolerance of the network comparison", "3");
262 284 args.add("gui", "display the network or network comparison using OpenGL");
  285 + args.add("device", "choose specific device to run", "0");
263 286 args.add("features", "save features to a CSV file, specify file name");
264 287  
265 288 args.parse(argc, argv); //parse the user arguments
... ... @@ -277,6 +300,7 @@ int main(int argc, char* argv[])
277 300 }
278 301  
279 302 if(args.nargs() == 2){ //if two files are specified, they will be displayed in neighboring viewports and compared
  303 + int device = args["device"].as_int(); //get the device value from the user
280 304 num_nets = 2; //set the number of networks to two
281 305 float sigma = args["sigma"].as_float(); //get the sigma value from the user
282 306 T.load_obj(args.arg(1)); //load the second (test) network
... ... @@ -284,7 +308,8 @@ int main(int argc, char* argv[])
284 308 features(args["features"].as_string());
285 309 GT = GT.resample(resample_rate * sigma); //resample both networks based on the sigma value
286 310 T = T.resample(resample_rate * sigma);
287   - compare(sigma); //run the comparison algorithm
  311 + setdevice(device);
  312 + compare(sigma, device); //run the comparison algorithm
288 313 }
289 314  
290 315 //if a GUI is requested, display the network using OpenGL
... ...