Commit 43b34ee057eab91bac772d21b44bbbe5894da3df

Authored by David Mayerich
2 parents d8072434 41100e51

Merge branch 'JACK_netmets' into 'master'

add cpu kdtree version

please do the test on linux to see whether it works or not

See merge request !13
Showing 2 changed files with 222 additions and 67 deletions   Show diff stats
stim/biomodels/network.h
@@ -13,6 +13,7 @@ @@ -13,6 +13,7 @@
13 #include <stim/visualization/cylinder.h> 13 #include <stim/visualization/cylinder.h>
14 #include <stim/structures/kdtree.cuh> 14 #include <stim/structures/kdtree.cuh>
15 #include <boost/tuple/tuple.hpp> 15 #include <boost/tuple/tuple.hpp>
  16 +#include <stim/cuda/cudatools/timer.h>
16 17
17 18
18 namespace stim{ 19 namespace stim{
@@ -415,19 +416,14 @@ public: @@ -415,19 +416,14 @@ public:
415 416
416 /// @param A is the network to compare to - the field is generated for A 417 /// @param A is the network to compare to - the field is generated for A
417 /// @param sigma is the user-defined tolerance value - smaller values provide a stricter comparison 418 /// @param sigma is the user-defined tolerance value - smaller values provide a stricter comparison
418 - stim::network<T> compare(stim::network<T> A, float sigma){ 419 + stim::network<T> compare(stim::network<T> A, float sigma, int device){
419 420
420 stim::network<T> R; //generate a network storing the result of the comparison 421 stim::network<T> R; //generate a network storing the result of the comparison
421 R = (*this); //initialize the result with the current network 422 R = (*this); //initialize the result with the current network
422 423
423 - //generate a KD-tree for network A  
424 - //float metric = 0.0; // initialize metric to be returned after comparing the networks  
425 - stim::cuda_kdtree<T, 3> kdt; // initialize a pointer to a kd tree  
426 - size_t MaxTreeLevels = 3; // max tree level  
427 -  
428 - float *c; // centerline (array of double pointers) - points on kdtree must be double  
429 - unsigned int n_data = A.total_points(); // set the number of points  
430 - c = (float*) malloc(sizeof(float) * n_data * 3); 424 + T *c; // centerline (array of double pointers) - points on kdtree must be double
  425 + size_t n_data = A.total_points(); // set the number of points
  426 + c = (T*) malloc(sizeof(T) * n_data * 3);
431 427
432 unsigned t = 0; 428 unsigned t = 0;
433 for(unsigned e = 0; e < A.E.size(); e++){ //for each edge in the network 429 for(unsigned e = 0; e < A.E.size(); e++){ //for each edge in the network
@@ -440,16 +436,24 @@ public: @@ -440,16 +436,24 @@ public:
440 } 436 }
441 } 437 }
442 438
  439 + //generate a KD-tree for network A
  440 + //float metric = 0.0; // initialize metric to be returned after comparing the network
  441 + size_t MaxTreeLevels = 3; // max tree level
  442 +
  443 +#ifdef __CUDACC__
  444 + cudaSetDevice(device);
  445 + stim::cuda_kdtree<T, 3> kdt; // initialize a pointer to a kd tree
  446 +
443 //compare each point in the current network to the field produced by A 447 //compare each point in the current network to the field produced by A
444 - kdt.CreateKDTree(c, n_data, 3, MaxTreeLevels); // build a KD tree  
445 - float *dists = new float[1]; // near neighbor distances 448 + kdt.create(c, n_data, MaxTreeLevels); // build a KD tree
  449 + T *dists = new T[1]; // near neighbor distances
446 size_t *nnIdx = new size_t[1]; // near neighbor indices // allocate near neigh indices 450 size_t *nnIdx = new size_t[1]; // near neighbor indices // allocate near neigh indices
447 451
448 stim::vec3<T> p0, p1; 452 stim::vec3<T> p0, p1;
449 - float m1; 453 + T m1;
450 //float M = 0; //stores the total metric value 454 //float M = 0; //stores the total metric value
451 //float L = 0; //stores the total network length 455 //float L = 0; //stores the total network length
452 - float* queryPt = new float[3]; 456 + T* queryPt = new T[3];
453 for(unsigned e = 0; e < R.E.size(); e++){ //for each edge in A 457 for(unsigned e = 0; e < R.E.size(); e++){ //for each edge in A
454 R.E[e].add_mag(0); //add a new magnitude for the metric 458 R.E[e].add_mag(0); //add a new magnitude for the metric
455 459
@@ -457,14 +461,36 @@ public: @@ -457,14 +461,36 @@ public:
457 461
458 p1 = R.E[e][p]; //get the next point in the edge 462 p1 = R.E[e][p]; //get the next point in the edge
459 stim2array(queryPt, p1); 463 stim2array(queryPt, p1);
460 - kdt.Search(queryPt, 1, 3, dists, nnIdx); //find the distance between A and the current network 464 + kdt.search(queryPt, 1, nnIdx, dists); //find the distance between A and the current network
461 465
462 - m1 = 1.0f - gaussianFunction((float)dists[0], sigma); //calculate the metric value based on the distance 466 + m1 = 1.0f - gaussianFunction((T)dists[0], sigma); //calculate the metric value based on the distance
463 R.E[e].set_mag(m1, p, 1); //set the error for the second point in the segment 467 R.E[e].set_mag(m1, p, 1); //set the error for the second point in the segment
464 468
465 } 469 }
466 } 470 }
  471 +#else
  472 + stim::cpu_kdtree<T, 3> kdt;
  473 + kdt.create(c, n_data, MaxTreeLevels);
  474 + T *dists = new T[1]; // near neighbor distances
  475 + size_t *nnIdx = new size_t[1]; // near neighbor indices // allocate near neigh indices
  476 +
  477 + stim::vec3<T> p0, p1;
  478 + T m1;
  479 + T* queryPt = new T[3];
  480 + for(unsigned e = 0; e < R.E.size(); e++){ //for each edge in A
  481 + R.E[e].add_mag(0); //add a new magnitude for the metric
  482 +
  483 + for(unsigned p = 0; p < R.E[e].size(); p++){ //for each point in the edge
467 484
  485 + p1 = R.E[e][p]; //get the next point in the edge
  486 + stim2array(queryPt, p1);
  487 + kdt.cpu_search(queryPt, 1, nnIdx, dists); //find the distance between A and the current network
  488 +
  489 + m1 = 1.0f - gaussianFunction((T)dists[0], sigma); //calculate the metric value based on the distance
  490 + R.E[e].set_mag(m1, p, 1); //set the error for the second point in the segment
  491 + }
  492 + }
  493 +#endif
