Commit eae7559211457744e9fcab743bf1d204e7136e92

Authored by David Mayerich
1 parent 5db80c2e

added GPU support for permutes

Showing 1 changed file with 26 additions and 3 deletions   Show diff stats
stim/envi/binary.h
@@ -15,6 +15,12 @@ @@ -15,6 +15,12 @@
15 #include <unistd.h> 15 #include <unistd.h>
16 #endif 16 #endif
17 17
  18 +#ifdef CUDA_FOUND
  19 +//CUDA externs
  20 +void gpu_permute(char* dest, char* src, size_t sx, size_t sy, size_t sz, size_t d0, size_t d1, size_t d2, size_t typesize);
  21 +#include <stim/cuda/cudatools/error.h>
  22 +#endif
  23 +
18 namespace stim{ 24 namespace stim{
19 25
20 /// This class calculates the optimal setting for independent parameter b (batch size) for 26 /// This class calculates the optimal setting for independent parameter b (batch size) for
@@ -638,10 +644,27 @@ public: @@ -638,10 +644,27 @@ public:
638 // permutes a block of data from the current interleave to the interleave specified (re-arranged dimensions to the order specified by [d0, d1, d2]) 644 // permutes a block of data from the current interleave to the interleave specified (re-arranged dimensions to the order specified by [d0, d1, d2])
639 645
640 size_t permute(T* dest, T* src, size_t sx, size_t sy, size_t sz, size_t d0, size_t d1, size_t d2){ 646 size_t permute(T* dest, T* src, size_t sx, size_t sy, size_t sz, size_t d0, size_t d1, size_t d2){
641 - auto t0 = std::chrono::high_resolution_clock::now(); 647 + std::chrono::high_resolution_clock::time_point t0, t1;
  648 + t0 = std::chrono::high_resolution_clock::now();
  649 +
  650 +#ifdef CUDA_FOUND
  651 + T* gpu_src;
  652 + HANDLE_ERROR( cudaMalloc(&gpu_src, sx*sy*sz*sizeof(T)) );
  653 + HANDLE_ERROR( cudaMemcpy(gpu_src, src, sx*sy*sz*sizeof(T), cudaMemcpyHostToDevice) );
  654 + T* gpu_dest;
  655 + HANDLE_ERROR( cudaMalloc(&gpu_dest, sx*sy*sz*sizeof(T)) );
  656 + gpu_permute((char*)gpu_dest, (char*)gpu_src, sx, sy, sz, d0, d1, d2, sizeof(T));
  657 + HANDLE_ERROR( cudaMemcpy(dest, gpu_dest, sx*sy*sz*sizeof(T), cudaMemcpyDeviceToHost) );
  658 + HANDLE_ERROR( cudaFree(gpu_src) );
  659 + HANDLE_ERROR( cudaFree(gpu_dest) );
  660 + t1 = std::chrono::high_resolution_clock::now();
  661 + return std::chrono::duration_cast<std::chrono::milliseconds>(t1-t0).count();
  662 +
  663 +#endif
  664 +
642 size_t d[3] = {d0, d1, d2}; 665 size_t d[3] = {d0, d1, d2};
643 size_t s[3] = {sx, sy, sz}; 666 size_t s[3] = {sx, sy, sz};
644 - size_t p[3];// = {x, y, z}; 667 + size_t p[3];
645 668
646 if(d[0] == 0 && d[1] == 1 && d[2] == 2){ 669 if(d[0] == 0 && d[1] == 1 && d[2] == 2){
647 //this isn't actually a permute - just copy the data 670 //this isn't actually a permute - just copy the data
@@ -680,7 +703,7 @@ public: @@ -680,7 +703,7 @@ public:
680 } 703 }
681 } 704 }
682 } 705 }
683 - auto t1 = std::chrono::high_resolution_clock::now(); 706 + t1 = std::chrono::high_resolution_clock::now();
684 return std::chrono::duration_cast<std::chrono::milliseconds>(t1-t0).count(); 707 return std::chrono::duration_cast<std::chrono::milliseconds>(t1-t0).count();
685 } 708 }
686 709