Commit eae7559211457744e9fcab743bf1d204e7136e92
1 parent
5db80c2e
added GPU support for permutes
Showing
1 changed file
with
26 additions
and
3 deletions
Show diff stats
stim/envi/binary.h
... | ... | @@ -15,6 +15,12 @@ |
15 | 15 | #include <unistd.h> |
16 | 16 | #endif |
17 | 17 | |
18 | +#ifdef CUDA_FOUND | |
19 | +//CUDA externs | |
20 | +void gpu_permute(char* dest, char* src, size_t sx, size_t sy, size_t sz, size_t d0, size_t d1, size_t d2, size_t typesize); | |
21 | +#include <stim/cuda/cudatools/error.h> | |
22 | +#endif | |
23 | + | |
18 | 24 | namespace stim{ |
19 | 25 | |
20 | 26 | /// This class calculates the optimal setting for independent parameter b (batch size) for |
... | ... | @@ -638,10 +644,27 @@ public: |
638 | 644 | // permutes a block of data from the current interleave to the interleave specified (re-arranged dimensions to the order specified by [d0, d1, d2]) |
639 | 645 | |
640 | 646 | size_t permute(T* dest, T* src, size_t sx, size_t sy, size_t sz, size_t d0, size_t d1, size_t d2){ |
641 | - auto t0 = std::chrono::high_resolution_clock::now(); | |
647 | + std::chrono::high_resolution_clock::time_point t0, t1; | |
648 | + t0 = std::chrono::high_resolution_clock::now(); | |
649 | + | |
650 | +#ifdef CUDA_FOUND | |
651 | + T* gpu_src; | |
652 | + HANDLE_ERROR( cudaMalloc(&gpu_src, sx*sy*sz*sizeof(T)) ); | |
653 | + HANDLE_ERROR( cudaMemcpy(gpu_src, src, sx*sy*sz*sizeof(T), cudaMemcpyHostToDevice) ); | |
654 | + T* gpu_dest; | |
655 | + HANDLE_ERROR( cudaMalloc(&gpu_dest, sx*sy*sz*sizeof(T)) ); | |
656 | + gpu_permute((char*)gpu_dest, (char*)gpu_src, sx, sy, sz, d0, d1, d2, sizeof(T)); | |
657 | + HANDLE_ERROR( cudaMemcpy(dest, gpu_dest, sx*sy*sz*sizeof(T), cudaMemcpyDeviceToHost) ); | |
658 | + HANDLE_ERROR( cudaFree(gpu_src) ); | |
659 | + HANDLE_ERROR( cudaFree(gpu_dest) ); | |
660 | + t1 = std::chrono::high_resolution_clock::now(); | |
661 | + return std::chrono::duration_cast<std::chrono::milliseconds>(t1-t0).count(); | |
662 | + | |
663 | +#endif | |
664 | + | |
642 | 665 | size_t d[3] = {d0, d1, d2}; |
643 | 666 | size_t s[3] = {sx, sy, sz}; |
644 | - size_t p[3];// = {x, y, z}; | |
667 | + size_t p[3]; | |
645 | 668 | |
646 | 669 | if(d[0] == 0 && d[1] == 1 && d[2] == 2){ |
647 | 670 | //this isn't actually a permute - just copy the data |
... | ... | @@ -680,7 +703,7 @@ public: |
680 | 703 | } |
681 | 704 | } |
682 | 705 | } |
683 | - auto t1 = std::chrono::high_resolution_clock::now(); | |
706 | + t1 = std::chrono::high_resolution_clock::now(); | |
684 | 707 | return std::chrono::duration_cast<std::chrono::milliseconds>(t1-t0).count(); |
685 | 708 | } |
686 | 709 | ... | ... |