Commit df7df5b0b1d125f0882bf012cab9c7ce619da096
Merge branch 'JACK_netmets' into 'master'
add device choice and cpu implementation See merge request !1
Showing
1 changed file
with
38 additions
and
13 deletions
Show diff stats
main.cu
@@ -14,6 +14,11 @@ | @@ -14,6 +14,11 @@ | ||
14 | #include <stim/parser/arguments.h> | 14 | #include <stim/parser/arguments.h> |
15 | #include <stim/visualization/camera.h> | 15 | #include <stim/visualization/camera.h> |
16 | 16 | ||
17 | +#ifdef __CUDACC__ | ||
18 | +//CUDA includes | ||
19 | +#include <cuda.h> | ||
20 | +#endif | ||
21 | + | ||
17 | //ANN includes | 22 | //ANN includes |
18 | //#include <ANN/ANN.h> | 23 | //#include <ANN/ANN.h> |
19 | 24 | ||
@@ -29,9 +34,9 @@ stim::gl_network<float> GT; //ground truth network | @@ -29,9 +34,9 @@ stim::gl_network<float> GT; //ground truth network | ||
29 | stim::gl_network<float> T; //test network | 34 | stim::gl_network<float> T; //test network |
30 | 35 | ||
31 | //hard-coded parameters | 36 | //hard-coded parameters |
32 | -float resample_rate = 0.5; //sample rate for the network (fraction of sigma used as the maximum sample rate) | ||
33 | -float camera_factor = 1.2; //start point of the camera as a function of X and Y size | ||
34 | -float orbit_factor = 0.01; //degrees per pixel used to orbit the camera | 37 | +float resample_rate = 0.5f; //sample rate for the network (fraction of sigma used as the maximum sample rate) |
38 | +float camera_factor = 1.2f; //start point of the camera as a function of X and Y size | ||
39 | +float orbit_factor = 0.01f; //degrees per pixel used to orbit the camera | ||
35 | 40 | ||
36 | //mouse position tracking | 41 | //mouse position tracking |
37 | int mouse_x; | 42 | int mouse_x; |
@@ -195,11 +200,26 @@ void glut_initialize(){ | @@ -195,11 +200,26 @@ void glut_initialize(){ | ||
195 | cam.LookAt(c[0], c[1], c[2]); //look at the center of the network | 200 | cam.LookAt(c[0], c[1], c[2]); //look at the center of the network |
196 | } | 201 | } |
197 | 202 | ||
203 | +#ifdef __CUDACC__ | ||
204 | +void setdevice(int &device){ | ||
205 | + int count; | ||
206 | + cudaGetDeviceCount(&count); // numbers of device that are available | ||
207 | + if(count < device + 1){ | ||
208 | + std::cout<<"No such device available, please set another device"<<std::endl; | ||
209 | + exit(1); | ||
210 | + } | ||
211 | +} | ||
212 | +#else | ||
213 | +void setdevice(int &device){ | ||
214 | + device = -1; | ||
215 | +} | ||
216 | +#endif | ||
217 | + | ||
198 | //compare both networks and fill the networks with error information | 218 | //compare both networks and fill the networks with error information |
199 | -void compare(float sigma){ | 219 | +void compare(float sigma, int device){ |
200 | 220 | ||
201 | - GT = GT.compare(T, sigma); //compare the ground truth to the test case - store errors in GT | ||
202 | - T = T.compare(GT, sigma); //compare the test case to the ground truth - store errors in T | 221 | + GT = GT.compare(T, sigma, device); //compare the ground truth to the test case - store errors in GT |
222 | + T = T.compare(GT, sigma, device); //compare the test case to the ground truth - store errors in T | ||
203 | 223 | ||
204 | //calculate the metrics | 224 | //calculate the metrics |
205 | float FPR = GT.average(1); //calculate the metrics | 225 | float FPR = GT.average(1); //calculate the metrics |
@@ -212,7 +232,7 @@ void compare(float sigma){ | @@ -212,7 +232,7 @@ void compare(float sigma){ | ||
212 | // writes features of the networks i.e average segment length, tortuosity, branching index, contraction, fractal dimension, number of end and branch points to a csv file | 232 | // writes features of the networks i.e average segment length, tortuosity, branching index, contraction, fractal dimension, number of end and branch points to a csv file |
213 | // Pranathi wrote this - saves network features to a CSV file | 233 | // Pranathi wrote this - saves network features to a CSV file |
214 | void features(std::string filename){ | 234 | void features(std::string filename){ |
215 | - double avgL_t, avgL_gt, avgT_t, avgT_gt, avgB_t, avgB_gt, avgC_t, avgC_gt, avgFD_t, avgFD_gt; | 235 | + double avgL_t, avgL_gt, avgT_t, avgT_gt, avgB_t, avgB_gt, avgC_t, avgC_gt, avgFD_t, avgFD_gt; |
216 | unsigned int e_t, e_gt, b_gt, b_t; | 236 | unsigned int e_t, e_gt, b_gt, b_t; |
217 | avgL_gt = GT.Lengths(); | 237 | avgL_gt = GT.Lengths(); |
218 | avgT_gt = GT.Tortuosities(); | 238 | avgT_gt = GT.Tortuosities(); |
@@ -246,10 +266,12 @@ void advertise(){ | @@ -246,10 +266,12 @@ void advertise(){ | ||
246 | std::cout<<"Source: https://git.stim.ee.uh.edu/segmentation/netmets"<<std::endl; | 266 | std::cout<<"Source: https://git.stim.ee.uh.edu/segmentation/netmets"<<std::endl; |
247 | std::cout<<"========================================================================="<<std::endl<<std::endl; | 267 | std::cout<<"========================================================================="<<std::endl<<std::endl; |
248 | 268 | ||
249 | - std::cout<<"usage: netmets file1 file2 --sigma 10"<<std::endl; | ||
250 | - std::cout<<" compare two files with a tolerance of 10 (units defined by the network)"<<std::endl; | ||
251 | - std::cout<<" netmets file1 --gui"<<std::endl<<std::endl; | ||
252 | - std::cout<<" load a file and display it using OpenGL"<<std::endl; | 269 | + std::cout<<"usage: netmets file1 file2 --sigma 3"<<std::endl; |
270 | + std::cout<<" compare two files with a tolerance of 3 (units defined by the network)"<<std::endl<<std::endl; | ||
271 | + std::cout<<" netmets file1 --gui"<<std::endl; | ||
272 | + std::cout<<" load a file and display it using OpenGL"<<std::endl<<std::endl; | ||
273 | + std::cout<<" netmets file1 file2 --device 0"<<std::endl; | ||
274 | + std::cout<<" compare two files using device 0 (if there isn't a gpu, use cpu)"<<std::endl<<std::endl; | ||
253 | } | 275 | } |
254 | 276 | ||
255 | int main(int argc, char* argv[]) | 277 | int main(int argc, char* argv[]) |
@@ -258,8 +280,9 @@ int main(int argc, char* argv[]) | @@ -258,8 +280,9 @@ int main(int argc, char* argv[]) | ||
258 | 280 | ||
259 | //add arguments | 281 | //add arguments |
260 | args.add("help", "prints this help"); | 282 | args.add("help", "prints this help"); |
261 | - args.add("sigma", "force a sigma value to specify the tolerance of the network comparison", "10"); | 283 | + args.add("sigma", "force a sigma value to specify the tolerance of the network comparison", "3"); |
262 | args.add("gui", "display the network or network comparison using OpenGL"); | 284 | args.add("gui", "display the network or network comparison using OpenGL"); |
285 | + args.add("device", "choose specific device to run", "0"); | ||
263 | args.add("features", "save features to a CSV file, specify file name"); | 286 | args.add("features", "save features to a CSV file, specify file name"); |
264 | 287 | ||
265 | args.parse(argc, argv); //parse the user arguments | 288 | args.parse(argc, argv); //parse the user arguments |
@@ -277,6 +300,7 @@ int main(int argc, char* argv[]) | @@ -277,6 +300,7 @@ int main(int argc, char* argv[]) | ||
277 | } | 300 | } |
278 | 301 | ||
279 | if(args.nargs() == 2){ //if two files are specified, they will be displayed in neighboring viewports and compared | 302 | if(args.nargs() == 2){ //if two files are specified, they will be displayed in neighboring viewports and compared |
303 | + int device = args["device"].as_int(); //get the device value from the user | ||
280 | num_nets = 2; //set the number of networks to two | 304 | num_nets = 2; //set the number of networks to two |
281 | float sigma = args["sigma"].as_float(); //get the sigma value from the user | 305 | float sigma = args["sigma"].as_float(); //get the sigma value from the user |
282 | T.load_obj(args.arg(1)); //load the second (test) network | 306 | T.load_obj(args.arg(1)); //load the second (test) network |
@@ -284,7 +308,8 @@ int main(int argc, char* argv[]) | @@ -284,7 +308,8 @@ int main(int argc, char* argv[]) | ||
284 | features(args["features"].as_string()); | 308 | features(args["features"].as_string()); |
285 | GT = GT.resample(resample_rate * sigma); //resample both networks based on the sigma value | 309 | GT = GT.resample(resample_rate * sigma); //resample both networks based on the sigma value |
286 | T = T.resample(resample_rate * sigma); | 310 | T = T.resample(resample_rate * sigma); |
287 | - compare(sigma); //run the comparison algorithm | 311 | + setdevice(device); |
312 | + compare(sigma, device); //run the comparison algorithm | ||
288 | } | 313 | } |
289 | 314 | ||
290 | //if a GUI is requested, display the network using OpenGL | 315 | //if a GUI is requested, display the network using OpenGL |