Commit eaaf04c2aa0d8cbad1830fe51e30ee4f72ac7141
Merged branch master into master
Showing
1 changed file
with
99 additions
and
29 deletions
Show diff stats
stim/structures/kdtree.cuh
... | ... | @@ -63,6 +63,7 @@ namespace stim { |
63 | 63 | cur_tree_ptr = this; // create a class pointer points to the current class value |
64 | 64 | n_id = 0; // set total number of points to default 0 |
65 | 65 | } |
66 | + | |
66 | 67 | ~cpu_kdtree() { // destructor of cpu_kdtree |
67 | 68 | std::vector <kdtree::kdnode<T>*> next_nodes; |
68 | 69 | next_nodes.push_back(root); |
... | ... | @@ -81,6 +82,7 @@ namespace stim { |
81 | 82 | } |
82 | 83 | root = NULL; |
83 | 84 | } |
85 | + | |
84 | 86 | void cpu_create(std::vector < typename kdtree::point<T, D> > &reference_points, size_t max_levels) { |
85 | 87 | tmp_points = &reference_points; |
86 | 88 | root = new kdtree::kdnode<T>(); // initializing the root node |
... | ... | @@ -121,10 +123,12 @@ namespace stim { |
121 | 123 | next_nodes = next_search_nodes; // go deeper within the tree |
122 | 124 | } |
123 | 125 | } |
126 | + | |
124 | 127 | static bool sort_points(const size_t a, const size_t b) { // create functor for std::sort |
125 | 128 | std::vector < typename kdtree::point<T, D> > &pts = *cur_tree_ptr->tmp_points; // put cur_tree_ptr to current input points' pointer |
126 | 129 | return pts[a].dim[cur_tree_ptr->current_axis] < pts[b].dim[cur_tree_ptr->current_axis]; |
127 | 130 | } |
131 | + | |
128 | 132 | void split(kdtree::kdnode<T> *cur, kdtree::kdnode<T> *left, kdtree::kdnode<T> *right) { |
129 | 133 | std::vector < typename kdtree::point<T, D> > &pts = *tmp_points; |
130 | 134 | current_axis = cur->level % D; // indicate the judicative dimension or axis |
... | ... | @@ -145,6 +149,7 @@ namespace stim { |
145 | 149 | right->indices.push_back(idx); |
146 | 150 | } |
147 | 151 | } |
152 | + | |
148 | 153 | void create(T *h_reference_points, size_t reference_count, size_t max_levels) { |
149 | 154 | std::vector < typename kdtree::point<T, D> > reference_points(reference_count); // restore the reference points in particular way |
150 | 155 | for (size_t j = 0; j < reference_count; j++) |
... | ... | @@ -153,13 +158,16 @@ namespace stim { |
153 | 158 | cpu_create(reference_points, max_levels); |
154 | 159 | cpu_tmp_points = *tmp_points; |
155 | 160 | } |
161 | + | |
156 | 162 | int get_num_nodes() const { // get the total number of nodes |
157 | 163 | return n_id; |
158 | 164 | } |
165 | + | |
159 | 166 | kdtree::kdnode<T>* get_root() const { // get the root node of tree |
160 | 167 | return root; |
161 | 168 | } |
162 | - T cpu_distance(const kdtree::point<T, D> &a, const kdtree::point<T, D> &b) { | |
169 | + | |
170 | + T cpu_distance(const kdtree::point<T, D> &a, const kdtree::point<T, D> &b) { | |
163 | 171 | T distance = 0; |
164 | 172 | |
165 | 173 | for (size_t i = 0; i < D; i++) { |
... | ... | @@ -168,6 +176,7 @@ namespace stim { |
168 | 176 | } |
169 | 177 | return distance; |
170 | 178 | } |
179 | + | |
171 | 180 | void cpu_search_at_node(kdtree::kdnode<T> *cur, const kdtree::point<T, D> &query, size_t *index, T *distance, kdtree::kdnode<T> **node) { |
172 | 181 | T best_distance = FLT_MAX; // initialize the best distance to max of floating point |
173 | 182 | size_t best_index = 0; |
... | ... | @@ -198,6 +207,7 @@ namespace stim { |
198 | 207 | *index = best_index; |
199 | 208 | *distance = best_distance; |
200 | 209 | } |
210 | + | |
201 | 211 | void cpu_search_at_node_range(kdtree::kdnode<T> *cur, const kdtree::point<T, D> &query, T range, size_t *index, T *distance) { |
202 | 212 | T best_distance = FLT_MAX; // initialize the best distance to max of floating point |
203 | 213 | size_t best_index = 0; |
... | ... | @@ -240,6 +250,7 @@ namespace stim { |
240 | 250 | *index = best_index; |
241 | 251 | *distance = best_distance; |
242 | 252 | } |
253 | + | |
243 | 254 | void cpu_search(T *h_query_points, size_t query_count, size_t *h_indices, T *h_distances) { |
244 | 255 | /// first convert the input query point into specific type |
245 | 256 | kdtree::point<T, D> query; |
... | ... | @@ -303,6 +314,7 @@ namespace stim { |
303 | 314 | } |
304 | 315 | return distance; |
305 | 316 | } |
317 | + | |
306 | 318 | template <typename T, int D> |
307 | 319 | __device__ void search_at_node(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, int cur, kdtree::point<T, D> &d_query_point, size_t *d_index, T *d_distance, int *d_node) { |
308 | 320 | T best_distance = FLT_MAX; |
... | ... | @@ -332,6 +344,7 @@ namespace stim { |
332 | 344 | *d_distance = best_distance; |
333 | 345 | *d_index = best_index; |
334 | 346 | } |
347 | + | |
335 | 348 | template <typename T, int D> |
336 | 349 | __device__ void search_at_node_range(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> &d_query_point, int cur, T range, size_t *d_index, T *d_distance, size_t id, int *next_nodes, int *next_search_nodes, int *Judge) { |
337 | 350 | T best_distance = FLT_MAX; |
... | ... | @@ -390,6 +403,7 @@ namespace stim { |
390 | 403 | *d_distance = best_distance; |
391 | 404 | *d_index = best_index; |
392 | 405 | } |
406 | + | |
393 | 407 | template <typename T, int D> |
394 | 408 | __device__ void search(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> &d_query_point, size_t *d_index, T *d_distance, size_t id, int *next_nodes, int *next_search_nodes, int *Judge) { |
395 | 409 | int best_node = 0; |
... | ... | @@ -422,6 +436,7 @@ namespace stim { |
422 | 436 | *d_distance = sqrt(best_distance); |
423 | 437 | *d_index = best_index; |
424 | 438 | } |
439 | + | |
425 | 440 | template <typename T, int D> |
426 | 441 | __global__ void search_batch(cuda_kdnode<T> *nodes, size_t *indices, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> *d_query_points, size_t d_query_count, size_t *d_indices, T *d_distances, int *next_nodes, int *next_search_nodes, int *Judge) { |
427 | 442 | size_t idx = blockIdx.x * blockDim.x + threadIdx.x; |
... | ... | @@ -429,6 +444,41 @@ namespace stim { |
429 | 444 | |
430 | 445 | search<T, D>(nodes, indices, d_reference_points, d_query_points[idx], &d_indices[idx], &d_distances[idx], idx, next_nodes, next_search_nodes, Judge); // every query points are independent |
431 | 446 | } |
447 | + | |
448 | + template <typename T, int D> | |
449 | + void search_stream(cuda_kdnode<T> *d_nodes, size_t *d_index, kdtree::point<T, D> *d_reference_points, kdtree::point<T, D> *query_stream_points, size_t stream_count, size_t *indices, T *distances) { | |
450 | + unsigned int threads = (unsigned int)(stream_count > 1024 ? 1024 : stream_count); | |
451 | + unsigned int blocks = (unsigned int)(stream_count / threads + (stream_count % threads ? 1 : 0)); | |
452 | + | |
453 | + kdtree::point<T, D> *d_query_points; | |
454 | + size_t *d_indices; | |
455 | + T *d_distances; | |
456 | + | |
457 | + int *next_nodes; | |
458 | + int *next_search_nodes; | |
459 | + | |
460 | + HANDLE_ERROR(cudaMalloc((void**)&d_query_points, sizeof(T) * stream_count * D)); | |
461 | + HANDLE_ERROR(cudaMalloc((void**)&d_indices, sizeof(size_t) * stream_count)); | |
462 | + HANDLE_ERROR(cudaMalloc((void**)&d_distances, sizeof(T) * stream_count)); | |
463 | + HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * stack_size * sizeof(int))); | |
464 | + HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * stack_size * sizeof(int))); | |
465 | + HANDLE_ERROR(cudaMemcpy(d_query_points, query_stream_points, sizeof(T) * stream_count * D, cudaMemcpyHostToDevice)); | |
466 | + | |
467 | + int *Judge = NULL; | |
468 | + | |
469 | + search_batch<<<blocks, threads>>> (d_nodes, d_index, d_reference_points, d_query_points, stream_count, d_indices, d_distances, next_nodes, next_search_nodes, Judge); | |
470 | + | |
471 | + if(Judge == NULL) { | |
472 | + HANDLE_ERROR(cudaMemcpy(indices, d_indices, sizeof(size_t) * stream_count, cudaMemcpyDeviceToHost)); | |
473 | + HANDLE_ERROR(cudaMemcpy(distances, d_distances, sizeof(T) * stream_count, cudaMemcpyDeviceToHost)); | |
474 | + } | |
475 | + | |
476 | + HANDLE_ERROR(cudaFree(next_nodes)); | |
477 | + HANDLE_ERROR(cudaFree(next_search_nodes)); | |
478 | + HANDLE_ERROR(cudaFree(d_query_points)); | |
479 | + HANDLE_ERROR(cudaFree(d_indices)); | |
480 | + HANDLE_ERROR(cudaFree(d_distances)); | |
481 | + } | |
432 | 482 | |
433 | 483 | template <typename T, int D = 3> |
434 | 484 | class cuda_kdtree { |
... | ... | @@ -457,7 +507,7 @@ namespace stim { |
457 | 507 | //bb.init(&h_reference_points[0]); |
458 | 508 | //aaboundingboxing<T, D>(bb, h_reference_points, reference_count); |
459 | 509 | |
460 | - std::vector < typename kdtree::point<T, D> > reference_points(reference_count); // restore the reference points in particular way | |
510 | + std::vector < typename kdtree::point<T, D>> reference_points(reference_count); // restore the reference points in particular way | |
461 | 511 | for (size_t j = 0; j < reference_count; j++) |
462 | 512 | for (size_t i = 0; i < D; i++) |
463 | 513 | reference_points[j].dim[i] = h_reference_points[j * D + i]; |
... | ... | @@ -509,7 +559,7 @@ namespace stim { |
509 | 559 | } |
510 | 560 | HANDLE_ERROR(cudaMemcpy(d_nodes, &tmp_nodes[0], sizeof(cuda_kdnode<T>) * tmp_nodes.size(), cudaMemcpyHostToDevice)); |
511 | 561 | HANDLE_ERROR(cudaMemcpy(d_index, &indices[0], sizeof(size_t) * indices.size(), cudaMemcpyHostToDevice)); |
512 | - HANDLE_ERROR(cudaMemcpy(d_reference_points, &reference_points[0], sizeof(kdtree::point<T, D>) * reference_points.size(), cudaMemcpyHostToDevice)); | |
562 | + HANDLE_ERROR(cudaMemcpy(d_reference_points, &reference_points[0], sizeof(kdtree::point<T, D>) * reference_count, cudaMemcpyHostToDevice)); | |
513 | 563 | } |
514 | 564 | |
515 | 565 | /// Search the KD tree for nearest neighbors to a set of specified query points |
... | ... | @@ -523,37 +573,57 @@ namespace stim { |
523 | 573 | for (size_t i = 0; i < D; i++) |
524 | 574 | query_points[j].dim[i] = h_query_points[j * D + i]; |
525 | 575 | |
526 | - unsigned int threads = (unsigned int)(query_points.size() > 1024 ? 1024 : query_points.size()); | |
527 | - unsigned int blocks = (unsigned int)(query_points.size() / threads + (query_points.size() % threads ? 1 : 0)); | |
576 | + cudaDeviceProp prop; | |
577 | + cudaGetDeviceProperties(&prop, 0); | |
578 | + | |
579 | + size_t query_memory = D * sizeof(T) * query_count; | |
580 | + size_t N = 3 * query_memory / prop.totalGlobalMem; //consider index and distance, roughly 3 times | |
581 | + if (N > 1) { | |
582 | + N++; | |
583 | + size_t stream_count = query_count / N; | |
584 | + for (size_t n = 0; n < N; n++) { | |
585 | + size_t query_stream_start = n * stream_count; | |
586 | + search_stream(d_nodes, d_index, d_reference_points, &query_points[query_stream_start], stream_count, &indices[query_stream_start], &distances[query_stream_start]); | |
587 | + } | |
588 | + size_t stream_remain_count = query_count - N * stream_count; | |
589 | + if (stream_remain_count > 0) { | |
590 | + size_t query_remain_start = N * stream_count; | |
591 | + search_stream(d_nodes, d_index, d_reference_points, &query_points[query_remain_start], stream_remain_count, &indices[query_remain_start], &distances[query_remain_start]); | |
592 | + } | |
593 | + } | |
594 | + else { | |
595 | + unsigned int threads = (unsigned int)(query_count > 1024 ? 1024 : query_count); | |
596 | + unsigned int blocks = (unsigned int)(query_count / threads + (query_count % threads ? 1 : 0)); | |
528 | 597 | |
529 | - kdtree::point<T, D> *d_query_points; // create a pointer pointing to query points on gpu | |
530 | - size_t *d_indices; | |
531 | - T *d_distances; | |
598 | + kdtree::point<T, D> *d_query_points; // create a pointer pointing to query points on gpu | |
599 | + size_t *d_indices; | |
600 | + T *d_distances; | |
532 | 601 | |
533 | - int *next_nodes; // create two STACK-like array | |
534 | - int *next_search_nodes; | |
602 | + int *next_nodes; // create two STACK-like array | |
603 | + int *next_search_nodes; | |
535 | 604 | |
536 | - int *Judge = NULL; // judge variable to see whether one thread is overwrite another thread's memory | |
605 | + int *Judge = NULL; // judge variable to see whether one thread is overwrite another thread's memory | |
537 | 606 | |
538 | - HANDLE_ERROR(cudaMalloc((void**)&d_query_points, sizeof(T) * query_points.size() * D)); | |
539 | - HANDLE_ERROR(cudaMalloc((void**)&d_indices, sizeof(size_t) * query_points.size())); | |
540 | - HANDLE_ERROR(cudaMalloc((void**)&d_distances, sizeof(T) * query_points.size())); | |
541 | - HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * stack_size * sizeof(int))); // STACK size right now is 50, you can change it if you mean to | |
542 | - HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * stack_size * sizeof(int))); | |
543 | - HANDLE_ERROR(cudaMemcpy(d_query_points, &query_points[0], sizeof(T) * query_points.size() * D, cudaMemcpyHostToDevice)); | |
544 | - | |
545 | - search_batch<<<blocks, threads>>> (d_nodes, d_index, d_reference_points, d_query_points, query_points.size(), d_indices, d_distances, next_nodes, next_search_nodes, Judge); | |
546 | - | |
547 | - if (Judge == NULL) { // do the following work if the thread works safely | |
548 | - HANDLE_ERROR(cudaMemcpy(indices, d_indices, sizeof(size_t) * query_points.size(), cudaMemcpyDeviceToHost)); | |
549 | - HANDLE_ERROR(cudaMemcpy(distances, d_distances, sizeof(T) * query_points.size(), cudaMemcpyDeviceToHost)); | |
550 | - } | |
607 | + HANDLE_ERROR(cudaMalloc((void**)&d_query_points, sizeof(T) * query_count * D)); | |
608 | + HANDLE_ERROR(cudaMalloc((void**)&d_indices, sizeof(size_t) * query_count)); | |
609 | + HANDLE_ERROR(cudaMalloc((void**)&d_distances, sizeof(T) * query_count)); | |
610 | + HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * stack_size * sizeof(int))); // STACK size right now is 50, you can change it if you mean to | |
611 | + HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * stack_size * sizeof(int))); | |
612 | + HANDLE_ERROR(cudaMemcpy(d_query_points, &query_points[0], sizeof(T) * query_count * D, cudaMemcpyHostToDevice)); | |
613 | + | |
614 | + search_batch<<<blocks, threads>>> (d_nodes, d_index, d_reference_points, d_query_points, query_count, d_indices, d_distances, next_nodes, next_search_nodes, Judge); | |
615 | + | |
616 | + if (Judge == NULL) { // do the following work if the thread works safely | |
617 | + HANDLE_ERROR(cudaMemcpy(indices, d_indices, sizeof(size_t) * query_count, cudaMemcpyDeviceToHost)); | |
618 | + HANDLE_ERROR(cudaMemcpy(distances, d_distances, sizeof(T) * query_count, cudaMemcpyDeviceToHost)); | |
619 | + } | |
551 | 620 | |
552 | - HANDLE_ERROR(cudaFree(next_nodes)); | |
553 | - HANDLE_ERROR(cudaFree(next_search_nodes)); | |
554 | - HANDLE_ERROR(cudaFree(d_query_points)); | |
555 | - HANDLE_ERROR(cudaFree(d_indices)); | |
556 | - HANDLE_ERROR(cudaFree(d_distances)); | |
621 | + HANDLE_ERROR(cudaFree(next_nodes)); | |
622 | + HANDLE_ERROR(cudaFree(next_search_nodes)); | |
623 | + HANDLE_ERROR(cudaFree(d_query_points)); | |
624 | + HANDLE_ERROR(cudaFree(d_indices)); | |
625 | + HANDLE_ERROR(cudaFree(d_distances)); | |
626 | + } | |
557 | 627 | } |
558 | 628 | |
559 | 629 | /// Return the number of points in the KD tree | ... | ... |