Commit a0f09cd529ce6afba149fd491f781ad8dc540e83
1 parent
fcd2eb7c
add #define stack_size in kd-tree
Showing
1 changed file
with
11 additions
and
10 deletions
Show diff stats
stim/structures/kdtree.cuh
... | ... | @@ -7,6 +7,7 @@ |
7 | 7 | |
8 | 8 | #ifndef KDTREE_H |
9 | 9 | #define KDTREE_H |
10 | +#define stack_size 50 | |
10 | 11 | |
11 | 12 | #include "device_launch_parameters.h" |
12 | 13 | #include <cuda.h> |
... | ... | @@ -337,13 +338,13 @@ namespace stim { |
337 | 338 | size_t best_index = 0; |
338 | 339 | |
339 | 340 | int next_nodes_pos = 0; // initialize pop out order index |
340 | - next_nodes[id * 50 + next_nodes_pos] = cur; // find data that belongs to the very specific thread | |
341 | + next_nodes[id * stack_size + next_nodes_pos] = cur; // find data that belongs to the very specific thread | |
341 | 342 | next_nodes_pos++; |
342 | 343 | |
343 | 344 | while (next_nodes_pos) { |
344 | 345 | int next_search_nodes_pos = 0; // record push back order index |
345 | 346 | while (next_nodes_pos) { |
346 | - cur = next_nodes[id * 50 + next_nodes_pos - 1]; // pop out the last push in one and keep poping out | |
347 | + cur = next_nodes[id * stack_size + next_nodes_pos - 1]; // pop out the last push in one and keep poping out | |
347 | 348 | next_nodes_pos--; |
348 | 349 | int split_axis = nodes[cur].level % D; |
349 | 350 | |
... | ... | @@ -362,20 +363,20 @@ namespace stim { |
362 | 363 | |
363 | 364 | if (fabs(d) > range) { |
364 | 365 | if (d < 0) { |
365 | - next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].left; | |
366 | + next_search_nodes[id * stack_size + next_search_nodes_pos] = nodes[cur].left; | |
366 | 367 | next_search_nodes_pos++; |
367 | 368 | } |
368 | 369 | else { |
369 | - next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].right; | |
370 | + next_search_nodes[id * stack_size + next_search_nodes_pos] = nodes[cur].right; | |
370 | 371 | next_search_nodes_pos++; |
371 | 372 | } |
372 | 373 | } |
373 | 374 | else { |
374 | - next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].right; | |
375 | + next_search_nodes[id * stack_size + next_search_nodes_pos] = nodes[cur].right; | |
375 | 376 | next_search_nodes_pos++; |
376 | - next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].left; | |
377 | + next_search_nodes[id * stack_size + next_search_nodes_pos] = nodes[cur].left; | |
377 | 378 | next_search_nodes_pos++; |
378 | - if (next_search_nodes_pos > 50) { | |
379 | + if (next_search_nodes_pos > stack_size) { | |
379 | 380 | printf("Thread conflict might be caused by thread %d, so please try smaller input max_tree_levels\n", id); |
380 | 381 | (*Judge)++; |
381 | 382 | } |
... | ... | @@ -383,7 +384,7 @@ namespace stim { |
383 | 384 | } |
384 | 385 | } |
385 | 386 | for (int i = 0; i < next_search_nodes_pos; i++) |
386 | - next_nodes[id * 50 + i] = next_search_nodes[id * 50 + i]; | |
387 | + next_nodes[id * stack_size + i] = next_search_nodes[id * stack_size + i]; | |
387 | 388 | next_nodes_pos = next_search_nodes_pos; |
388 | 389 | } |
389 | 390 | *d_distance = best_distance; |
... | ... | @@ -537,8 +538,8 @@ namespace stim { |
537 | 538 | HANDLE_ERROR(cudaMalloc((void**)&d_query_points, sizeof(T) * query_points.size() * D)); |
538 | 539 | HANDLE_ERROR(cudaMalloc((void**)&d_indices, sizeof(size_t) * query_points.size())); |
539 | 540 | HANDLE_ERROR(cudaMalloc((void**)&d_distances, sizeof(T) * query_points.size())); |
540 | - HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * 50 * sizeof(int))); // STACK size right now is 50, you can change it if you mean to | |
541 | - HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * 50 * sizeof(int))); | |
541 | + HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * stack_size * sizeof(int))); // STACK size right now is 50, you can change it if you mean to | |
542 | + HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * stack_size * sizeof(int))); | |
542 | 543 | HANDLE_ERROR(cudaMemcpy(d_query_points, &query_points[0], sizeof(T) * query_points.size() * D, cudaMemcpyHostToDevice)); |
543 | 544 | |
544 | 545 | search_batch<<<blocks, threads>>> (d_nodes, d_index, d_reference_points, d_query_points, query_points.size(), d_indices, d_distances, next_nodes, next_search_nodes, Judge); | ... | ... |