Commit fbbc07be5d338b7854f754d1e00c7cb37212780d

Authored by Jiaming Guo
1 parent 47263b74

replace ANN by kdtree.

stim/biomodels/network.h
... ... @@ -11,7 +11,7 @@
11 11 #include <stim/math/vec3.h>
12 12 #include <stim/visualization/obj.h>
13 13 #include <stim/visualization/cylinder.h>
14   -#include <ANN/ANN.h>
  14 +#include <stim/structures/kdtree.cuh>
15 15 #include <boost/tuple/tuple.hpp>
16 16  
17 17  
... ... @@ -57,7 +57,7 @@ class network{
57 57 /// Output the edge information as a string
58 58 std::string str(){
59 59 std::stringstream ss;
60   - ss<<"("<<cylinder<T>::size()<<")\tl = "<<this.length()<<"\t"<<v[0]<<"----"<<v[1];
  60 + ss<<"("<<cylinder<T>::size()<<")\tl = "<<this->length()<<"\t"<<v[0]<<"----"<<v[1];
61 61 return ss.str();
62 62 }
63 63  
... ... @@ -390,8 +390,8 @@ public:
390 390 // gaussian function
391 391 float gaussianFunction(float x, float std=25){ return exp(-x/(2*std*std));} // by default std = 25
392 392  
393   - // stim 3d vector to annpoint of 3 dimensions
394   - void stim2ann(ANNpoint &a, stim::vec3<T> b){
  393 + // convert vec3 to array
  394 + void stim2array(float *a, stim::vec3<T> b){
395 395 a[0] = b[0];
396 396 a[1] = b[1];
397 397 a[2] = b[2];
... ... @@ -422,44 +422,43 @@ public:
422 422  
423 423 //generate a KD-tree for network A
424 424 float metric = 0.0; // initialize metric to be returned after comparing the networks
425   - ANNkd_tree* kdt; // initialize a pointer to a kd tree
426   - double **c; // centerline (array of double pointers) - points on kdtree must be double
  425 + stim::cuda_kdtree<T, 3> kdt; // initialize a pointer to a kd tree
  426 + size_t MaxTreeLevels = 3; // max tree level
  427 +
  428 + float *c; // centerline (array of double pointers) - points on kdtree must be double
427 429 unsigned int n_data = A.total_points(); // set the number of points
428   - c = (double**) malloc(sizeof(double*) * n_data); // allocate the array pointer
429   - for(unsigned int i = 0; i < n_data; i++) // allocate space for each point of 3 dimensions
430   - c[i] = (double*) malloc(sizeof(double) * 3);
  430 + c = (float*) malloc(sizeof(float) * n_data * 3);
431 431  
432 432 unsigned t = 0;
433 433 for(unsigned e = 0; e < A.E.size(); e++){ //for each edge in the network
434 434 for(unsigned p = 0; p < A.E[e].size(); p++){ //for each point in the edge
435 435 for(unsigned d = 0; d < 3; d++){ //for each coordinate
436 436  
437   - c[t][d] = A.E[e][p][d];
  437 + c[t * 3 + d] = A.E[e][p][d];
438 438 }
439 439 t++;
440 440 }
441 441 }
442 442  
443 443 //compare each point in the current network to the field produced by A
444   - ANNpointArray pts = (ANNpointArray)c; // create an array of data points of type double
445   - kdt = new ANNkd_tree(pts, n_data, 3); // build a KD tree using the annpointarray
446   - double eps = 0; // error bound
447   - ANNdistArray dists = new ANNdist[1]; // near neighbor distances
448   - ANNidxArray nnIdx = new ANNidx[1]; // near neighbor indices // allocate near neigh indices
  444 + kdt.CreateKDTree(c, n_data, 3, MaxTreeLevels); // build a KD tree
  445 + float *dists = new float[1]; // near neighbor distances
  446 + size_t *nnIdx = new size_t[1]; // near neighbor indices // allocate near neigh indices
449 447  
450 448 stim::vec3<T> p0, p1;
451 449 float m1;
452 450 float M = 0; //stores the total metric value
453 451 float L = 0; //stores the total network length
454   - ANNpoint queryPt = annAllocPt(3);
  452 + float* queryPt = new float[3];
455 453 for(unsigned e = 0; e < R.E.size(); e++){ //for each edge in A
456 454 R.E[e].add_mag(0); //add a new magnitude for the metric
457 455  
458 456 for(unsigned p = 0; p < R.E[e].size(); p++){ //for each point in the edge
459 457  
460 458 p1 = R.E[e][p]; //get the next point in the edge
461   - stim2ann(queryPt, p1);
462   - kdt->annkSearch( queryPt, 1, nnIdx, dists, eps); //find the distance between A and the current network
  459 + stim2array(queryPt, p1);
  460 + kdt.Search(queryPt, 1, 3, dists, nnIdx); //find the distance between A and the current network
  461 +
463 462 m1 = 1.0f - gaussianFunction((float)dists[0], sigma); //calculate the metric value based on the distance
464 463 R.E[e].set_mag(m1, p, 1); //set the error for the second point in the segment
465 464  
... ...
stim/biomodels/network_dep.h
... ... @@ -4,7 +4,7 @@
4 4 #include <stim/math/vector.h>
5 5 #include <stim/visualization/obj.h>
6 6 #include <list>
7   -#include <ANN/ANN.h>
  7 +//#include <ANN/ANN.h>
8 8  
9 9 namespace stim{
10 10  
... ...
stim/structures/kdtree.cuh
1   -// right now the size of CUDA STACK is set to 1000, increase it if you mean to make deeper tree
  1 +// right now the size of CUDA STACK is set to 50, increase it if you mean to make deeper tree
2 2 // data should be stored in row-major
3 3 // x1,x2,x3,x4,x5......
4 4 // y1,y2,y3,y4,y5......
... ... @@ -207,13 +207,13 @@ namespace stim {
207 207 size_t best_index = 0;
208 208  
209 209 int next_nodes_pos = 0; // initialize pop out order index
210   - next_nodes[id * 1000 + next_nodes_pos] = cur; // find data that belongs to the very specific thread
  210 + next_nodes[id * 50 + next_nodes_pos] = cur; // find data that belongs to the very specific thread
211 211 next_nodes_pos++;
212 212  
213 213 while (next_nodes_pos) {
214 214 int next_search_nodes_pos = 0; // record push back order index
215 215 while (next_nodes_pos) {
216   - cur = next_nodes[id * 1000 + next_nodes_pos - 1]; // pop out the last push in one and keep poping out
  216 + cur = next_nodes[id * 50 + next_nodes_pos - 1]; // pop out the last push in one and keep poping out
217 217 next_nodes_pos--;
218 218 int split_axis = nodes[cur].level % D;
219 219  
... ... @@ -232,20 +232,20 @@ namespace stim {
232 232  
233 233 if (fabs(d) > range) {
234 234 if (d < 0) {
235   - next_search_nodes[id * 1000 + next_search_nodes_pos] = nodes[cur].left;
  235 + next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].left;
236 236 next_search_nodes_pos++;
237 237 }
238 238 else {
239   - next_search_nodes[id * 1000 + next_search_nodes_pos] = nodes[cur].right;
  239 + next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].right;
240 240 next_search_nodes_pos++;
241 241 }
242 242 }
243 243 else {
244   - next_search_nodes[id * 1000 + next_search_nodes_pos] = nodes[cur].right;
  244 + next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].right;
245 245 next_search_nodes_pos++;
246   - next_search_nodes[id * 1000 + next_search_nodes_pos] = nodes[cur].left;
  246 + next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].left;
247 247 next_search_nodes_pos++;
248   - if (next_search_nodes_pos > 1000) {
  248 + if (next_search_nodes_pos > 50) {
249 249 printf("Thread conflict might be caused by thread %d, so please try smaller input max_tree_levels\n", id);
250 250 (*Judge)++;
251 251 }
... ... @@ -253,7 +253,7 @@ namespace stim {
253 253 }
254 254 }
255 255 for (int i = 0; i < next_search_nodes_pos; i++)
256   - next_nodes[id * 1000 + i] = next_search_nodes[id * 1000 + i];
  256 + next_nodes[id * 50 + i] = next_search_nodes[id * 50 + i];
257 257 next_nodes_pos = next_search_nodes_pos;
258 258 }
259 259 *d_distance = best_distance;
... ... @@ -392,8 +392,8 @@ namespace stim {
392 392 HANDLE_ERROR(cudaMalloc((void**)&d_query_points, sizeof(T) * query_points.size() * D));
393 393 HANDLE_ERROR(cudaMalloc((void**)&d_indices, sizeof(size_t) * query_points.size()));
394 394 HANDLE_ERROR(cudaMalloc((void**)&d_distances, sizeof(T) * query_points.size()));
395   - HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * 1000 * sizeof(int))); // STACK size right now is 1000, you can change it if you mean to
396   - HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * 1000 * sizeof(int)));
  395 + HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * 50 * sizeof(int))); // STACK size right now is 50, you can change it if you mean to
  396 + HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * 50 * sizeof(int)));
397 397 HANDLE_ERROR(cudaMemcpy(d_query_points, &query_points[0], sizeof(T) * query_points.size() * D, cudaMemcpyHostToDevice));
398 398  
399 399 SearchBatch<<<threads, blocks>>> (d_nodes, d_index, d_reference_points, d_query_points, query_points.size(), d_indices, d_distances, next_nodes, next_search_nodes, Judge);
... ...