Commit 43b34ee057eab91bac772d21b44bbbe5894da3df
Merge branch 'JACK_netmets' into 'master'
add cpu kdtree version please do the test on linux to see whether it works or not See merge request !13
Showing
2 changed files
with
222 additions
and
67 deletions
Show diff stats
stim/biomodels/network.h
... | ... | @@ -13,6 +13,7 @@ |
13 | 13 | #include <stim/visualization/cylinder.h> |
14 | 14 | #include <stim/structures/kdtree.cuh> |
15 | 15 | #include <boost/tuple/tuple.hpp> |
16 | +#include <stim/cuda/cudatools/timer.h> | |
16 | 17 | |
17 | 18 | |
18 | 19 | namespace stim{ |
... | ... | @@ -415,19 +416,14 @@ public: |
415 | 416 | |
416 | 417 | /// @param A is the network to compare to - the field is generated for A |
417 | 418 | /// @param sigma is the user-defined tolerance value - smaller values provide a stricter comparison |
418 | - stim::network<T> compare(stim::network<T> A, float sigma){ | |
419 | + stim::network<T> compare(stim::network<T> A, float sigma, int device){ | |
419 | 420 | |
420 | 421 | stim::network<T> R; //generate a network storing the result of the comparison |
421 | 422 | R = (*this); //initialize the result with the current network |
422 | 423 | |
423 | - //generate a KD-tree for network A | |
424 | - //float metric = 0.0; // initialize metric to be returned after comparing the networks | |
425 | - stim::cuda_kdtree<T, 3> kdt; // initialize a pointer to a kd tree | |
426 | - size_t MaxTreeLevels = 3; // max tree level | |
427 | - | |
428 | - float *c; // centerline (array of double pointers) - points on kdtree must be double | |
429 | - unsigned int n_data = A.total_points(); // set the number of points | |
430 | - c = (float*) malloc(sizeof(float) * n_data * 3); | |
424 | + T *c; // centerline (array of double pointers) - points on kdtree must be double | |
425 | + size_t n_data = A.total_points(); // set the number of points | |
426 | + c = (T*) malloc(sizeof(T) * n_data * 3); | |
431 | 427 | |
432 | 428 | unsigned t = 0; |
433 | 429 | for(unsigned e = 0; e < A.E.size(); e++){ //for each edge in the network |
... | ... | @@ -440,16 +436,24 @@ public: |
440 | 436 | } |
441 | 437 | } |
442 | 438 | |
439 | + //generate a KD-tree for network A | |
440 | + //float metric = 0.0; // initialize metric to be returned after comparing the network | |
441 | + size_t MaxTreeLevels = 3; // max tree level | |
442 | + | |
443 | +#ifdef __CUDACC__ | |
444 | + cudaSetDevice(device); | |
445 | + stim::cuda_kdtree<T, 3> kdt; // initialize a pointer to a kd tree | |
446 | + | |
443 | 447 | //compare each point in the current network to the field produced by A |
444 | - kdt.CreateKDTree(c, n_data, 3, MaxTreeLevels); // build a KD tree | |
445 | - float *dists = new float[1]; // near neighbor distances | |
448 | + kdt.create(c, n_data, MaxTreeLevels); // build a KD tree | |
449 | + T *dists = new T[1]; // near neighbor distances | |
446 | 450 | size_t *nnIdx = new size_t[1]; // near neighbor indices // allocate near neigh indices |
447 | 451 | |
448 | 452 | stim::vec3<T> p0, p1; |
449 | - float m1; | |
453 | + T m1; | |
450 | 454 | //float M = 0; //stores the total metric value |
451 | 455 | //float L = 0; //stores the total network length |
452 | - float* queryPt = new float[3]; | |
456 | + T* queryPt = new T[3]; | |
453 | 457 | for(unsigned e = 0; e < R.E.size(); e++){ //for each edge in A |
454 | 458 | R.E[e].add_mag(0); //add a new magnitude for the metric |
455 | 459 | |
... | ... | @@ -457,14 +461,36 @@ public: |
457 | 461 | |
458 | 462 | p1 = R.E[e][p]; //get the next point in the edge |
459 | 463 | stim2array(queryPt, p1); |
460 | - kdt.Search(queryPt, 1, 3, dists, nnIdx); //find the distance between A and the current network | |
464 | + kdt.search(queryPt, 1, nnIdx, dists); //find the distance between A and the current network | |
461 | 465 | |
462 | - m1 = 1.0f - gaussianFunction((float)dists[0], sigma); //calculate the metric value based on the distance | |
466 | + m1 = 1.0f - gaussianFunction((T)dists[0], sigma); //calculate the metric value based on the distance | |
463 | 467 | R.E[e].set_mag(m1, p, 1); //set the error for the second point in the segment |
464 | 468 | |
465 | 469 | } |
466 | 470 | } |
471 | +#else | |
472 | + stim::cpu_kdtree<T, 3> kdt; | |
473 | + kdt.create(c, n_data, MaxTreeLevels); | |
474 | + T *dists = new T[1]; // near neighbor distances | |
475 | + size_t *nnIdx = new size_t[1]; // near neighbor indices // allocate near neigh indices | |
476 | + | |
477 | + stim::vec3<T> p0, p1; | |
478 | + T m1; | |
479 | + T* queryPt = new T[3]; | |
480 | + for(unsigned e = 0; e < R.E.size(); e++){ //for each edge in A | |
481 | + R.E[e].add_mag(0); //add a new magnitude for the metric | |
482 | + | |
483 | + for(unsigned p = 0; p < R.E[e].size(); p++){ //for each point in the edge | |
467 | 484 | |
485 | + p1 = R.E[e][p]; //get the next point in the edge | |
486 | + stim2array(queryPt, p1); | |
487 | + kdt.cpu_search(queryPt, 1, nnIdx, dists); //find the distance between A and the current network | |
488 | + | |
489 | + m1 = 1.0f - gaussianFunction((T)dists[0], sigma); //calculate the metric value based on the distance | |
490 | + R.E[e].set_mag(m1, p, 1); //set the error for the second point in the segment | |
491 | + } | |
492 | + } | |
493 | +#endif | |
468 | 494 | return R; //return the resulting network |
469 | 495 | } |
470 | 496 | ... | ... |
stim/structures/kdtree.cuh
1 | -// right now the size of CUDA STACK is set to 50, increase it if you mean to make deeper tree | |
1 | +// right now the size of CUDA STACK is set to 1000, increase it if you mean to make deeper tree | |
2 | 2 | // data should be stored in row-major |
3 | 3 | // x1,x2,x3,x4,x5...... |
4 | 4 | // y1,y2,y3,y4,y5...... |
... | ... | @@ -13,6 +13,7 @@ |
13 | 13 | #include <cuda_runtime_api.h> |
14 | 14 | #include "cuda_runtime.h" |
15 | 15 | #include <vector> |
16 | +#include <cstring> | |
16 | 17 | #include <float.h> |
17 | 18 | #include <iostream> |
18 | 19 | #include <algorithm> |
... | ... | @@ -50,14 +51,14 @@ namespace stim { |
50 | 51 | class cpu_kdtree { |
51 | 52 | protected: |
52 | 53 | int current_axis; // current judging axis |
53 | - int cmps; // count how many time of comparisons (just for cpu-kdtree) | |
54 | 54 | int n_id; // store the total number of nodes |
55 | - std::vector < typename kdtree::point<T, D> > *tmp_points; // transfer or temp points | |
55 | + std::vector < typename kdtree::point<T, D> > *tmp_points; // transfer or temperary points | |
56 | + std::vector < typename kdtree::point<T, D> > cpu_tmp_points; // for cpu searching | |
56 | 57 | kdtree::kdnode<T> *root; // root node |
57 | 58 | static cpu_kdtree<T, D> *cur_tree_ptr; |
58 | 59 | public: |
59 | 60 | cpu_kdtree() { // constructor for creating a cpu_kdtree |
60 | - cur_tree_ptr = this; // create a class pointer points to the current class value | |
61 | + cur_tree_ptr = this; // create a class pointer points to the current class value | |
61 | 62 | n_id = 0; // set total number of points to default 0 |
62 | 63 | } |
63 | 64 | ~cpu_kdtree() { // destructor of cpu_kdtree |
... | ... | @@ -78,8 +79,8 @@ namespace stim { |
78 | 79 | } |
79 | 80 | root = NULL; |
80 | 81 | } |
81 | - void Create(std::vector < typename kdtree::point<T, D> > &reference_points, size_t max_levels) { | |
82 | - tmp_points = &reference_points; | |
82 | + void cpu_create(std::vector < typename kdtree::point<T, D> > &reference_points, size_t max_levels) { | |
83 | + tmp_points = &reference_points; | |
83 | 84 | root = new kdtree::kdnode<T>(); // initializing the root node |
84 | 85 | root->idx = n_id++; // the index of root is 0 |
85 | 86 | root->level = 0; // tree level begins at 0 |
... | ... | @@ -100,7 +101,7 @@ namespace stim { |
100 | 101 | kdtree::kdnode<T> *right = new kdtree::kdnode<T>(); |
101 | 102 | left->idx = n_id++; // set the index of current node's left node |
102 | 103 | right->idx = n_id++; |
103 | - Split(current_node, left, right); // split left and right and determine a node | |
104 | + split(current_node, left, right); // split left and right and determine a node | |
104 | 105 | std::vector <size_t> temp; // empty vecters of int |
105 | 106 | //temp.resize(current_node->indices.size()); |
106 | 107 | current_node->indices.swap(temp); // clean up current node's indices |
... | ... | @@ -118,14 +119,14 @@ namespace stim { |
118 | 119 | next_nodes = next_search_nodes; // go deeper within the tree |
119 | 120 | } |
120 | 121 | } |
121 | - static bool SortPoints(const size_t a, const size_t b) { // create functor for std::sort | |
122 | + static bool sort_points(const size_t a, const size_t b) { // create functor for std::sort | |
122 | 123 | std::vector < typename kdtree::point<T, D> > &pts = *cur_tree_ptr->tmp_points; // put cur_tree_ptr to current input points' pointer |
123 | 124 | return pts[a].dim[cur_tree_ptr->current_axis] < pts[b].dim[cur_tree_ptr->current_axis]; |
124 | 125 | } |
125 | - void Split(kdtree::kdnode<T> *cur, kdtree::kdnode<T> *left, kdtree::kdnode<T> *right) { | |
126 | + void split(kdtree::kdnode<T> *cur, kdtree::kdnode<T> *left, kdtree::kdnode<T> *right) { | |
126 | 127 | std::vector < typename kdtree::point<T, D> > &pts = *tmp_points; |
127 | 128 | current_axis = cur->level % D; // indicate the judicative dimension or axis |
128 | - std::sort(cur->indices.begin(), cur->indices.end(), SortPoints); // using SortPoints as comparison function to sort the data | |
129 | + std::sort(cur->indices.begin(), cur->indices.end(), sort_points); // using SortPoints as comparison function to sort the data | |
129 | 130 | size_t mid_value = cur->indices[cur->indices.size() / 2]; // odd in the mid_value, even take the floor |
130 | 131 | cur->split_value = pts[mid_value].dim[current_axis]; // get the parent node |
131 | 132 | left->parent = cur; // set the parent of the next search nodes to current node |
... | ... | @@ -142,48 +143,176 @@ namespace stim { |
142 | 143 | right->indices.push_back(idx); |
143 | 144 | } |
144 | 145 | } |
145 | - int GetNumNodes() const { // get the total number of nodes | |
146 | + void create(T *h_reference_points, size_t reference_count, size_t max_levels) { | |
147 | + std::vector < typename kdtree::point<T, D> > reference_points(reference_count); // restore the reference points in particular way | |
148 | + for (size_t j = 0; j < reference_count; j++) | |
149 | + for (size_t i = 0; i < D; i++) | |
150 | + reference_points[j].dim[i] = h_reference_points[j * D + i]; | |
151 | + cpu_create(reference_points, max_levels); | |
152 | + cpu_tmp_points = *tmp_points; | |
153 | + } | |
154 | + int get_num_nodes() const { // get the total number of nodes | |
146 | 155 | return n_id; |
147 | 156 | } |
148 | - kdtree::kdnode<T>* GetRoot() const { // get the root node of tree | |
157 | + kdtree::kdnode<T>* get_root() const { // get the root node of tree | |
149 | 158 | return root; |
150 | 159 | } |
160 | + T cpu_distance(const kdtree::point<T, D> &a, const kdtree::point<T, D> &b) { | |
161 | + T distance = 0; | |
162 | + | |
163 | + for (size_t i = 0; i < D; i++) { | |
164 | + T d = a.dim[i] - b.dim[i]; | |
165 | + distance += d*d; | |
166 | + } | |
167 | + return distance; | |
168 | + } | |
169 | + void cpu_search_at_node(kdtree::kdnode<T> *cur, const kdtree::point<T, D> &query, size_t *index, T *distance, kdtree::kdnode<T> **node) { | |
170 | + T best_distance = FLT_MAX; // initialize the best distance to max of floating point | |
171 | + size_t best_index = 0; | |
172 | + std::vector < typename kdtree::point<T, D> > pts = cpu_tmp_points; | |
173 | + while (true) { | |
174 | + size_t split_axis = cur->level % D; | |
175 | + if (cur->left == NULL) { // risky but acceptable, same goes for right because left and right are in same pace | |
176 | + *node = cur; // pointer points to a pointer | |
177 | + for (size_t i = 0; i < cur->indices.size(); i++) { | |
178 | + size_t idx = cur->indices[i]; | |
179 | + T d = cpu_distance(query, pts[idx]); // compute distances | |
180 | + /// if we want to compute k nearest neighbor, we can input the last resul | |
181 | + /// (last_best_dist < dist < best_dist) to select the next point until reaching to k | |
182 | + if (d < best_distance) { | |
183 | + best_distance = d; | |
184 | + best_index = idx; // record the nearest neighbor index | |
185 | + } | |
186 | + } | |
187 | + break; // find the target point then break the loop | |
188 | + } | |
189 | + else if (query.dim[split_axis] < cur->split_value) { // if it has son node, visit the next node on either left side or right side | |
190 | + cur = cur->left; | |
191 | + } | |
192 | + else { | |
193 | + cur = cur->right; | |
194 | + } | |
195 | + } | |
196 | + *index = best_index; | |
197 | + *distance = best_distance; | |
198 | + } | |
199 | + void cpu_search_at_node_range(kdtree::kdnode<T> *cur, const kdtree::point<T, D> &query, T range, size_t *index, T *distance) { | |
200 | + T best_distance = FLT_MAX; // initialize the best distance to max of floating point | |
201 | + size_t best_index = 0; | |
202 | + std::vector < typename kdtree::point<T, D> > pts = cpu_tmp_points; | |
203 | + std::vector < typename kdtree::kdnode<T>*> next_node; | |
204 | + next_node.push_back(cur); | |
205 | + while (next_node.size()) { | |
206 | + std::vector<typename kdtree::kdnode<T>*> next_search; | |
207 | + while (next_node.size()) { | |
208 | + cur = next_node.back(); | |
209 | + next_node.pop_back(); | |
210 | + size_t split_axis = cur->level % D; | |
211 | + if (cur->left == NULL) { | |
212 | + for (size_t i = 0; i < cur->indices.size(); i++) { | |
213 | + size_t idx = cur->indices[i]; | |
214 | + T d = cpu_distance(query, pts[idx]); | |
215 | + if (d < best_distance) { | |
216 | + best_distance = d; | |
217 | + best_index = idx; | |
218 | + } | |
219 | + } | |
220 | + } | |
221 | + else { | |
222 | + T d = query.dim[split_axis] - cur->split_value; // computer distance along specific axis or dimension | |
223 | + /// there are three possibilities: on either left or right, and on both left and right | |
224 | + if (fabs(d) > range) { // absolute value of floating point to see if distance will be larger that best_dist | |
225 | + if (d < 0) | |
226 | + next_search.push_back(cur->left); // every left[split_axis] is less and equal to cur->split_value, so it is possible to find the nearest point in this region | |
227 | + else | |
228 | + next_search.push_back(cur->right); | |
229 | + } | |
230 | + else { // it is possible that nereast neighbor will appear on both left and right | |
231 | + next_search.push_back(cur->left); | |
232 | + next_search.push_back(cur->right); | |
233 | + } | |
234 | + } | |
235 | + } | |
236 | + next_node = next_search; // pop out at least one time | |
237 | + } | |
238 | + *index = best_index; | |
239 | + *distance = best_distance; | |
240 | + } | |
241 | + void cpu_search(T *h_query_points, size_t query_count, size_t *h_indices, T *h_distances) { | |
242 | + /// first convert the input query point into specific type | |
243 | + kdtree::point<T, D> query; | |
244 | + for (size_t j = 0; j < query_count; j++) { | |
245 | + for (size_t i = 0; i < D; i++) | |
246 | + query.dim[i] = h_query_points[j * D + i]; | |
247 | + /// find the nearest node, this will be the upper bound for the next time searching | |
248 | + kdtree::kdnode<T> *best_node = NULL; | |
249 | + T best_distance = FLT_MAX; | |
250 | + size_t best_index = 0; | |
251 | + T radius = 0; // radius for range | |
252 | + cpu_search_at_node(root, query, &best_index, &best_distance, &best_node); // simple search to rougly determine a result for next search step | |
253 | + radius = sqrt(best_distance); // It is possible that nearest will appear in another region | |
254 | + /// find other possibilities | |
255 | + kdtree::kdnode<T> *cur = best_node; | |
256 | + while (cur->parent != NULL) { // every node that you pass will be possible to be the best node | |
257 | + /// go up | |
258 | + kdtree::kdnode<T> *parent = cur->parent; // travel back to every node that we pass through | |
259 | + size_t split_axis = (parent->level) % D; | |
260 | + /// search other nodes | |
261 | + size_t tmp_index; | |
262 | + T tmp_distance = FLT_MAX; | |
263 | + if (fabs(parent->split_value - query.dim[split_axis]) <= radius) { | |
264 | + /// search opposite node | |
265 | + if (parent->left != cur) | |
266 | + cpu_search_at_node_range(parent->left, query, radius, &tmp_index, &tmp_distance); // to see whether it is its mother node's left son node | |
267 | + else | |
268 | + cpu_search_at_node_range(parent->right, query, radius, &tmp_index, &tmp_distance); | |
269 | + } | |
270 | + if (tmp_distance < best_distance) { | |
271 | + best_distance = tmp_distance; | |
272 | + best_index = tmp_index; | |
273 | + } | |
274 | + cur = parent; | |
275 | + } | |
276 | + h_indices[j] = best_index; | |
277 | + h_distances[j] = best_distance; | |
278 | + } | |
279 | + } | |
151 | 280 | }; //end class kdtree |
152 | 281 | |
153 | 282 | template <typename T, int D> |
154 | - cpu_kdtree<T, D>* cpu_kdtree<T, D>::cur_tree_ptr = NULL; // definition of cur_tree_ptr pointer points to the current class | |
283 | + cpu_kdtree<T, D>* cpu_kdtree<T, D>::cur_tree_ptr = NULL; // definition of cur_tree_ptr pointer points to the current class | |
155 | 284 | |
156 | 285 | template <typename T> |
157 | 286 | struct cuda_kdnode { |
158 | 287 | int parent, left, right; |
159 | 288 | T split_value; |
160 | - size_t num_index; // number of indices it has | |
161 | - int index; // the beginning index | |
289 | + size_t num_index; // number of indices it has | |
290 | + int index; // the beginning index | |
162 | 291 | size_t level; |
163 | 292 | }; |
164 | 293 | |
165 | 294 | template <typename T, int D> |
166 | - __device__ T Distance(kdtree::point<T, D> &a, kdtree::point<T, D> &b) { | |
167 | - T dist = 0; | |
295 | + __device__ T gpu_distance(kdtree::point<T, D> &a, kdtree::point<T, D> &b) { | |
296 | + T distance = 0; | |
168 | 297 | |
169 | 298 | for (size_t i = 0; i < D; i++) { |
170 | 299 | T d = a.dim[i] - b.dim[i]; |
171 | - dist += d*d; | |
300 | + distance += d*d; | |
172 | 301 | } |
173 | - return dist; | |
302 | + return distance; | |
174 | 303 | } |
175 | 304 | template <typename T, int D> |
176 | - __device__ void SearchAtNode(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, int cur, kdtree::point<T, D> &d_query_point, size_t *d_index, T *d_distance, int *d_node) { | |
305 | + __device__ void search_at_node(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, int cur, kdtree::point<T, D> &d_query_point, size_t *d_index, T *d_distance, int *d_node) { | |
177 | 306 | T best_distance = FLT_MAX; |
178 | 307 | size_t best_index = 0; |
179 | 308 | |
180 | - while (true) { // break until reach the bottom | |
309 | + while (true) { // break until reach the bottom | |
181 | 310 | int split_axis = nodes[cur].level % D; |
182 | - if (nodes[cur].left == -1) { // check whether it has left node or not | |
311 | + if (nodes[cur].left == -1) { // check whether it has left node or not | |
183 | 312 | *d_node = cur; |
184 | 313 | for (int i = 0; i < nodes[cur].num_index; i++) { |
185 | 314 | size_t idx = indices[nodes[cur].index + i]; |
186 | - T dist = Distance<T, D>(d_query_point, d_reference_points[idx]); | |
315 | + T dist = gpu_distance<T, D>(d_query_point, d_reference_points[idx]); | |
187 | 316 | if (dist < best_distance) { |
188 | 317 | best_distance = dist; |
189 | 318 | best_index = idx; |
... | ... | @@ -191,7 +320,7 @@ namespace stim { |
191 | 320 | } |
192 | 321 | break; |
193 | 322 | } |
194 | - else if (d_query_point.dim[split_axis] < nodes[cur].split_value) { // jump into specific son node | |
323 | + else if (d_query_point.dim[split_axis] < nodes[cur].split_value) { // jump into specific son node | |
195 | 324 | cur = nodes[cur].left; |
196 | 325 | } |
197 | 326 | else { |
... | ... | @@ -202,25 +331,25 @@ namespace stim { |
202 | 331 | *d_index = best_index; |
203 | 332 | } |
204 | 333 | template <typename T, int D> |
205 | - __device__ void SearchAtNodeRange(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> &d_query_point, int cur, T range, size_t *d_index, T *d_distance, size_t id, int *next_nodes, int *next_search_nodes, int *Judge) { | |
334 | + __device__ void search_at_node_range(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> &d_query_point, int cur, T range, size_t *d_index, T *d_distance, size_t id, int *next_nodes, int *next_search_nodes, int *Judge) { | |
206 | 335 | T best_distance = FLT_MAX; |
207 | 336 | size_t best_index = 0; |
208 | 337 | |
209 | - int next_nodes_pos = 0; // initialize pop out order index | |
210 | - next_nodes[id * 50 + next_nodes_pos] = cur; // find data that belongs to the very specific thread | |
338 | + int next_nodes_pos = 0; // initialize pop out order index | |
339 | + next_nodes[id * 50 + next_nodes_pos] = cur; // find data that belongs to the very specific thread | |
211 | 340 | next_nodes_pos++; |
212 | 341 | |
213 | 342 | while (next_nodes_pos) { |
214 | - int next_search_nodes_pos = 0; // record push back order index | |
343 | + int next_search_nodes_pos = 0; // record push back order index | |
215 | 344 | while (next_nodes_pos) { |
216 | - cur = next_nodes[id * 50 + next_nodes_pos - 1]; // pop out the last push in one and keep poping out | |
345 | + cur = next_nodes[id * 50 + next_nodes_pos - 1]; // pop out the last push in one and keep poping out | |
217 | 346 | next_nodes_pos--; |
218 | 347 | int split_axis = nodes[cur].level % D; |
219 | 348 | |
220 | 349 | if (nodes[cur].left == -1) { |
221 | 350 | for (int i = 0; i < nodes[cur].num_index; i++) { |
222 | - int idx = indices[nodes[cur].index + i]; // all indices are stored in one array, pick up from every node's beginning index | |
223 | - T d = Distance<T>(d_query_point, d_reference_points[idx]); | |
351 | + int idx = indices[nodes[cur].index + i]; // all indices are stored in one array, pick up from every node's beginning index | |
352 | + T d = gpu_distance<T>(d_query_point, d_reference_points[idx]); | |
224 | 353 | if (d < best_distance) { |
225 | 354 | best_distance = d; |
226 | 355 | best_index = idx; |
... | ... | @@ -260,13 +389,13 @@ namespace stim { |
260 | 389 | *d_index = best_index; |
261 | 390 | } |
262 | 391 | template <typename T, int D> |
263 | - __device__ void Search(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> &d_query_point, size_t *d_index, T *d_distance, size_t id, int *next_nodes, int *next_search_nodes, int *Judge) { | |
392 | + __device__ void search(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> &d_query_point, size_t *d_index, T *d_distance, size_t id, int *next_nodes, int *next_search_nodes, int *Judge) { | |
264 | 393 | int best_node = 0; |
265 | 394 | T best_distance = FLT_MAX; |
266 | 395 | size_t best_index = 0; |
267 | 396 | T radius = 0; |
268 | 397 | |
269 | - SearchAtNode<T, D>(nodes, indices, d_reference_points, 0, d_query_point, &best_index, &best_distance, &best_node); | |
398 | + search_at_node<T, D>(nodes, indices, d_reference_points, 0, d_query_point, &best_index, &best_distance, &best_node); | |
270 | 399 | radius = sqrt(best_distance); // get range |
271 | 400 | int cur = best_node; |
272 | 401 | |
... | ... | @@ -278,9 +407,9 @@ namespace stim { |
278 | 407 | size_t tmp_idx; |
279 | 408 | if (fabs(nodes[parent].split_value - d_query_point.dim[split_axis]) <= radius) { |
280 | 409 | if (nodes[parent].left != cur) |
281 | - SearchAtNodeRange(nodes, indices, d_reference_points, d_query_point, nodes[parent].left, radius, &tmp_idx, &tmp_dist, id, next_nodes, next_search_nodes, Judge); | |
410 | + search_at_node_range(nodes, indices, d_reference_points, d_query_point, nodes[parent].left, radius, &tmp_idx, &tmp_dist, id, next_nodes, next_search_nodes, Judge); | |
282 | 411 | else |
283 | - SearchAtNodeRange(nodes, indices, d_reference_points, d_query_point, nodes[parent].right, radius, &tmp_idx, &tmp_dist, id, next_nodes, next_search_nodes, Judge); | |
412 | + search_at_node_range(nodes, indices, d_reference_points, d_query_point, nodes[parent].right, radius, &tmp_idx, &tmp_dist, id, next_nodes, next_search_nodes, Judge); | |
284 | 413 | } |
285 | 414 | if (tmp_dist < best_distance) { |
286 | 415 | best_distance = tmp_dist; |
... | ... | @@ -292,11 +421,11 @@ namespace stim { |
292 | 421 | *d_index = best_index; |
293 | 422 | } |
294 | 423 | template <typename T, int D> |
295 | - __global__ void SearchBatch(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> *d_query_points, size_t d_query_count, size_t *d_indices, T *d_distances, int *next_nodes, int *next_search_nodes, int *Judge) { | |
424 | + __global__ void search_batch(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> *d_query_points, size_t d_query_count, size_t *d_indices, T *d_distances, int *next_nodes, int *next_search_nodes, int *Judge) { | |
296 | 425 | size_t idx = blockIdx.x * blockDim.x + threadIdx.x; |
297 | 426 | if (idx >= d_query_count) return; // avoid segfault |
298 | 427 | |
299 | - Search<T, D>(nodes, indices, d_reference_points, d_query_points[idx], &d_indices[idx], &d_distances[idx], idx, next_nodes, next_search_nodes, Judge); // every query points are independent | |
428 | + search<T, D>(nodes, indices, d_reference_points, d_query_points[idx], &d_indices[idx], &d_distances[idx], idx, next_nodes, next_search_nodes, Judge); // every query points are independent | |
300 | 429 | } |
301 | 430 | |
302 | 431 | template <typename T, int D = 3> |
... | ... | @@ -312,20 +441,20 @@ namespace stim { |
312 | 441 | HANDLE_ERROR(cudaFree(d_index)); |
313 | 442 | HANDLE_ERROR(cudaFree(d_reference_points)); |
314 | 443 | } |
315 | - void CreateKDTree(T *h_reference_points, size_t reference_count, size_t dim_count, size_t max_levels) { | |
444 | + void create(T *h_reference_points, size_t reference_count, size_t max_levels) { | |
316 | 445 | if (max_levels > 10) { |
317 | 446 | std::cout<<"The max_tree_levels should be smaller!"<<std::endl; |
318 | 447 | exit(1); |
319 | - } | |
320 | - std::vector < typename kdtree::point<T, D> > reference_points(reference_count); // restore the reference points in particular way | |
448 | + } | |
449 | + std::vector <kdtree::point<T, D>> reference_points(reference_count); // restore the reference points in particular way | |
321 | 450 | for (size_t j = 0; j < reference_count; j++) |
322 | - for (size_t i = 0; i < dim_count; i++) | |
323 | - reference_points[j].dim[i] = h_reference_points[j * dim_count + i]; | |
451 | + for (size_t i = 0; i < D; i++) | |
452 | + reference_points[j].dim[i] = h_reference_points[j * D + i]; | |
324 | 453 | cpu_kdtree<T, D> tree; // creating a tree on cpu |
325 | - tree.Create(reference_points, max_levels); // building a tree on cpu | |
326 | - kdtree::kdnode<T> *d_root = tree.GetRoot(); | |
327 | - int num_nodes = tree.GetNumNodes(); | |
328 | - d_reference_count = reference_points.size(); // also equals to reference_count | |
454 | + tree.cpu_create(reference_points, max_levels); // building a tree on cpu | |
455 | + kdtree::kdnode<T> *d_root = tree.get_root(); | |
456 | + int num_nodes = tree.get_num_nodes(); | |
457 | + d_reference_count = reference_count; // also equals to reference_count | |
329 | 458 | |
330 | 459 | HANDLE_ERROR(cudaMalloc((void**)&d_nodes, sizeof(cuda_kdnode<T>) * num_nodes)); // copy data from host to device |
331 | 460 | HANDLE_ERROR(cudaMalloc((void**)&d_index, sizeof(size_t) * d_reference_count)); |
... | ... | @@ -371,11 +500,11 @@ namespace stim { |
371 | 500 | HANDLE_ERROR(cudaMemcpy(d_index, &indices[0], sizeof(size_t) * indices.size(), cudaMemcpyHostToDevice)); |
372 | 501 | HANDLE_ERROR(cudaMemcpy(d_reference_points, &reference_points[0], sizeof(kdtree::point<T, D>) * reference_points.size(), cudaMemcpyHostToDevice)); |
373 | 502 | } |
374 | - void Search(T *h_query_points, size_t query_count, size_t dim_count, T *dists, size_t *indices) { | |
503 | + void search(T *h_query_points, size_t query_count, size_t *indices, T *distances) { | |
375 | 504 | std::vector < typename kdtree::point<T, D> > query_points(query_count); |
376 | 505 | for (size_t j = 0; j < query_count; j++) |
377 | - for (size_t i = 0; i < dim_count; i++) | |
378 | - query_points[j].dim[i] = h_query_points[j * dim_count + i]; | |
506 | + for (size_t i = 0; i < D; i++) | |
507 | + query_points[j].dim[i] = h_query_points[j * D + i]; | |
379 | 508 | |
380 | 509 | unsigned int threads = (unsigned int)(query_points.size() > 1024 ? 1024 : query_points.size()); |
381 | 510 | unsigned int blocks = (unsigned int)(query_points.size() / threads + (query_points.size() % threads ? 1 : 0)); |
... | ... | @@ -392,15 +521,15 @@ namespace stim { |
392 | 521 | HANDLE_ERROR(cudaMalloc((void**)&d_query_points, sizeof(T) * query_points.size() * D)); |
393 | 522 | HANDLE_ERROR(cudaMalloc((void**)&d_indices, sizeof(size_t) * query_points.size())); |
394 | 523 | HANDLE_ERROR(cudaMalloc((void**)&d_distances, sizeof(T) * query_points.size())); |
395 | - HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * 50 * sizeof(int))); // STACK size right now is 50, you can change it if you mean to | |
524 | + HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * 50 * sizeof(int))); // STACK size right now is 50, you can change it if you mean to | |
396 | 525 | HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * 50 * sizeof(int))); |
397 | 526 | HANDLE_ERROR(cudaMemcpy(d_query_points, &query_points[0], sizeof(T) * query_points.size() * D, cudaMemcpyHostToDevice)); |
398 | 527 | |
399 | - SearchBatch<<<threads, blocks>>> (d_nodes, d_index, d_reference_points, d_query_points, query_points.size(), d_indices, d_distances, next_nodes, next_search_nodes, Judge); | |
528 | + search_batch<<<threads, blocks>>> (d_nodes, d_index, d_reference_points, d_query_points, query_points.size(), d_indices, d_distances, next_nodes, next_search_nodes, Judge); | |
400 | 529 | |
401 | 530 | if (Judge == NULL) { // do the following work if the thread works safely |
402 | 531 | HANDLE_ERROR(cudaMemcpy(indices, d_indices, sizeof(size_t) * query_points.size(), cudaMemcpyDeviceToHost)); |
403 | - HANDLE_ERROR(cudaMemcpy(dists, d_distances, sizeof(T) * query_points.size(), cudaMemcpyDeviceToHost)); | |
532 | + HANDLE_ERROR(cudaMemcpy(distances, d_distances, sizeof(T) * query_points.size(), cudaMemcpyDeviceToHost)); | |
404 | 533 | } |
405 | 534 | |
406 | 535 | HANDLE_ERROR(cudaFree(next_nodes)); | ... | ... |