Commit 16d171c9c122bce50bf06347e926419ec7da0f99

Authored by David Mayerich
2 parents fcd2eb7c a0f09cd5

Merge branch 'JACK' into 'master'

add #define stack_size in kd-tree

See merge request !15
Showing 1 changed file with 11 additions and 10 deletions   Show diff stats
stim/structures/kdtree.cuh
@@ -7,6 +7,7 @@ @@ -7,6 +7,7 @@
7 7
8 #ifndef KDTREE_H 8 #ifndef KDTREE_H
9 #define KDTREE_H 9 #define KDTREE_H
  10 +#define stack_size 50
10 11
11 #include "device_launch_parameters.h" 12 #include "device_launch_parameters.h"
12 #include <cuda.h> 13 #include <cuda.h>
@@ -337,13 +338,13 @@ namespace stim { @@ -337,13 +338,13 @@ namespace stim {
337 size_t best_index = 0; 338 size_t best_index = 0;
338 339
339 int next_nodes_pos = 0; // initialize pop out order index 340 int next_nodes_pos = 0; // initialize pop out order index
340 - next_nodes[id * 50 + next_nodes_pos] = cur; // find data that belongs to the very specific thread 341 + next_nodes[id * stack_size + next_nodes_pos] = cur; // find data that belongs to the very specific thread
341 next_nodes_pos++; 342 next_nodes_pos++;
342 343
343 while (next_nodes_pos) { 344 while (next_nodes_pos) {
344 int next_search_nodes_pos = 0; // record push back order index 345 int next_search_nodes_pos = 0; // record push back order index
345 while (next_nodes_pos) { 346 while (next_nodes_pos) {
346 - cur = next_nodes[id * 50 + next_nodes_pos - 1]; // pop out the last push in one and keep poping out 347 + cur = next_nodes[id * stack_size + next_nodes_pos - 1]; // pop out the last push in one and keep poping out
347 next_nodes_pos--; 348 next_nodes_pos--;
348 int split_axis = nodes[cur].level % D; 349 int split_axis = nodes[cur].level % D;
349 350
@@ -362,20 +363,20 @@ namespace stim { @@ -362,20 +363,20 @@ namespace stim {
362 363
363 if (fabs(d) > range) { 364 if (fabs(d) > range) {
364 if (d < 0) { 365 if (d < 0) {
365 - next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].left; 366 + next_search_nodes[id * stack_size + next_search_nodes_pos] = nodes[cur].left;
366 next_search_nodes_pos++; 367 next_search_nodes_pos++;
367 } 368 }
368 else { 369 else {
369 - next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].right; 370 + next_search_nodes[id * stack_size + next_search_nodes_pos] = nodes[cur].right;
370 next_search_nodes_pos++; 371 next_search_nodes_pos++;
371 } 372 }
372 } 373 }
373 else { 374 else {
374 - next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].right; 375 + next_search_nodes[id * stack_size + next_search_nodes_pos] = nodes[cur].right;
375 next_search_nodes_pos++; 376 next_search_nodes_pos++;
376 - next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].left; 377 + next_search_nodes[id * stack_size + next_search_nodes_pos] = nodes[cur].left;
377 next_search_nodes_pos++; 378 next_search_nodes_pos++;
378 - if (next_search_nodes_pos > 50) { 379 + if (next_search_nodes_pos > stack_size) {
379 printf("Thread conflict might be caused by thread %d, so please try smaller input max_tree_levels\n", id); 380 printf("Thread conflict might be caused by thread %d, so please try smaller input max_tree_levels\n", id);
380 (*Judge)++; 381 (*Judge)++;
381 } 382 }
@@ -383,7 +384,7 @@ namespace stim { @@ -383,7 +384,7 @@ namespace stim {
383 } 384 }
384 } 385 }
385 for (int i = 0; i < next_search_nodes_pos; i++) 386 for (int i = 0; i < next_search_nodes_pos; i++)
386 - next_nodes[id * 50 + i] = next_search_nodes[id * 50 + i]; 387 + next_nodes[id * stack_size + i] = next_search_nodes[id * stack_size + i];
387 next_nodes_pos = next_search_nodes_pos; 388 next_nodes_pos = next_search_nodes_pos;
388 } 389 }
389 *d_distance = best_distance; 390 *d_distance = best_distance;
@@ -537,8 +538,8 @@ namespace stim { @@ -537,8 +538,8 @@ namespace stim {
537 HANDLE_ERROR(cudaMalloc((void**)&d_query_points, sizeof(T) * query_points.size() * D)); 538 HANDLE_ERROR(cudaMalloc((void**)&d_query_points, sizeof(T) * query_points.size() * D));
538 HANDLE_ERROR(cudaMalloc((void**)&d_indices, sizeof(size_t) * query_points.size())); 539 HANDLE_ERROR(cudaMalloc((void**)&d_indices, sizeof(size_t) * query_points.size()));
539 HANDLE_ERROR(cudaMalloc((void**)&d_distances, sizeof(T) * query_points.size())); 540 HANDLE_ERROR(cudaMalloc((void**)&d_distances, sizeof(T) * query_points.size()));
540 - HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * 50 * sizeof(int))); // STACK size right now is 50, you can change it if you mean to  
541 - HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * 50 * sizeof(int))); 541 + HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * stack_size * sizeof(int))); // STACK size right now is 50, you can change it if you mean to
  542 + HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * stack_size * sizeof(int)));
542 HANDLE_ERROR(cudaMemcpy(d_query_points, &query_points[0], sizeof(T) * query_points.size() * D, cudaMemcpyHostToDevice)); 543 HANDLE_ERROR(cudaMemcpy(d_query_points, &query_points[0], sizeof(T) * query_points.size() * D, cudaMemcpyHostToDevice));
543 544
544 search_batch<<<blocks, threads>>> (d_nodes, d_index, d_reference_points, d_query_points, query_points.size(), d_indices, d_distances, next_nodes, next_search_nodes, Judge); 545 search_batch<<<blocks, threads>>> (d_nodes, d_index, d_reference_points, d_query_points, query_points.size(), d_indices, d_distances, next_nodes, next_search_nodes, Judge);