Commit 16d171c9c122bce50bf06347e926419ec7da0f99
Merge branch 'JACK' into 'master'
add #define stack_size in kd-tree See merge request !15
Showing
1 changed file
with
11 additions
and
10 deletions
Show diff stats
stim/structures/kdtree.cuh
@@ -7,6 +7,7 @@ | @@ -7,6 +7,7 @@ | ||
7 | 7 | ||
8 | #ifndef KDTREE_H | 8 | #ifndef KDTREE_H |
9 | #define KDTREE_H | 9 | #define KDTREE_H |
10 | +#define stack_size 50 | ||
10 | 11 | ||
11 | #include "device_launch_parameters.h" | 12 | #include "device_launch_parameters.h" |
12 | #include <cuda.h> | 13 | #include <cuda.h> |
@@ -337,13 +338,13 @@ namespace stim { | @@ -337,13 +338,13 @@ namespace stim { | ||
337 | size_t best_index = 0; | 338 | size_t best_index = 0; |
338 | 339 | ||
339 | int next_nodes_pos = 0; // initialize pop out order index | 340 | int next_nodes_pos = 0; // initialize pop out order index |
340 | - next_nodes[id * 50 + next_nodes_pos] = cur; // find data that belongs to the very specific thread | 341 | + next_nodes[id * stack_size + next_nodes_pos] = cur; // find data that belongs to the very specific thread |
341 | next_nodes_pos++; | 342 | next_nodes_pos++; |
342 | 343 | ||
343 | while (next_nodes_pos) { | 344 | while (next_nodes_pos) { |
344 | int next_search_nodes_pos = 0; // record push back order index | 345 | int next_search_nodes_pos = 0; // record push back order index |
345 | while (next_nodes_pos) { | 346 | while (next_nodes_pos) { |
346 | - cur = next_nodes[id * 50 + next_nodes_pos - 1]; // pop out the last push in one and keep poping out | 347 | + cur = next_nodes[id * stack_size + next_nodes_pos - 1]; // pop out the last push in one and keep poping out |
347 | next_nodes_pos--; | 348 | next_nodes_pos--; |
348 | int split_axis = nodes[cur].level % D; | 349 | int split_axis = nodes[cur].level % D; |
349 | 350 | ||
@@ -362,20 +363,20 @@ namespace stim { | @@ -362,20 +363,20 @@ namespace stim { | ||
362 | 363 | ||
363 | if (fabs(d) > range) { | 364 | if (fabs(d) > range) { |
364 | if (d < 0) { | 365 | if (d < 0) { |
365 | - next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].left; | 366 | + next_search_nodes[id * stack_size + next_search_nodes_pos] = nodes[cur].left; |
366 | next_search_nodes_pos++; | 367 | next_search_nodes_pos++; |
367 | } | 368 | } |
368 | else { | 369 | else { |
369 | - next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].right; | 370 | + next_search_nodes[id * stack_size + next_search_nodes_pos] = nodes[cur].right; |
370 | next_search_nodes_pos++; | 371 | next_search_nodes_pos++; |
371 | } | 372 | } |
372 | } | 373 | } |
373 | else { | 374 | else { |
374 | - next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].right; | 375 | + next_search_nodes[id * stack_size + next_search_nodes_pos] = nodes[cur].right; |
375 | next_search_nodes_pos++; | 376 | next_search_nodes_pos++; |
376 | - next_search_nodes[id * 50 + next_search_nodes_pos] = nodes[cur].left; | 377 | + next_search_nodes[id * stack_size + next_search_nodes_pos] = nodes[cur].left; |
377 | next_search_nodes_pos++; | 378 | next_search_nodes_pos++; |
378 | - if (next_search_nodes_pos > 50) { | 379 | + if (next_search_nodes_pos > stack_size) { |
379 | printf("Thread conflict might be caused by thread %d, so please try smaller input max_tree_levels\n", id); | 380 | printf("Thread conflict might be caused by thread %d, so please try smaller input max_tree_levels\n", id); |
380 | (*Judge)++; | 381 | (*Judge)++; |
381 | } | 382 | } |
@@ -383,7 +384,7 @@ namespace stim { | @@ -383,7 +384,7 @@ namespace stim { | ||
383 | } | 384 | } |
384 | } | 385 | } |
385 | for (int i = 0; i < next_search_nodes_pos; i++) | 386 | for (int i = 0; i < next_search_nodes_pos; i++) |
386 | - next_nodes[id * 50 + i] = next_search_nodes[id * 50 + i]; | 387 | + next_nodes[id * stack_size + i] = next_search_nodes[id * stack_size + i]; |
387 | next_nodes_pos = next_search_nodes_pos; | 388 | next_nodes_pos = next_search_nodes_pos; |
388 | } | 389 | } |
389 | *d_distance = best_distance; | 390 | *d_distance = best_distance; |
@@ -537,8 +538,8 @@ namespace stim { | @@ -537,8 +538,8 @@ namespace stim { | ||
537 | HANDLE_ERROR(cudaMalloc((void**)&d_query_points, sizeof(T) * query_points.size() * D)); | 538 | HANDLE_ERROR(cudaMalloc((void**)&d_query_points, sizeof(T) * query_points.size() * D)); |
538 | HANDLE_ERROR(cudaMalloc((void**)&d_indices, sizeof(size_t) * query_points.size())); | 539 | HANDLE_ERROR(cudaMalloc((void**)&d_indices, sizeof(size_t) * query_points.size())); |
539 | HANDLE_ERROR(cudaMalloc((void**)&d_distances, sizeof(T) * query_points.size())); | 540 | HANDLE_ERROR(cudaMalloc((void**)&d_distances, sizeof(T) * query_points.size())); |
540 | - HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * 50 * sizeof(int))); // STACK size right now is 50, you can change it if you mean to | ||
541 | - HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * 50 * sizeof(int))); | 541 | + HANDLE_ERROR(cudaMalloc((void**)&next_nodes, threads * blocks * stack_size * sizeof(int))); // STACK size right now is 50, you can change it if you mean to |
542 | + HANDLE_ERROR(cudaMalloc((void**)&next_search_nodes, threads * blocks * stack_size * sizeof(int))); | ||
542 | HANDLE_ERROR(cudaMemcpy(d_query_points, &query_points[0], sizeof(T) * query_points.size() * D, cudaMemcpyHostToDevice)); | 543 | HANDLE_ERROR(cudaMemcpy(d_query_points, &query_points[0], sizeof(T) * query_points.size() * D, cudaMemcpyHostToDevice)); |
543 | 544 | ||
544 | search_batch<<<blocks, threads>>> (d_nodes, d_index, d_reference_points, d_query_points, query_points.size(), d_indices, d_distances, next_nodes, next_search_nodes, Judge); | 545 | search_batch<<<blocks, threads>>> (d_nodes, d_index, d_reference_points, d_query_points, query_points.size(), d_indices, d_distances, next_nodes, next_search_nodes, Judge); |