Commit 0a575bb1055a25de4652c3825e0393db601e2749

Authored by David Mayerich
1 parent 5435830b

updated optimization gradient descent algorithm

Showing 2 changed files with 115 additions and 24 deletions   Show diff stats
stim/envi/binary.h
... ... @@ -21,26 +21,31 @@ namespace stim{
21 21 /// minimizing the dependent parameter bps (bytes per second)
22 22 class stream_optimizer{
23 23 protected:
24   - size_t Bps; //bytes per second for the previous batch
  24 + size_t Bps[2]; //bytes per second for the previous batch
25 25 size_t interval_B; //number of bytes processed this interval
26 26 size_t interval_ms; //number of milliseconds spent in the current interval
27 27 size_t n[2]; //current batch size (in bytes)
  28 + size_t h; //spacing used for finite difference calculations
28 29 size_t dn; //delta value (in bytes) for setting the batch size (minimum change in batch parameter)
29 30 size_t maxn; //maximum value for the batch size
30 31  
  32 + double alpha; //alpha value controls the factor of the gradient that is used to calculate the next point (speed of convergence)
  33 +
31 34 bool sample_step; //calculating the derivative (this class alternates between calculating dBps and B)
32 35 bool forward_diff; //evaluate the derivative using forward differences
33 36  
34 37 size_t window_ms; //size of the interval (in milliseconds) integrated to get a reliable bps value
35 38  
36 39 // This function rounds x to the nearest value within dB
37   - size_t round_limit(size_t n0){
38   - if(n0 > maxn) n0 = maxn; //limit the returned size of x to within the specified bounds
39   - if(n0 < dn) n0 = dn;
  40 + size_t round_limit(double n0){
  41 + if(n0 < 0) return dn; //if n0 is less than zero, return the lowest possible n
  42 +
  43 + size_t new_n = (size_t)(n0 + 0.5); //now n0 must be positive, so round it to the nearest integer
  44 + if(new_n > maxn) new_n = maxn; //limit the returned size of x to within the specified bounds
40 45  
41   - size_t lowest = n0 / dn;
  46 + size_t lowest = new_n / dn;
42 47 size_t highest = lowest + dn;
43   - size_t diff[2] = {n0 - lowest, highest - n0}; //calculate the two differences
  48 + size_t diff[2] = {new_n - lowest, highest - new_n}; //calculate the two differences
44 49 if(diff[0] < diff[1])
45 50 return lowest;
46 51 return highest;
... ... @@ -49,19 +54,79 @@ protected:
49 54 public:
50 55  
51 56 //constructor initializes a stream optimizer
52   - stream_optimizer(size_t current_batch_size, size_t min_batch_size, size_t max_batch_size, size_t window = 1000){
53   - Bps = 0; //initialize to zero bytes per second processed
  57 + stream_optimizer(size_t min_batch_size, size_t max_batch_size, double a = 0.0001, size_t window = 1000){
  58 + //Bps = 0; //initialize to zero bytes per second processed
  59 + Bps[0] = Bps[1] = 0; //initialize the bits per second to 0
54 60 interval_B = 0; //zero bytes have been processed at initialization
55 61 interval_ms = 0; //no time has been spent on the batch so far
56 62 dn = min_batch_size; //set the minimum batch size as the minimum change in batch size
57 63 maxn = max_batch_size; //set the maximum batch size
58   - n[0] = current_batch_size; //set B
  64 + n[0] = max_batch_size; //set B
  65 + h = (max_batch_size / min_batch_size) / 10 * dn;
  66 + std::cout<<"h = "<<h<<std::endl;
  67 + if(h < dn) h = dn;
  68 + alpha = a;
  69 + //n[0] = round_limit( (max_batch_size - min_batch_size)/2 );
59 70 window_ms = window; //minimum integration interval (for getting a reliable bps measure)
60 71 sample_step = true; //the first step is to calculate the derivative
61 72 forward_diff = true; //start with the forward difference (since we start at the maximum batch size)
62 73 }
63 74  
64   - // this function updates the optimizer, given the number of bytes processed in an interval and time spent processing
  75 + size_t update(size_t bytes_processed, size_t ms_spent){
  76 + interval_B += bytes_processed; //increment the number of bytes processed
  77 + interval_ms += ms_spent; //increment the number of milliseconds spent processing
  78 +
  79 + //if we have sufficient information to evaluate the optimization function at this point
  80 + if(interval_ms < window_ms){ //if insufficient time has passed to get a reliable Bps measurement
  81 + return n[0];
  82 + }
  83 + else{ //if we have collected enough information for a reliable Bps estimate
  84 + size_t new_Bps = interval_B / interval_ms; //calculate the current Bps
  85 +
  86 + if(Bps[0] == 0){ //if n[0] hasn't been evaluated yet, this is the first step
  87 + Bps[0] = new_Bps; //set the initial Bps value
  88 + n[1] = n[0] - h; //set the position of the next sample point
  89 + std::cout<<"Bps value at n = "<<n[0]<<" is "<<Bps[0]<<" Bps, probing n = "<<n[1]<<std::endl;
  90 + return n[1]; //return the probe point
  91 + }
  92 + else{
  93 + Bps[1] = new_Bps; //set the Bps for the current point (n[1])
  94 +
  95 + double Bps_p; //allocate a variable for the derivative
  96 + //calculate the derivative
  97 + if(n[0] < n[1]){ //if the current point is less than the previous one (probably the most common)
  98 + Bps_p = ((double)Bps[1] - (double)Bps[0]) / (double)h; //calculate the derivative using the forward finite difference
  99 + }
  100 + else{
  101 + Bps_p = ((double)Bps[0] - (double)Bps[1]) / (double)h; //calculate the derivative using the backward finite difference
  102 + }
  103 +
  104 + std::cout<<" probed n = "<<n[1]<<" with "<<Bps[1]<<" Bps, gradient = "<<Bps_p<<" Bps"<<std::endl;
  105 +
  106 + double new_n_precise = n[0] + alpha * Bps_p; //calculate the next point (snap to closest integer)
  107 + size_t new_n_nearest = round_limit(new_n_precise); //calculate the next point (given batch parameters)
  108 +
  109 + if(new_n_nearest == n[0]){ //if the newest point is the same as the original point
  110 + Bps[0] = Bps[1]; //update the Bps
  111 + //if(n[0] == dn) n[1] = n[0] + h; //if we're on the left edge, probe forward
  112 + //else n[1] = n[0] - h; //otherwise probe backwards
  113 + std::cout<<" staying at n = "<<n[0]<<" for now"<<std::endl;
  114 + //return n[1]; //return the probe point
  115 +
  116 + Bps[0] = 0; //reset the Bps for the current point
  117 + return n[0]; //return the current point for a re-calculation
  118 + }
  119 + else{ //if the newest point is different from the original point
  120 + n[0] = new_n_nearest; //move to the new point
  121 + Bps[0] = 0; //set the Bps to zero (point hasn't been tested)
  122 + std::cout<<" moving to n = "<<n[0]<<std::endl;
  123 + return n[0]; //return the new point
  124 + }
  125 + }
  126 + }
  127 + }
  128 +
  129 + /*// this function updates the optimizer, given the number of bytes processed in an interval and time spent processing
65 130 size_t update(size_t bytes_processed, size_t ms_spent){
66 131 interval_B += bytes_processed; //increment the number of bytes processed
67 132 interval_ms += ms_spent; //increment the number of milliseconds spent processing
... ... @@ -98,11 +163,33 @@ public:
98 163 }
99 164 if(sample_step) return n[0];
100 165 return n[1]; //insufficient information, keep the same batch size
101   - }
  166 + }*/
  167 +
  168 + /*size_t update(size_t bytes_processed, size_t ms_spent){
  169 + interval_B += bytes_processed; //increment the number of bytes processed
  170 + interval_ms += ms_spent; //increment the number of milliseconds spent processing
  171 +
  172 + //if( Bps[0] == 0 ){ //if the left boundary hasn't been processed
  173 +
  174 +
  175 + //if we have sufficient information to evaluate the optimization function at this point
  176 + if(interval_ms >= window_ms){
  177 + size_t new_Bps = interval_B / interval_ms; //calculate the current Bps
  178 +
  179 + if(Bps[0] == 0) //if the left interval Bps hasn't been calculated
  180 + Bps[0] = interval_B / interval_ms; //that is the interval being processed
  181 + else
  182 + Bps[1] = interval_B / interval_ms; //otherwise the right interval is being processed
  183 +
  184 + if(Bps[0] != 0 && Bps[1] != 0){ //if both intervals have been processed
  185 +
  186 +
  187 + }
  188 + }*/
102 189  
103 190 size_t update(size_t bytes_processed, size_t ms_spent, size_t& data_rate){
104 191 size_t time = update(bytes_processed, ms_spent);
105   - data_rate = Bps;
  192 + data_rate = Bps[0];
106 193 return time;
107 194 }
108 195 };
... ...
stim/envi/bsq.h
... ... @@ -392,17 +392,24 @@ public:
392 392 bool bil(std::string outname, bool PROGRESS = false, bool VERBOSE = false){
393 393  
394 394 const size_t buffers = 4; //number of buffers required for this algorithm
  395 +
395 396 size_t mem_per_batch = binary<T>::buffer_size / buffers; //calculate the maximum memory available for a batch
396 397  
397 398 size_t slice_bytes = X() * Z() * sizeof(T); //number of bytes in an input batch slice (Y-slice in this case)
398 399 size_t max_slices_per_batch = mem_per_batch / slice_bytes; //maximum number of slices we can process in one batch given memory constraints
  400 +
  401 + if(VERBOSE){
  402 + std::cout<<"maximum memory available for processing: "<<(double)binary<T>::buffer_size/(double)1000000<<" MB"<<std::endl;
  403 + std::cout<<" this supports a batch size of "<<max_slices_per_batch<<" Y-axis slices"<<std::endl;
  404 + }
  405 +
399 406 if(max_slices_per_batch == 0){ //if there is insufficient memory for a single slice, throw an error
400 407 std::cout<<"error, insufficient memory for stim::bsq::bil()"<<std::endl;
401 408 exit(1);
402 409 }
403 410 size_t max_batch_bytes = max_slices_per_batch * slice_bytes; //calculate the amount of memory that will be allocated for all four buffers
404 411  
405   - stream_optimizer O(max_slices_per_batch, 1, max_slices_per_batch);
  412 + stream_optimizer O(1, max_slices_per_batch);
406 413  
407 414 T* src[2]; //source double-buffer for asynchronous batching
408 415 src[0] = (T*) malloc(max_batch_bytes);
... ... @@ -421,10 +428,6 @@ public:
421 428 std::future<size_t> rthread;
422 429 std::future<std::ostream&> wthread; //create asynchronous threads for reading and writing
423 430  
424   - //readlines(src[0], 0, N[0]); //read the first batch into the 0 source buffer
425   - //y_load += N[0]; //increment the loaded slice counter
426   - //int b = 1;
427   -
428 431 std::chrono::high_resolution_clock::time_point t_start, pt_start; //high-resolution timers
429 432 std::chrono::high_resolution_clock::time_point t_end, pt_end;
430 433 size_t t_batch; //number of milliseconds to process a batch
... ... @@ -435,15 +438,15 @@ public:
435 438 size_t data_rate;
436 439  
437 440 rt_total += readlines(src[0], 0, N[0]); //read the first batch into the 0 source buffer
438   - y_load += N[0]; //increment the loaded slice counter
439   - int b = 1; //initialize the double buffer to 0
  441 + y_load += N[0]; //increment the loaded slice counter
  442 + int b = 1; //initialize the double buffer to 0
440 443 while(y_proc < Y()){ //while there are still slices to be processed
441 444 t_start = std::chrono::high_resolution_clock::now(); //start the timer for this batch
442 445 if(y_load < Y()){ //if there are still slices to be loaded, load them
443   - if(y_proc > 0){
444   - N[b] = O.update(N[!b] * slice_bytes, t_batch, data_rate); //set the batch size based on optimization
445   - std::cout<<"New N = "<<N[b]<<" at "<<(double)data_rate / 1000000<<" MB/s"<<std::endl;
446   - }
  446 + //if(y_proc > 0){
  447 +
  448 +
  449 + //}
447 450 if(y_load + N[b] > Y()) N[b] = Y() - y_load; //if the next batch would process more than the total slices, adjust the batch size
448 451 rthread = std::async(std::launch::async, &stim::bsq<T>::readlines, this, src[b], y_load, N[b]);
449 452 rt_total += rthread.get();
... ... @@ -452,7 +455,6 @@ public:
452 455  
453 456 b = !b; //swap the double-buffer
454 457 pt_total += binary<T>::permute(dst[b], src[b], X(), N[b], Z(), 0, 2, 1); //permute the batch to a BIL file
455   - //target.write((char*)dst[b], N[b] * slice_bytes); //write the permuted data to the output file
456 458 wt_total += writeblock(&target, dst[b], N[b] * slice_bytes); //write the permuted data to the output file
457 459 y_proc += N[b]; //increment the counter of processed pixels
458 460 if(PROGRESS) progress = (double)( y_proc + 1 ) / Y() * 100; //increment the progress counter based on the number of processed pixels
... ... @@ -460,6 +462,8 @@ public:
460 462 t_batch = std::chrono::duration_cast<std::chrono::milliseconds>(t_end-t_start).count();
461 463 t_total += t_batch;
462 464 //if(y_load < Y()) rt_total += rthread.get(); //if a new batch was set to load, make sure it loads after calculations
  465 + N[b] = O.update(N[!b] * slice_bytes, t_batch, data_rate); //set the batch size based on optimization
  466 + //std::cout<<"New N = "<<N[!b]<<" selected with "<<(double)data_rate / 1000000<<" MB/s"<<std::endl;
463 467 }
464 468  
465 469 free(src[0]); //free buffer resources
... ...