Commit df7df5b0b1d125f0882bf012cab9c7ce619da096
Merge branch 'JACK_netmets' into 'master'
add device choice and cpu implementation See merge request !1
Showing
1 changed file
with
38 additions
and
13 deletions
Show diff stats
main.cu
... | ... | @@ -14,6 +14,11 @@ |
14 | 14 | #include <stim/parser/arguments.h> |
15 | 15 | #include <stim/visualization/camera.h> |
16 | 16 | |
17 | +#ifdef __CUDACC__ | |
18 | +//CUDA includes | |
19 | +#include <cuda.h> | |
20 | +#endif | |
21 | + | |
17 | 22 | //ANN includes |
18 | 23 | //#include <ANN/ANN.h> |
19 | 24 | |
... | ... | @@ -29,9 +34,9 @@ stim::gl_network<float> GT; //ground truth network |
29 | 34 | stim::gl_network<float> T; //test network |
30 | 35 | |
31 | 36 | //hard-coded parameters |
32 | -float resample_rate = 0.5; //sample rate for the network (fraction of sigma used as the maximum sample rate) | |
33 | -float camera_factor = 1.2; //start point of the camera as a function of X and Y size | |
34 | -float orbit_factor = 0.01; //degrees per pixel used to orbit the camera | |
37 | +float resample_rate = 0.5f; //sample rate for the network (fraction of sigma used as the maximum sample rate) | |
38 | +float camera_factor = 1.2f; //start point of the camera as a function of X and Y size | |
39 | +float orbit_factor = 0.01f; //degrees per pixel used to orbit the camera | |
35 | 40 | |
36 | 41 | //mouse position tracking |
37 | 42 | int mouse_x; |
... | ... | @@ -195,11 +200,26 @@ void glut_initialize(){ |
195 | 200 | cam.LookAt(c[0], c[1], c[2]); //look at the center of the network |
196 | 201 | } |
197 | 202 | |
203 | +#ifdef __CUDACC__ | |
204 | +void setdevice(int &device){ | |
205 | + int count; | |
206 | + cudaGetDeviceCount(&count); // numbers of device that are available | |
207 | + if(count < device + 1){ | |
208 | + std::cout<<"No such device available, please set another device"<<std::endl; | |
209 | + exit(1); | |
210 | + } | |
211 | +} | |
212 | +#else | |
213 | +void setdevice(int &device){ | |
214 | + device = -1; | |
215 | +} | |
216 | +#endif | |
217 | + | |
198 | 218 | //compare both networks and fill the networks with error information |
199 | -void compare(float sigma){ | |
219 | +void compare(float sigma, int device){ | |
200 | 220 | |
201 | - GT = GT.compare(T, sigma); //compare the ground truth to the test case - store errors in GT | |
202 | - T = T.compare(GT, sigma); //compare the test case to the ground truth - store errors in T | |
221 | + GT = GT.compare(T, sigma, device); //compare the ground truth to the test case - store errors in GT | |
222 | + T = T.compare(GT, sigma, device); //compare the test case to the ground truth - store errors in T | |
203 | 223 | |
204 | 224 | //calculate the metrics |
205 | 225 | float FPR = GT.average(1); //calculate the metrics |
... | ... | @@ -212,7 +232,7 @@ void compare(float sigma){ |
212 | 232 | // writes features of the networks i.e average segment length, tortuosity, branching index, contraction, fractal dimension, number of end and branch points to a csv file |
213 | 233 | // Pranathi wrote this - saves network features to a CSV file |
214 | 234 | void features(std::string filename){ |
215 | - double avgL_t, avgL_gt, avgT_t, avgT_gt, avgB_t, avgB_gt, avgC_t, avgC_gt, avgFD_t, avgFD_gt; | |
235 | + double avgL_t, avgL_gt, avgT_t, avgT_gt, avgB_t, avgB_gt, avgC_t, avgC_gt, avgFD_t, avgFD_gt; | |
216 | 236 | unsigned int e_t, e_gt, b_gt, b_t; |
217 | 237 | avgL_gt = GT.Lengths(); |
218 | 238 | avgT_gt = GT.Tortuosities(); |
... | ... | @@ -246,10 +266,12 @@ void advertise(){ |
246 | 266 | std::cout<<"Source: https://git.stim.ee.uh.edu/segmentation/netmets"<<std::endl; |
247 | 267 | std::cout<<"========================================================================="<<std::endl<<std::endl; |
248 | 268 | |
249 | - std::cout<<"usage: netmets file1 file2 --sigma 10"<<std::endl; | |
250 | - std::cout<<" compare two files with a tolerance of 10 (units defined by the network)"<<std::endl; | |
251 | - std::cout<<" netmets file1 --gui"<<std::endl<<std::endl; | |
252 | - std::cout<<" load a file and display it using OpenGL"<<std::endl; | |
269 | + std::cout<<"usage: netmets file1 file2 --sigma 3"<<std::endl; | |
270 | + std::cout<<" compare two files with a tolerance of 3 (units defined by the network)"<<std::endl<<std::endl; | |
271 | + std::cout<<" netmets file1 --gui"<<std::endl; | |
272 | + std::cout<<" load a file and display it using OpenGL"<<std::endl<<std::endl; | |
273 | + std::cout<<" netmets file1 file2 --device 0"<<std::endl; | |
274 | + std::cout<<" compare two files using device 0 (if there isn't a gpu, use cpu)"<<std::endl<<std::endl; | |
253 | 275 | } |
254 | 276 | |
255 | 277 | int main(int argc, char* argv[]) |
... | ... | @@ -258,8 +280,9 @@ int main(int argc, char* argv[]) |
258 | 280 | |
259 | 281 | //add arguments |
260 | 282 | args.add("help", "prints this help"); |
261 | - args.add("sigma", "force a sigma value to specify the tolerance of the network comparison", "10"); | |
283 | + args.add("sigma", "force a sigma value to specify the tolerance of the network comparison", "3"); | |
262 | 284 | args.add("gui", "display the network or network comparison using OpenGL"); |
285 | + args.add("device", "choose specific device to run", "0"); | |
263 | 286 | args.add("features", "save features to a CSV file, specify file name"); |
264 | 287 | |
265 | 288 | args.parse(argc, argv); //parse the user arguments |
... | ... | @@ -277,6 +300,7 @@ int main(int argc, char* argv[]) |
277 | 300 | } |
278 | 301 | |
279 | 302 | if(args.nargs() == 2){ //if two files are specified, they will be displayed in neighboring viewports and compared |
303 | + int device = args["device"].as_int(); //get the device value from the user | |
280 | 304 | num_nets = 2; //set the number of networks to two |
281 | 305 | float sigma = args["sigma"].as_float(); //get the sigma value from the user |
282 | 306 | T.load_obj(args.arg(1)); //load the second (test) network |
... | ... | @@ -284,7 +308,8 @@ int main(int argc, char* argv[]) |
284 | 308 | features(args["features"].as_string()); |
285 | 309 | GT = GT.resample(resample_rate * sigma); //resample both networks based on the sigma value |
286 | 310 | T = T.resample(resample_rate * sigma); |
287 | - compare(sigma); //run the comparison algorithm | |
311 | + setdevice(device); | |
312 | + compare(sigma, device); //run the comparison algorithm | |
288 | 313 | } |
289 | 314 | |
290 | 315 | //if a GUI is requested, display the network using OpenGL | ... | ... |