Commit 0a575bb1055a25de4652c3825e0393db601e2749

Authored by David Mayerich
1 parent 5435830b

updated optimization gradient descent algorithm

Showing 2 changed files with 115 additions and 24 deletions   Show diff stats
stim/envi/binary.h
@@ -21,26 +21,31 @@ namespace stim{ @@ -21,26 +21,31 @@ namespace stim{
21 /// minimizing the dependent parameter bps (bytes per second) 21 /// minimizing the dependent parameter bps (bytes per second)
22 class stream_optimizer{ 22 class stream_optimizer{
23 protected: 23 protected:
24 - size_t Bps; //bytes per second for the previous batch 24 + size_t Bps[2]; //bytes per second for the previous batch
25 size_t interval_B; //number of bytes processed this interval 25 size_t interval_B; //number of bytes processed this interval
26 size_t interval_ms; //number of milliseconds spent in the current interval 26 size_t interval_ms; //number of milliseconds spent in the current interval
27 size_t n[2]; //current batch size (in bytes) 27 size_t n[2]; //current batch size (in bytes)
  28 + size_t h; //spacing used for finite difference calculations
28 size_t dn; //delta value (in bytes) for setting the batch size (minimum change in batch parameter) 29 size_t dn; //delta value (in bytes) for setting the batch size (minimum change in batch parameter)
29 size_t maxn; //maximum value for the batch size 30 size_t maxn; //maximum value for the batch size
30 31
  32 + double alpha; //alpha value controls the factor of the gradient that is used to calculate the next point (speed of convergence)
  33 +
31 bool sample_step; //calculating the derivative (this class alternates between calculating dBps and B) 34 bool sample_step; //calculating the derivative (this class alternates between calculating dBps and B)
32 bool forward_diff; //evaluate the derivative using forward differences 35 bool forward_diff; //evaluate the derivative using forward differences
33 36
34 size_t window_ms; //size of the interval (in milliseconds) integrated to get a reliable bps value 37 size_t window_ms; //size of the interval (in milliseconds) integrated to get a reliable bps value
35 38
36 // This function rounds x to the nearest value within dB 39 // This function rounds x to the nearest value within dB
37 - size_t round_limit(size_t n0){  
38 - if(n0 > maxn) n0 = maxn; //limit the returned size of x to within the specified bounds  
39 - if(n0 < dn) n0 = dn; 40 + size_t round_limit(double n0){
  41 + if(n0 < 0) return dn; //if n0 is less than zero, return the lowest possible n
  42 +
  43 + size_t new_n = (size_t)(n0 + 0.5); //now n0 must be positive, so round it to the nearest integer
  44 + if(new_n > maxn) new_n = maxn; //limit the returned size of x to within the specified bounds
40 45
41 - size_t lowest = n0 / dn; 46 + size_t lowest = new_n / dn;
42 size_t highest = lowest + dn; 47 size_t highest = lowest + dn;
43 - size_t diff[2] = {n0 - lowest, highest - n0}; //calculate the two differences 48 + size_t diff[2] = {new_n - lowest, highest - new_n}; //calculate the two differences
44 if(diff[0] < diff[1]) 49 if(diff[0] < diff[1])
45 return lowest; 50 return lowest;
46 return highest; 51 return highest;
@@ -49,19 +54,79 @@ protected: @@ -49,19 +54,79 @@ protected:
49 public: 54 public:
50 55
51 //constructor initializes a stream optimizer 56 //constructor initializes a stream optimizer
52 - stream_optimizer(size_t current_batch_size, size_t min_batch_size, size_t max_batch_size, size_t window = 1000){  
53 - Bps = 0; //initialize to zero bytes per second processed 57 + stream_optimizer(size_t min_batch_size, size_t max_batch_size, double a = 0.0001, size_t window = 1000){
  58 + //Bps = 0; //initialize to zero bytes per second processed
  59 + Bps[0] = Bps[1] = 0; //initialize the bits per second to 0
54 interval_B = 0; //zero bytes have been processed at initialization 60 interval_B = 0; //zero bytes have been processed at initialization
55 interval_ms = 0; //no time has been spent on the batch so far 61 interval_ms = 0; //no time has been spent on the batch so far
56 dn = min_batch_size; //set the minimum batch size as the minimum change in batch size 62 dn = min_batch_size; //set the minimum batch size as the minimum change in batch size
57 maxn = max_batch_size; //set the maximum batch size 63 maxn = max_batch_size; //set the maximum batch size
58 - n[0] = current_batch_size; //set B 64 + n[0] = max_batch_size; //set B
  65 + h = (max_batch_size / min_batch_size) / 10 * dn;
  66 + std::cout<<"h = "<<h<<std::endl;
  67 + if(h < dn) h = dn;
  68 + alpha = a;
  69 + //n[0] = round_limit( (max_batch_size - min_batch_size)/2 );
59 window_ms = window; //minimum integration interval (for getting a reliable bps measure) 70 window_ms = window; //minimum integration interval (for getting a reliable bps measure)
60 sample_step = true; //the first step is to calculate the derivative 71 sample_step = true; //the first step is to calculate the derivative
61 forward_diff = true; //start with the forward difference (since we start at the maximum batch size) 72 forward_diff = true; //start with the forward difference (since we start at the maximum batch size)
62 } 73 }
63 74
64 - // this function updates the optimizer, given the number of bytes processed in an interval and time spent processing 75 + size_t update(size_t bytes_processed, size_t ms_spent){
  76 + interval_B += bytes_processed; //increment the number of bytes processed
  77 + interval_ms += ms_spent; //increment the number of milliseconds spent processing
  78 +
  79 + //if we have sufficient information to evaluate the optimization function at this point
  80 + if(interval_ms < window_ms){ //if insufficient time has passed to get a reliable Bps measurement
  81 + return n[0];
  82 + }
  83 + else{ //if we have collected enough information for a reliable Bps estimate
  84 + size_t new_Bps = interval_B / interval_ms; //calculate the current Bps
  85 +
  86 + if(Bps[0] == 0){ //if n[0] hasn't been evaluated yet, this is the first step
  87 + Bps[0] = new_Bps; //set the initial Bps value
  88 + n[1] = n[0] - h; //set the position of the next sample point
  89 + std::cout<<"Bps value at n = "<<n[0]<<" is "<<Bps[0]<<" Bps, probing n = "<<n[1]<<std::endl;
  90 + return n[1]; //return the probe point
  91 + }
  92 + else{
  93 + Bps[1] = new_Bps; //set the Bps for the current point (n[1])
  94 +
  95 + double Bps_p; //allocate a variable for the derivative
  96 + //calculate the derivative
  97 + if(n[0] < n[1]){ //if the current point is less than the previous one (probably the most common)
  98 + Bps_p = ((double)Bps[1] - (double)Bps[0]) / (double)h; //calculate the derivative using the forward finite difference
  99 + }
  100 + else{
  101 + Bps_p = ((double)Bps[0] - (double)Bps[1]) / (double)h; //calculate the derivative using the backward finite difference
  102 + }
  103 +
  104 + std::cout<<" probed n = "<<n[1]<<" with "<<Bps[1]<<" Bps, gradient = "<<Bps_p<<" Bps"<<std::endl;
  105 +
  106 + double new_n_precise = n[0] + alpha * Bps_p; //calculate the next point (snap to closest integer)
  107 + size_t new_n_nearest = round_limit(new_n_precise); //calculate the next point (given batch parameters)
  108 +
  109 + if(new_n_nearest == n[0]){ //if the newest point is the same as the original point
  110 + Bps[0] = Bps[1]; //update the Bps
  111 + //if(n[0] == dn) n[1] = n[0] + h; //if we're on the left edge, probe forward
  112 + //else n[1] = n[0] - h; //otherwise probe backwards
  113 + std::cout<<" staying at n = "<<n[0]<<" for now"<<std::endl;
  114 + //return n[1]; //return the probe point
  115 +
  116 + Bps[0] = 0; //reset the Bps for the current point
  117 + return n[0]; //return the current point for a re-calculation
  118 + }
  119 + else{ //if the newest point is different from the original point
  120 + n[0] = new_n_nearest; //move to the new point
  121 + Bps[0] = 0; //set the Bps to zero (point hasn't been tested)
  122 + std::cout<<" moving to n = "<<n[0]<<std::endl;
  123 + return n[0]; //return the new point
  124 + }
  125 + }
  126 + }
  127 + }
  128 +
  129 + /*// this function updates the optimizer, given the number of bytes processed in an interval and time spent processing
65 size_t update(size_t bytes_processed, size_t ms_spent){ 130 size_t update(size_t bytes_processed, size_t ms_spent){
66 interval_B += bytes_processed; //increment the number of bytes processed 131 interval_B += bytes_processed; //increment the number of bytes processed
67 interval_ms += ms_spent; //increment the number of milliseconds spent processing 132 interval_ms += ms_spent; //increment the number of milliseconds spent processing
@@ -98,11 +163,33 @@ public: @@ -98,11 +163,33 @@ public:
98 } 163 }
99 if(sample_step) return n[0]; 164 if(sample_step) return n[0];
100 return n[1]; //insufficient information, keep the same batch size 165 return n[1]; //insufficient information, keep the same batch size
101 - } 166 + }*/
  167 +
  168 + /*size_t update(size_t bytes_processed, size_t ms_spent){
  169 + interval_B += bytes_processed; //increment the number of bytes processed
  170 + interval_ms += ms_spent; //increment the number of milliseconds spent processing
  171 +
  172 + //if( Bps[0] == 0 ){ //if the left boundary hasn't been processed
  173 +
  174 +
  175 + //if we have sufficient information to evaluate the optimization function at this point
  176 + if(interval_ms >= window_ms){
  177 + size_t new_Bps = interval_B / interval_ms; //calculate the current Bps
  178 +
  179 + if(Bps[0] == 0) //if the left interval Bps hasn't been calculated
  180 + Bps[0] = interval_B / interval_ms; //that is the interval being processed
  181 + else
  182 + Bps[1] = interval_B / interval_ms; //otherwise the right interval is being processed
  183 +
  184 + if(Bps[0] != 0 && Bps[1] != 0){ //if both intervals have been processed
  185 +
  186 +
  187 + }
  188 + }*/
102 189
103 size_t update(size_t bytes_processed, size_t ms_spent, size_t& data_rate){ 190 size_t update(size_t bytes_processed, size_t ms_spent, size_t& data_rate){
104 size_t time = update(bytes_processed, ms_spent); 191 size_t time = update(bytes_processed, ms_spent);
105 - data_rate = Bps; 192 + data_rate = Bps[0];
106 return time; 193 return time;
107 } 194 }
108 }; 195 };
@@ -392,17 +392,24 @@ public: @@ -392,17 +392,24 @@ public:
392 bool bil(std::string outname, bool PROGRESS = false, bool VERBOSE = false){ 392 bool bil(std::string outname, bool PROGRESS = false, bool VERBOSE = false){
393 393
394 const size_t buffers = 4; //number of buffers required for this algorithm 394 const size_t buffers = 4; //number of buffers required for this algorithm
  395 +
395 size_t mem_per_batch = binary<T>::buffer_size / buffers; //calculate the maximum memory available for a batch 396 size_t mem_per_batch = binary<T>::buffer_size / buffers; //calculate the maximum memory available for a batch
396 397
397 size_t slice_bytes = X() * Z() * sizeof(T); //number of bytes in an input batch slice (Y-slice in this case) 398 size_t slice_bytes = X() * Z() * sizeof(T); //number of bytes in an input batch slice (Y-slice in this case)
398 size_t max_slices_per_batch = mem_per_batch / slice_bytes; //maximum number of slices we can process in one batch given memory constraints 399 size_t max_slices_per_batch = mem_per_batch / slice_bytes; //maximum number of slices we can process in one batch given memory constraints
  400 +
  401 + if(VERBOSE){
  402 + std::cout<<"maximum memory available for processing: "<<(double)binary<T>::buffer_size/(double)1000000<<" MB"<<std::endl;
  403 + std::cout<<" this supports a batch size of "<<max_slices_per_batch<<" Y-axis slices"<<std::endl;
  404 + }
  405 +
399 if(max_slices_per_batch == 0){ //if there is insufficient memory for a single slice, throw an error 406 if(max_slices_per_batch == 0){ //if there is insufficient memory for a single slice, throw an error
400 std::cout<<"error, insufficient memory for stim::bsq::bil()"<<std::endl; 407 std::cout<<"error, insufficient memory for stim::bsq::bil()"<<std::endl;
401 exit(1); 408 exit(1);
402 } 409 }
403 size_t max_batch_bytes = max_slices_per_batch * slice_bytes; //calculate the amount of memory that will be allocated for all four buffers 410 size_t max_batch_bytes = max_slices_per_batch * slice_bytes; //calculate the amount of memory that will be allocated for all four buffers
404 411
405 - stream_optimizer O(max_slices_per_batch, 1, max_slices_per_batch); 412 + stream_optimizer O(1, max_slices_per_batch);
406 413
407 T* src[2]; //source double-buffer for asynchronous batching 414 T* src[2]; //source double-buffer for asynchronous batching
408 src[0] = (T*) malloc(max_batch_bytes); 415 src[0] = (T*) malloc(max_batch_bytes);
@@ -421,10 +428,6 @@ public: @@ -421,10 +428,6 @@ public:
421 std::future<size_t> rthread; 428 std::future<size_t> rthread;
422 std::future<std::ostream&> wthread; //create asynchronous threads for reading and writing 429 std::future<std::ostream&> wthread; //create asynchronous threads for reading and writing
423 430
424 - //readlines(src[0], 0, N[0]); //read the first batch into the 0 source buffer  
425 - //y_load += N[0]; //increment the loaded slice counter  
426 - //int b = 1;  
427 -  
428 std::chrono::high_resolution_clock::time_point t_start, pt_start; //high-resolution timers 431 std::chrono::high_resolution_clock::time_point t_start, pt_start; //high-resolution timers
429 std::chrono::high_resolution_clock::time_point t_end, pt_end; 432 std::chrono::high_resolution_clock::time_point t_end, pt_end;
430 size_t t_batch; //number of milliseconds to process a batch 433 size_t t_batch; //number of milliseconds to process a batch
@@ -435,15 +438,15 @@ public: @@ -435,15 +438,15 @@ public:
435 size_t data_rate; 438 size_t data_rate;
436 439
437 rt_total += readlines(src[0], 0, N[0]); //read the first batch into the 0 source buffer 440 rt_total += readlines(src[0], 0, N[0]); //read the first batch into the 0 source buffer
438 - y_load += N[0]; //increment the loaded slice counter  
439 - int b = 1; //initialize the double buffer to 0 441 + y_load += N[0]; //increment the loaded slice counter
  442 + int b = 1; //initialize the double buffer to 0
440 while(y_proc < Y()){ //while there are still slices to be processed 443 while(y_proc < Y()){ //while there are still slices to be processed
441 t_start = std::chrono::high_resolution_clock::now(); //start the timer for this batch 444 t_start = std::chrono::high_resolution_clock::now(); //start the timer for this batch
442 if(y_load < Y()){ //if there are still slices to be loaded, load them 445 if(y_load < Y()){ //if there are still slices to be loaded, load them
443 - if(y_proc > 0){  
444 - N[b] = O.update(N[!b] * slice_bytes, t_batch, data_rate); //set the batch size based on optimization  
445 - std::cout<<"New N = "<<N[b]<<" at "<<(double)data_rate / 1000000<<" MB/s"<<std::endl;  
446 - } 446 + //if(y_proc > 0){
  447 +
  448 +
  449 + //}
447 if(y_load + N[b] > Y()) N[b] = Y() - y_load; //if the next batch would process more than the total slices, adjust the batch size 450 if(y_load + N[b] > Y()) N[b] = Y() - y_load; //if the next batch would process more than the total slices, adjust the batch size
448 rthread = std::async(std::launch::async, &stim::bsq<T>::readlines, this, src[b], y_load, N[b]); 451 rthread = std::async(std::launch::async, &stim::bsq<T>::readlines, this, src[b], y_load, N[b]);
449 rt_total += rthread.get(); 452 rt_total += rthread.get();
@@ -452,7 +455,6 @@ public: @@ -452,7 +455,6 @@ public:
452 455
453 b = !b; //swap the double-buffer 456 b = !b; //swap the double-buffer
454 pt_total += binary<T>::permute(dst[b], src[b], X(), N[b], Z(), 0, 2, 1); //permute the batch to a BIL file 457 pt_total += binary<T>::permute(dst[b], src[b], X(), N[b], Z(), 0, 2, 1); //permute the batch to a BIL file
455 - //target.write((char*)dst[b], N[b] * slice_bytes); //write the permuted data to the output file  
456 wt_total += writeblock(&target, dst[b], N[b] * slice_bytes); //write the permuted data to the output file 458 wt_total += writeblock(&target, dst[b], N[b] * slice_bytes); //write the permuted data to the output file
457 y_proc += N[b]; //increment the counter of processed pixels 459 y_proc += N[b]; //increment the counter of processed pixels
458 if(PROGRESS) progress = (double)( y_proc + 1 ) / Y() * 100; //increment the progress counter based on the number of processed pixels 460 if(PROGRESS) progress = (double)( y_proc + 1 ) / Y() * 100; //increment the progress counter based on the number of processed pixels
@@ -460,6 +462,8 @@ public: @@ -460,6 +462,8 @@ public:
460 t_batch = std::chrono::duration_cast<std::chrono::milliseconds>(t_end-t_start).count(); 462 t_batch = std::chrono::duration_cast<std::chrono::milliseconds>(t_end-t_start).count();
461 t_total += t_batch; 463 t_total += t_batch;
462 //if(y_load < Y()) rt_total += rthread.get(); //if a new batch was set to load, make sure it loads after calculations 464 //if(y_load < Y()) rt_total += rthread.get(); //if a new batch was set to load, make sure it loads after calculations
  465 + N[b] = O.update(N[!b] * slice_bytes, t_batch, data_rate); //set the batch size based on optimization
  466 + //std::cout<<"New N = "<<N[!b]<<" selected with "<<(double)data_rate / 1000000<<" MB/s"<<std::endl;
463 } 467 }
464 468
465 free(src[0]); //free buffer resources 469 free(src[0]); //free buffer resources