468 return R; //return the resulting network 494 return R; //return the resulting network
469 } 495 }
470 496
stim/structures/kdtree.cuh
1 -// right now the size of CUDA STACK is set to 50, increase it if you mean to make deeper tree 1 +// right now the size of CUDA STACK is set to 1000, increase it if you mean to make deeper tree
2 // data should be stored in row-major 2 // data should be stored in row-major
3 // x1,x2,x3,x4,x5...... 3 // x1,x2,x3,x4,x5......
4 // y1,y2,y3,y4,y5...... 4 // y1,y2,y3,y4,y5......
@@ -13,6 +13,7 @@ @@ -13,6 +13,7 @@
13 #include <cuda_runtime_api.h> 13 #include <cuda_runtime_api.h>
14 #include "cuda_runtime.h" 14 #include "cuda_runtime.h"
15 #include <vector> 15 #include <vector>
  16 +#include <cstring>
16 #include <float.h> 17 #include <float.h>
17 #include <iostream> 18 #include <iostream>
18 #include <algorithm> 19 #include <algorithm>
@@ -50,14 +51,14 @@ namespace stim { @@ -50,14 +51,14 @@ namespace stim {
50 class cpu_kdtree { 51 class cpu_kdtree {
51 protected: 52 protected:
52 int current_axis; // current judging axis 53 int current_axis; // current judging axis
53 - int cmps; // count how many time of comparisons (just for cpu-kdtree)  
54 int n_id; // store the total number of nodes 54 int n_id; // store the total number of nodes
55 - std::vector < typename kdtree::point<T, D> > *tmp_points; // transfer or temp points 55 + std::vector < typename kdtree::point<T, D> > *tmp_points; // transfer or temperary points
  56 + std::vector < typename kdtree::point<T, D> > cpu_tmp_points; // for cpu searching
56 kdtree::kdnode<T> *root; // root node 57 kdtree::kdnode<T> *root; // root node
57 static cpu_kdtree<T, D> *cur_tree_ptr; 58 static cpu_kdtree<T, D> *cur_tree_ptr;
58 public: 59 public:
59 cpu_kdtree() { // constructor for creating a cpu_kdtree 60 cpu_kdtree() { // constructor for creating a cpu_kdtree
60 - cur_tree_ptr = this; // create a class pointer points to the current class value 61 + cur_tree_ptr = this; // create a class pointer points to the current class value
61 n_id = 0; // set total number of points to default 0 62 n_id = 0; // set total number of points to default 0
62 } 63 }
63 ~cpu_kdtree() { // destructor of cpu_kdtree 64 ~cpu_kdtree() { // destructor of cpu_kdtree
@@ -78,8 +79,8 @@ namespace stim { @@ -78,8 +79,8 @@ namespace stim {
78 } 79 }
79 root = NULL; 80 root = NULL;
80 } 81 }
81 - void Create(std::vector < typename kdtree::point<T, D> > &reference_points, size_t max_levels) {  
82 - tmp_points = &reference_points; 82 + void cpu_create(std::vector < typename kdtree::point<T, D> > &reference_points, size_t max_levels) {
  83 + tmp_points = &reference_points;
83 root = new kdtree::kdnode<T>(); // initializing the root node 84 root = new kdtree::kdnode<T>(); // initializing the root node
84 root->idx = n_id++; // the index of root is 0 85 root->idx = n_id++; // the index of root is 0
85 root->level = 0; // tree level begins at 0 86 root->level = 0; // tree level begins at 0
@@ -100,7 +101,7 @@ namespace stim { @@ -100,7 +101,7 @@ namespace stim {
100 kdtree::kdnode<T> *right = new kdtree::kdnode<T>(); 101 kdtree::kdnode<T> *right = new kdtree::kdnode<T>();
101 left->idx = n_id++; // set the index of current node's left node 102 left->idx = n_id++; // set the index of current node's left node
102 right->idx = n_id++; 103 right->idx = n_id++;
103 - Split(current_node, left, right); // split left and right and determine a node 104 + split(current_node, left, right); // split left and right and determine a node
104 std::vector <size_t> temp; // empty vecters of int 105 std::vector <size_t> temp; // empty vecters of int
105 //temp.resize(current_node->indices.size()); 106 //temp.resize(current_node->indices.size());
106 current_node->indices.swap(temp); // clean up current node's indices 107 current_node->indices.swap(temp); // clean up current node's indices
@@ -118,14 +119,14 @@ namespace stim { @@ -118,14 +119,14 @@ namespace stim {
118 next_nodes = next_search_nodes; // go deeper within the tree 119 next_nodes = next_search_nodes; // go deeper within the tree
119 } 120 }
120 } 121 }
121 - static bool SortPoints(const size_t a, const size_t b) { // create functor for std::sort 122 + static bool sort_points(const size_t a, const size_t b) { // create functor for std::sort
122 std::vector < typename kdtree::point<T, D> > &pts = *cur_tree_ptr->tmp_points; // put cur_tree_ptr to current input points' pointer 123 std::vector < typename kdtree::point<T, D> > &pts = *cur_tree_ptr->tmp_points; // put cur_tree_ptr to current input points' pointer
123 return pts[a].dim[cur_tree_ptr->current_axis] < pts[b].dim[cur_tree_ptr->current_axis]; 124 return pts[a].dim[cur_tree_ptr->current_axis] < pts[b].dim[cur_tree_ptr->current_axis];
124 } 125 }
125 - void Split(kdtree::kdnode<T> *cur, kdtree::kdnode<T> *left, kdtree::kdnode<T> *right) { 126 + void split(kdtree::kdnode<T> *cur, kdtree::kdnode<T> *left, kdtree::kdnode<T> *right) {
126 std::vector < typename kdtree::point<T, D> > &pts = *tmp_points; 127 std::vector < typename kdtree::point<T, D> > &pts = *tmp_points;
127 current_axis = cur->level % D; // indicate the judicative dimension or axis 128 current_axis = cur->level % D; // indicate the judicative dimension or axis
128 - std::sort(cur->indices.begin(), cur->indices.end(), SortPoints); // using SortPoints as comparison function to sort the data 129 + std::sort(cur->indices.begin(), cur->indices.end(), sort_points); // using SortPoints as comparison function to sort the data
129 size_t mid_value = cur->indices[cur->indices.size() / 2]; // odd in the mid_value, even take the floor 130 size_t mid_value = cur->indices[cur->indices.size() / 2]; // odd in the mid_value, even take the floor
130 cur->split_value = pts[mid_value].dim[current_axis]; // get the parent node 131 cur->split_value = pts[mid_value].dim[current_axis]; // get the parent node
131 left->parent = cur; // set the parent of the next search nodes to current node 132 left->parent = cur; // set the parent of the next search nodes to current node
@@ -142,48 +143,176 @@ namespace stim { @@ -142,48 +143,176 @@ namespace stim {
142 right->indices.push_back(idx); 143 right->indices.push_back(idx);
143 } 144 }
144 } 145 }
145 - int GetNumNodes() const { // get the total number of nodes 146 + void create(T *h_reference_points, size_t reference_count, size_t max_levels) {
  147 + std::vector < typename kdtree::point<T, D> > reference_points(reference_count); // restore the reference points in particular way
  148 + for (size_t j = 0; j < reference_count; j++)
  149 + for (size_t i = 0; i < D; i++)
  150 + reference_points[j].dim[i] = h_reference_points[j * D + i];
  151 + cpu_create(reference_points, max_levels);
  152 + cpu_tmp_points = *tmp_points;
  153 + }
  154 + int get_num_nodes() const { // get the total number of nodes
146 return n_id; 155 return n_id;
147 } 156 }
148 - kdtree::kdnode<T>* GetRoot() const { // get the root node of tree 157 + kdtree::kdnode<T>* get_root() const { // get the root node of tree
149 return root; 158 return root;
150 } 159 }
  160 + T cpu_distance(const kdtree::point<T, D> &a, const kdtree::point<T, D> &b) {
  161 + T distance = 0;
  162 +
  163 + for (size_t i = 0; i < D; i++) {
  164 + T d = a.dim[i] - b.dim[i];
  165 + distance += d*d;
  166 + }
  167 + return distance;
  168 + }
  169 + void cpu_search_at_node(kdtree::kdnode<T> *cur, const kdtree::point<T, D> &query, size_t *index, T *distance, kdtree::kdnode<T> **node) {
  170 + T best_distance = FLT_MAX; // initialize the best distance to max of floating point
  171 + size_t best_index = 0;
  172 + std::vector < typename kdtree::point<T, D> > pts = cpu_tmp_points;
  173 + while (true) {
  174 + size_t split_axis = cur->level % D;
  175 + if (cur->left == NULL) { // risky but acceptable, same goes for right because left and right are in same pace
  176 + *node = cur; // pointer points to a pointer
  177 + for (size_t i = 0; i < cur->indices.size(); i++) {
  178 + size_t idx = cur->indices[i];
  179 + T d = cpu_distance(query, pts[idx]); // compute distances
  180 + /// if we want to compute k nearest neighbor, we can input the last resul
  181 + /// (last_best_dist < dist < best_dist) to select the next point until reaching to k
  182 + if (d < best_distance) {
  183 + best_distance = d;
  184 + best_index = idx; // record the nearest neighbor index
  185 + }
  186 + }
  187 + break; // find the target point then break the loop
  188 + }
  189 + else if (query.dim[split_axis] < cur->split_value) { // if it has son node, visit the next node on either left side or right side
  190 + cur = cur->left;
  191 + }
  192 + else {
  193 + cur = cur->right;
  194 + }
  195 + }
  196 + *index = best_index;
  197 + *distance = best_distance;
  198 + }
  199 + void cpu_search_at_node_range(kdtree::kdnode<T> *cur, const kdtree::point<T, D> &query, T range, size_t *index, T *distance) {
  200 + T best_distance = FLT_MAX; // initialize the best distance to max of floating point
  201 + size_t best_index = 0;
  202 + std::vector < typename kdtree::point<T, D> > pts = cpu_tmp_points;
  203 + std::vector < typename kdtree::kdnode<T>*> next_node;
  204 + next_node.push_back(cur);
  205 + while (next_node.size()) {
  206 + std::vector<typename kdtree::kdnode<T>*> next_search;
  207 + while (next_node.size()) {
  208 + cur = next_node.back();
  209 + next_node.pop_back();
  210 + size_t split_axis = cur->level % D;
  211 + if (cur->left == NULL) {
  212 + for (size_t i = 0; i < cur->indices.size(); i++) {
  213 + size_t idx = cur->indices[i];
  214 + T d = cpu_distance(query, pts[idx]);
  215 + if (d < best_distance) {
  216 + best_distance = d;
  217 + best_index = idx;
  218 + }
  219 + }
  220 + }
  221 + else {
  222 + T d = query.dim[split_axis] - cur->split_value; // computer distance along specific axis or dimension
  223 + /// there are three possibilities: on either left or right, and on both left and right
  224 + if (fabs(d) > range) { // absolute value of floating point to see if distance will be larger that best_dist
  225 + if (d < 0)
  226 + next_search.push_back(cur->left); // every left[split_axis] is less and equal to cur->split_value, so it is possible to find the nearest point in this region
  227 + else
  228 + next_search.push_back(cur->right);
  229 + }
  230 + else { // it is possible that nereast neighbor will appear on both left and right
  231 + next_search.push_back(cur->left);
  232 + next_search.push_back(cur->right);
  233 + }
  234 + }
  235 + }
  236 + next_node = next_search; // pop out at least one time
  237 + }
  238 + *index = best_index;
  239 + *distance = best_distance;
  240 + }
  241 + void cpu_search(T *h_query_points, size_t query_count, size_t *h_indices, T *h_distances) {
  242 + /// first convert the input query point into specific type
  243 + kdtree::point<T, D> query;
  244 + for (size_t j = 0; j < query_count; j++) {
  245 + for (size_t i = 0; i < D; i++)
  246 + query.dim[i] = h_query_points[j * D + i];
  247 + /// find the nearest node, this will be the upper bound for the next time searching
  248 + kdtree::kdnode<T> *best_node = NULL;
  249 + T best_distance = FLT_MAX;
  250 + size_t best_index = 0;
  251 + T radius = 0; // radius for range
  252 + cpu_search_at_node(root, query, &best_index, &best_distance, &best_node); // simple search to rougly determine a result for next search step
  253 + radius = sqrt(best_distance); // It is possible that nearest will appear in another region
  254 + /// find other possibilities
  255 + kdtree::kdnode<T> *cur = best_node;
  256 + while (cur->parent != NULL) { // every node that you pass will be possible to be the best node
  257 + /// go up
  258 + kdtree::kdnode<T> *parent = cur->parent; // travel back to every node that we pass through
  259 + size_t split_axis = (parent->level) % D;
  260 + /// search other nodes
  261 + size_t tmp_index;
  262 + T tmp_distance = FLT_MAX;
  263 + if (fabs(parent->split_value - query.dim[split_axis]) <= radius) {
  264 + /// search opposite node
  265 + if (parent->left != cur)
  266 + cpu_search_at_node_range(parent->left, query, radius, &tmp_index, &tmp_distance); // to see whether it is its mother node's left son node
  267 + else
  268 + cpu_search_at_node_range(parent->right, query, radius, &tmp_index, &tmp_distance);
  269 + }
  270 + if (tmp_distance < best_distance) {
  271 + best_distance = tmp_distance;
  272 + best_index = tmp_index;
  273 + }
  274 + cur = parent;
  275 + }
  276 + h_indices[j] = best_index;
  277 + h_distances[j] = best_distance;
  278 + }
  279 + }
151 }; //end class kdtree 280 }; //end class kdtree
152 281
153 template <typename T, int D> 282 template <typename T, int D>
154 - cpu_kdtree<T, D>* cpu_kdtree<T, D>::cur_tree_ptr = NULL; // definition of cur_tree_ptr pointer points to the current class 283 + cpu_kdtree<T, D>* cpu_kdtree<T, D>::cur_tree_ptr = NULL; // definition of cur_tree_ptr pointer points to the current class
155 284
156 template <typename T> 285 template <typename T>
157 struct cuda_kdnode { 286 struct cuda_kdnode {
158 int parent, left, right; 287 int parent, left, right;
159 T split_value; 288 T split_value;
160 - size_t num_index; // number of indices it has  
161 - int index; // the beginning index 289 + size_t num_index; // number of indices it has
  290 + int index; // the beginning index
162 size_t level; 291 size_t level;
163 }; 292 };
164 293
165 template <typename T, int D> 294 template <typename T, int D>
166 - __device__ T Distance(kdtree::point<T, D> &a, kdtree::point<T, D> &b) {  
167 - T dist = 0; 295 + __device__ T gpu_distance(kdtree::point<T, D> &a, kdtree::point<T, D> &b) {
  296 + T distance = 0;
168 297
169 for (size_t i = 0; i < D; i++) { 298 for (size_t i = 0; i < D; i++) {
170 T d = a.dim[i] - b.dim[i]; 299 T d = a.dim[i] - b.dim[i];
171 - dist += d*d; 300 + distance += d*d;
172 } 301 }
173 - return dist; 302 + return distance;
174 } 303 }
175 template <typename T, int D> 304 template <typename T, int D>
176 - __device__ void SearchAtNode(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, int cur, kdtree::point<T, D> &d_query_point, size_t *d_index, T *d_distance, int *d_node) { 305 + __device__ void search_at_node(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, int cur, kdtree::point<T, D> &d_query_point, size_t *d_index, T *d_distance, int *d_node) {
177 T best_distance = FLT_MAX; 306 T best_distance = FLT_MAX;
178 size_t best_index = 0; 307 size_t best_index = 0;
179 308
180 - while (true) { // break until reach the bottom 309 + while (true) { // break until reach the bottom
181 int split_axis = nodes[cur].level % D; 310 int split_axis = nodes[cur].level % D;
182 - if (nodes[cur].left == -1) { // check whether it has left node or not 311 + if (nodes[cur].left == -1) { // check whether it has left node or not
183 *d_node = cur; 312 *d_node = cur;
184 for (int i = 0; i < nodes[cur].num_index; i++) { 313 for (int i = 0; i < nodes[cur].num_index; i++) {
185 size_t idx = indices[nodes[cur].index + i]; 314 size_t idx = indices[nodes[cur].index + i];
186 - T dist = Distance<T, D>(d_query_point, d_reference_points[idx]); 315 + T dist = gpu_distance<T, D>(d_query_point, d_reference_points[idx]);
187 if (dist < best_distance) { 316 if (dist < best_distance) {
188 best_distance = dist; 317 best_distance = dist;
189 best_index = idx; 318 best_index = idx;
@@ -191,7 +320,7 @@ namespace stim { @@ -191,7 +320,7 @@ namespace stim {
191 } 320 }
192 break; 321 break;
193 } 322 }
194 - else if (d_query_point.dim[split_axis] < nodes[cur].split_value) { // jump into specific son node 323 + else if (d_query_point.dim[split_axis] < nodes[cur].split_value) { // jump into specific son node
195 cur = nodes[cur].left; 324 cur = nodes[cur].left;
196 } 325 }
197 else { 326 else {
@@ -202,25 +331,25 @@ namespace stim { @@ -202,25 +331,25 @@ namespace stim {
202 *d_index = best_index; 331 *d_index = best_index;
203 } 332 }
204 template <typename T, int D> 333 template <typename T, int D>
205 - __device__ void SearchAtNodeRange(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> &d_query_point, int cur, T range, size_t *d_index, T *d_distance, size_t id, int *next_nodes, int *next_search_nodes, int *Judge) { 334 + __device__ void search_at_node_range(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> &d_query_point, int cur, T range, size_t *d_index, T *d_distance, size_t id, int *next_nodes, int *next_search_nodes, int *Judge) {
206 T best_distance = FLT_MAX; 335 T best_distance = FLT_MAX;
207 size_t best_index = 0; 336 size_t best_index = 0;
208 337
209 - int next_nodes_pos = 0; // initialize pop out order index  
210 - next_nodes[id * 50 + next_nodes_pos] = cur; // find data that belongs to the very specific thread 338 + int next_nodes_pos = 0; // initialize pop out order index
  339 + next_nodes[id * 50 + next_nodes_pos] = cur; // find data that belongs to the very specific thread
211 next_nodes_pos++; 340 next_nodes_pos++;
212 341
213 while (next_nodes_pos) { 342 while (next_nodes_pos) {
214 - int next_search_nodes_pos = 0; // record push back order index 343 + int next_search_nodes_pos = 0; // record push back order index
215 while (next_nodes_pos) { 344 while (next_nodes_pos) {
216 - cur = next_nodes[id * 50 + next_nodes_pos - 1]; // pop out the last push in one and keep poping out 345 + cur = next_nodes[id * 50 + next_nodes_pos - 1]; // pop out the last push in one and keep poping out
217 next_nodes_pos--; 346 next_nodes_pos--;
218 int split_axis = nodes[cur].level % D; 347 int split_axis = nodes[cur].level % D;
219 348
220 if (nodes[cur].left == -1) { 349 if (nodes[cur].left == -1) {
221 for (int i = 0; i < nodes[cur].num_index; i++) { 350 for (int i = 0; i < nodes[cur].num_index; i++) {
222 - int idx = indices[nodes[cur].index + i]; // all indices are stored in one array, pick up from every node's beginning index  
223 - T d = Distance<T>(d_query_point, d_reference_points[idx]); 351 + int idx = indices[nodes[cur].index + i]; // all indices are stored in one array, pick up from every node's beginning index
  352 + T d = gpu_distance<T>(d_query_point, d_reference_points[idx]);
224 if (d < best_distance) { 353 if (d < best_distance) {
225 best_distance = d; 354 best_distance = d;
226 best_index = idx; 355 best_index = idx;
@@ -260,13 +389,13 @@ namespace stim { @@ -260,13 +389,13 @@ namespace stim {
260 *d_index = best_index; 389 *d_index = best_index;
261 } 390 }
262 template <typename T, int D> 391 template <typename T, int D>
263 - __device__ void Search(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> &d_query_point, size_t *d_index, T *d_distance, size_t id, int *next_nodes, int *next_search_nodes, int *Judge) { 392 + __device__ void search(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> &d_query_point, size_t *d_index, T *d_distance, size_t id, int *next_nodes, int *next_search_nodes, int *Judge) {
264 int best_node = 0; 393 int best_node = 0;
265 T best_distance = FLT_MAX; 394 T best_distance = FLT_MAX;
266 size_t best_index = 0; 395 size_t best_index = 0;
267 T radius = 0; 396 T radius = 0;
268 397
269 - SearchAtNode<T, D>(nodes, indices, d_reference_points, 0, d_query_point, &best_index, &best_distance, &best_node); 398 + search_at_node<T, D>(nodes, indices, d_reference_points, 0, d_query_point, &best_index, &best_distance, &best_node);
270 radius = sqrt(best_distance); // get range 399 radius = sqrt(best_distance); // get range
271 int cur = best_node; 400 int cur = best_node;
272 401
@@ -278,9 +407,9 @@ namespace stim { @@ -278,9 +407,9 @@ namespace stim {
278 size_t tmp_idx; 407 size_t tmp_idx;
279 if (fabs(nodes[parent].split_value - d_query_point.dim[split_axis]) <= radius) { 408 if (fabs(nodes[parent].split_value - d_query_point.dim[split_axis]) <= radius) {
280 if (nodes[parent].left != cur) 409 if (nodes[parent].left != cur)
281 - SearchAtNodeRange(nodes, indices, d_reference_points, d_query_point, nodes[parent].left, radius, &tmp_idx, &tmp_dist, id, next_nodes, next_search_nodes, Judge); 410 + search_at_node_range(nodes, indices, d_reference_points, d_query_point, nodes[parent].left, radius, &tmp_idx, &tmp_dist, id, next_nodes, next_search_nodes, Judge);
282 else 411 else
283 - SearchAtNodeRange(nodes, indices, d_reference_points, d_query_point, nodes[parent].right, radius, &tmp_idx, &tmp_dist, id, next_nodes, next_search_nodes, Judge); 412 + search_at_node_range(nodes, indices, d_reference_points, d_query_point, nodes[parent].right, radius, &tmp_idx, &tmp_dist, id, next_nodes, next_search_nodes, Judge);
284 } 413 }
285 if (tmp_dist < best_distance) { 414 if (tmp_dist < best_distance) {
286 best_distance = tmp_dist; 415 best_distance = tmp_dist;
@@ -292,11 +421,11 @@ namespace stim { @@ -292,11 +421,11 @@ namespace stim {
292 *d_index = best_index; 421 *d_index = best_index;
293 } 422 }
294 template <typename T, int D> 423 template <typename T, int D>
295 - __global__ void SearchBatch(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> *d_query_points, size_t d_query_count, size_t *d_indices, T *d_distances, int *next_nodes, int *next_search_nodes, int *Judge) { 424 + __global__ void search_batch(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> *d_query_points, size_t d_query_count, size_t *d_indices, T *d_distances, int *next_nodes, int *next_search_nodes, int *Judge) {
296 size_t idx = blockIdx.x * blockDim.x + threadIdx.x; 425 size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
297 if (idx >= d_query_count) return; // avoid segfault 426 if (idx >= d_query_count) return; // avoid segfault
298 427
299 - Search<T, D>(nodes, indices, d_reference_points, d_query_points[idx], &d_indices[idx], &d_distances[idx], idx, next_nodes, next_search_nodes, Judge); // every query points are independent 428 + search<T, D>(nodes, indices, d_reference_points, d_query_points[idx], &d_indices[idx], &d_distances[idx], idx, next_nodes, next_search_nodes, Judge); // every query points are independent
300 } 429 }
301 430
302 template <typename T, int D = 3> 431 template <typename T, int D = 3>
@@ -312,20 +441,20 @@ namespace stim { @@ -312,20 +441,20 @@ namespace stim {
312 HANDLE_ERROR(cudaFree(d_index)); 441 HANDLE_ERROR(cudaFree(d_index));
313 HANDLE_ERROR(cudaFree(d_reference_points)); 442 HANDLE_ERROR(cudaFree(d_reference_points));
314 } 443 }
315 - void CreateKDTree(T *h_reference_points, size_t reference_count, size_t dim_count, size_t max_levels) { 444 + void create(T *h_reference_points, size_t reference_count, size_t max_levels) {
316 if (max_levels > 10) { 445 if (max_levels > 10) {
317 std::cout<<"The max_tree_levels should be smaller!"<<std::endl; 446 std::cout<<"The max_tree_levels should be smaller!"<<std::endl;
318 exit(1); 447 exit(1);
319 - }  
320 - std::vector < typename kdtree::point<T, D> > reference_points(reference_count); // restore the reference points in particular way 448 + }
  449 + std::vector <kdtree::point<T, D>> reference_points(reference_count); // restore the reference points in particular way
321 for (size_t j = 0; j < reference_count; j++) 450 for (size_t j = 0; j < reference_count; j++)
322 - for (size_t i = 0; i < dim_count; i++)  
323 - reference_points[j].dim[i] = h_reference_points[j * dim_count + i]; 451 + for (size_t i = 0; i < D; i++)
  452 + reference_points[j].dim[i] = h_reference_points[j * D + i];
324 cpu_kdtree<T, D> tree; // creating a tree on cpu 453 cpu_kdtree<T, D> tree; // creating a tree on cpu
325 - tree.Create(reference_points, max_levels); // building a tree on cpu  
326 - kdtree::kdnode<T> *d_root = tree.GetRoot();  
327 - int num_nodes = tree.GetNumNodes();  
328 - d_reference_count = reference_points.size(); // also equals to reference_count 454 + tree.cpu_create(reference_points, max_levels); // building a tree on cpu
  455 + kdtree::kdnode<T> *d_root = tree.get_root();
  456 + int num_nodes = tree.get_num_nodes();
  457 + d_reference_count = reference_count; // also equals to reference_count
329 458
330 HANDLE_ERROR(cudaMalloc((void**)&d_nodes, sizeof(cuda_kdnode<T>) * num_nodes)); // copy data from host to device 459 HANDLE_ERROR(cudaMalloc((void**)&d_nodes, sizeof(cuda_kdnode<T>) * num_nodes)); // copy data from host to device
331 HANDLE_ERROR(cudaMalloc((void**)&d_index, sizeof(size_t) * d_reference_count)); 460 HANDLE_ERROR(cudaMalloc((void**)&d_index, sizeof(size_t) * d_reference_count));
@@ -371,11 +500,11 @@ namespace stim { @@ -371,11 +500,11 @@ namespace stim {
371 HANDLE_ERROR(cudaMemcpy(d_index, &indices[0], sizeof(size_t) * indices.size(), cudaMemcpyHostToDevice)); 500 HANDLE_ERROR(cudaMemcpy(d_index, &indices[0], sizeof(size_t) * indices.size(), cudaMemcpyHostToDevice));
372 HANDLE_ERROR(cudaMemcpy(d_reference_points, &reference_points[0], sizeof(kdtree::point<T, D>) * reference_points.size(), cudaMemcpyHostToDevice)); 501 HANDLE_ERROR(cudaMemcpy(d_reference_points, &reference_points[0], sizeof(kdtree::point<T, D>) * reference_points.size(), cudaMemcpyHostToDevice));
373 } 502 }
374 - void Search(T *h_query_points, size_t query_count, size_t dim_count, T *dists, size_t *indices) { 503 + void search(T *h_query_points, size_t query_count, size_t *indices, T *distances) {
375 std::vector < typename kdtree::point<T, D> > query_points(query_count); 504 std::vector < typename kdtree::point<T, D> > query_points(query_count);
376 for (size_t j = 0; j < query_count; j++) 505 for (size_t j = 0; j < query_count; j++)
377 - for (size_t i = 0; i < dim_count; i++)  
378 - query_points[j].dim[i] = h_query_points[j * dim_count + i]; 506 + for (size_t i = 0; i < D; i++)
  507 + query_points[j].dim[i] = h_query_points[j * D + i];
379 508
380 unsigned int threads = (unsigned int)(query_points.size() > 1024 ? 1024 : query_points.size()); 509 unsigned int threads = (unsigned int)(query_points.size() > 1024 ? 1024 : query_points.size());
381 unsigned int blocks = (unsigned int)(query_points.size() / threads + (query_points.size() % threads ? 1 : 0)); 510 unsigned int blocks = (unsigned int)(query_points.size() / threads + (query_points.size() % threads ? 1 : 0));
@@ -392,15 +521,15 @@ namespace stim { @@ -392,15 +521,15 @@ namespace stim {
392 HANDLE_ERROR(cudaMalloc((void**)&d_query_points, sizeof(T) * query_points.size() * D)); 521 HANDLE_ERROR(cudaMalloc((void**)&d_query_points, sizeof(T) * query_points.size() * D));
393 HANDLE_ERROR(cudaMalloc((void**)&d_indices, sizeof(size_t) * query_points.size())); 522 HANDLE_ERROR(cudaMalloc((void**)&d_indices, sizeof(size_t) * query_points.size()));
394 HANDLE_ERROR(cudaMalloc((void**)&d_distances, sizeof(T) * query_points.size())); 523 HANDLE_ERROR(cudaMalloc((void**)&d_distances, sizeof(T) * query_points.size()));
395 - HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * 50 * sizeof(int))); // STACK size right now is 50, you can change it if you mean to 524 + HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * 50 * sizeof(int))); // STACK size right now is 50, you can change it if you mean to
396 HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * 50 * sizeof(int))); 525 HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * 50 * sizeof(int)));
397 HANDLE_ERROR(cudaMemcpy(d_query_points, &query_points[0], sizeof(T) * query_points.size() * D, cudaMemcpyHostToDevice)); 526 HANDLE_ERROR(cudaMemcpy(d_query_points, &query_points[0], sizeof(T) * query_points.size() * D, cudaMemcpyHostToDevice));
398 527
399 - SearchBatch<<<threads, blocks>>> (d_nodes, d_index, d_reference_points, d_query_points, query_points.size(), d_indices, d_distances, next_nodes, next_search_nodes, Judge); 528 + search_batch<<<threads, blocks>>> (d_nodes, d_index, d_reference_points, d_query_points, query_points.size(), d_indices, d_distances, next_nodes, next_search_nodes, Judge);
400 529
401 if (Judge == NULL) { // do the following work if the thread works safely 530 if (Judge == NULL) { // do the following work if the thread works safely
402 HANDLE_ERROR(cudaMemcpy(indices, d_indices, sizeof(size_t) * query_points.size(), cudaMemcpyDeviceToHost)); 531 HANDLE_ERROR(cudaMemcpy(indices, d_indices, sizeof(size_t) * query_points.size(), cudaMemcpyDeviceToHost));
403 - HANDLE_ERROR(cudaMemcpy(dists, d_distances, sizeof(T) * query_points.size(), cudaMemcpyDeviceToHost)); 532 + HANDLE_ERROR(cudaMemcpy(distances, d_distances, sizeof(T) * query_points.size(), cudaMemcpyDeviceToHost));
404 } 533 }
405 534
406 HANDLE_ERROR(cudaFree(next_nodes)); 535 HANDLE_ERROR(cudaFree(next_nodes));