Commit eae7559211457744e9fcab743bf1d204e7136e92

Authored by David Mayerich
1 parent 5db80c2e

added GPU support for permutes

Showing 1 changed file with 26 additions and 3 deletions   Show diff stats
stim/envi/binary.h
... ... @@ -15,6 +15,12 @@
15 15 #include <unistd.h>
16 16 #endif
17 17  
  18 +#ifdef CUDA_FOUND
  19 +//CUDA externs
  20 +void gpu_permute(char* dest, char* src, size_t sx, size_t sy, size_t sz, size_t d0, size_t d1, size_t d2, size_t typesize);
  21 +#include <stim/cuda/cudatools/error.h>
  22 +#endif
  23 +
18 24 namespace stim{
19 25  
20 26 /// This class calculates the optimal setting for independent parameter b (batch size) for
... ... @@ -638,10 +644,27 @@ public:
638 644 // permutes a block of data from the current interleave to the interleave specified (re-arranged dimensions to the order specified by [d0, d1, d2])
639 645  
640 646 size_t permute(T* dest, T* src, size_t sx, size_t sy, size_t sz, size_t d0, size_t d1, size_t d2){
641   - auto t0 = std::chrono::high_resolution_clock::now();
  647 + std::chrono::high_resolution_clock::time_point t0, t1;
  648 + t0 = std::chrono::high_resolution_clock::now();
  649 +
  650 +#ifdef CUDA_FOUND
  651 + T* gpu_src;
  652 + HANDLE_ERROR( cudaMalloc(&gpu_src, sx*sy*sz*sizeof(T)) );
  653 + HANDLE_ERROR( cudaMemcpy(gpu_src, src, sx*sy*sz*sizeof(T), cudaMemcpyHostToDevice) );
  654 + T* gpu_dest;
  655 + HANDLE_ERROR( cudaMalloc(&gpu_dest, sx*sy*sz*sizeof(T)) );
  656 + gpu_permute((char*)gpu_dest, (char*)gpu_src, sx, sy, sz, d0, d1, d2, sizeof(T));
  657 + HANDLE_ERROR( cudaMemcpy(dest, gpu_dest, sx*sy*sz*sizeof(T), cudaMemcpyDeviceToHost) );
  658 + HANDLE_ERROR( cudaFree(gpu_src) );
  659 + HANDLE_ERROR( cudaFree(gpu_dest) );
  660 + t1 = std::chrono::high_resolution_clock::now();
  661 + return std::chrono::duration_cast<std::chrono::milliseconds>(t1-t0).count();
  662 +
  663 +#endif
  664 +
642 665 size_t d[3] = {d0, d1, d2};
643 666 size_t s[3] = {sx, sy, sz};
644   - size_t p[3];// = {x, y, z};
  667 + size_t p[3];
645 668  
646 669 if(d[0] == 0 && d[1] == 1 && d[2] == 2){
647 670 //this isn't actually a permute - just copy the data
... ... @@ -680,7 +703,7 @@ public:
680 703 }
681 704 }
682 705 }
683   - auto t1 = std::chrono::high_resolution_clock::now();
  706 + t1 = std::chrono::high_resolution_clock::now();
684 707 return std::chrono::duration_cast<std::chrono::milliseconds>(t1-t0).count();
685 708 }
686 709  
... ...