Commit 0a575bb1055a25de4652c3825e0393db601e2749
1 parent
5435830b
updated optimization gradient descent algorithm
Showing
2 changed files
with
115 additions
and
24 deletions
Show diff stats
stim/envi/binary.h
... | ... | @@ -21,26 +21,31 @@ namespace stim{ |
21 | 21 | /// minimizing the dependent parameter bps (bytes per second) |
22 | 22 | class stream_optimizer{ |
23 | 23 | protected: |
24 | - size_t Bps; //bytes per second for the previous batch | |
24 | + size_t Bps[2]; //bytes per second for the previous batch | |
25 | 25 | size_t interval_B; //number of bytes processed this interval |
26 | 26 | size_t interval_ms; //number of milliseconds spent in the current interval |
27 | 27 | size_t n[2]; //current batch size (in bytes) |
28 | + size_t h; //spacing used for finite difference calculations | |
28 | 29 | size_t dn; //delta value (in bytes) for setting the batch size (minimum change in batch parameter) |
29 | 30 | size_t maxn; //maximum value for the batch size |
30 | 31 | |
32 | + double alpha; //alpha value controls the factor of the gradient that is used to calculate the next point (speed of convergence) | |
33 | + | |
31 | 34 | bool sample_step; //calculating the derivative (this class alternates between calculating dBps and B) |
32 | 35 | bool forward_diff; //evaluate the derivative using forward differences |
33 | 36 | |
34 | 37 | size_t window_ms; //size of the interval (in milliseconds) integrated to get a reliable bps value |
35 | 38 | |
36 | 39 | // This function rounds x to the nearest value within dB |
37 | - size_t round_limit(size_t n0){ | |
38 | - if(n0 > maxn) n0 = maxn; //limit the returned size of x to within the specified bounds | |
39 | - if(n0 < dn) n0 = dn; | |
40 | + size_t round_limit(double n0){ | |
41 | + if(n0 < 0) return dn; //if n0 is less than zero, return the lowest possible n | |
42 | + | |
43 | + size_t new_n = (size_t)(n0 + 0.5); //now n0 must be positive, so round it to the nearest integer | |
44 | + if(new_n > maxn) new_n = maxn; //limit the returned size of x to within the specified bounds | |
40 | 45 | |
41 | - size_t lowest = n0 / dn; | |
46 | + size_t lowest = new_n / dn; | |
42 | 47 | size_t highest = lowest + dn; |
43 | - size_t diff[2] = {n0 - lowest, highest - n0}; //calculate the two differences | |
48 | + size_t diff[2] = {new_n - lowest, highest - new_n}; //calculate the two differences | |
44 | 49 | if(diff[0] < diff[1]) |
45 | 50 | return lowest; |
46 | 51 | return highest; |
... | ... | @@ -49,19 +54,79 @@ protected: |
49 | 54 | public: |
50 | 55 | |
51 | 56 | //constructor initializes a stream optimizer |
52 | - stream_optimizer(size_t current_batch_size, size_t min_batch_size, size_t max_batch_size, size_t window = 1000){ | |
53 | - Bps = 0; //initialize to zero bytes per second processed | |
57 | + stream_optimizer(size_t min_batch_size, size_t max_batch_size, double a = 0.0001, size_t window = 1000){ | |
58 | + //Bps = 0; //initialize to zero bytes per second processed | |
59 | + Bps[0] = Bps[1] = 0; //initialize the bits per second to 0 | |
54 | 60 | interval_B = 0; //zero bytes have been processed at initialization |
55 | 61 | interval_ms = 0; //no time has been spent on the batch so far |
56 | 62 | dn = min_batch_size; //set the minimum batch size as the minimum change in batch size |
57 | 63 | maxn = max_batch_size; //set the maximum batch size |
58 | - n[0] = current_batch_size; //set B | |
64 | + n[0] = max_batch_size; //set B | |
65 | + h = (max_batch_size / min_batch_size) / 10 * dn; | |
66 | + std::cout<<"h = "<<h<<std::endl; | |
67 | + if(h < dn) h = dn; | |
68 | + alpha = a; | |
69 | + //n[0] = round_limit( (max_batch_size - min_batch_size)/2 ); | |
59 | 70 | window_ms = window; //minimum integration interval (for getting a reliable bps measure) |
60 | 71 | sample_step = true; //the first step is to calculate the derivative |
61 | 72 | forward_diff = true; //start with the forward difference (since we start at the maximum batch size) |
62 | 73 | } |
63 | 74 | |
64 | - // this function updates the optimizer, given the number of bytes processed in an interval and time spent processing | |
75 | + size_t update(size_t bytes_processed, size_t ms_spent){ | |
76 | + interval_B += bytes_processed; //increment the number of bytes processed | |
77 | + interval_ms += ms_spent; //increment the number of milliseconds spent processing | |
78 | + | |
79 | + //if we have sufficient information to evaluate the optimization function at this point | |
80 | + if(interval_ms < window_ms){ //if insufficient time has passed to get a reliable Bps measurement | |
81 | + return n[0]; | |
82 | + } | |
83 | + else{ //if we have collected enough information for a reliable Bps estimate | |
84 | + size_t new_Bps = interval_B / interval_ms; //calculate the current Bps | |
85 | + | |
86 | + if(Bps[0] == 0){ //if n[0] hasn't been evaluated yet, this is the first step | |
87 | + Bps[0] = new_Bps; //set the initial Bps value | |
88 | + n[1] = n[0] - h; //set the position of the next sample point | |
89 | + std::cout<<"Bps value at n = "<<n[0]<<" is "<<Bps[0]<<" Bps, probing n = "<<n[1]<<std::endl; | |
90 | + return n[1]; //return the probe point | |
91 | + } | |
92 | + else{ | |
93 | + Bps[1] = new_Bps; //set the Bps for the current point (n[1]) | |
94 | + | |
95 | + double Bps_p; //allocate a variable for the derivative | |
96 | + //calculate the derivative | |
97 | + if(n[0] < n[1]){ //if the current point is less than the previous one (probably the most common) | |
98 | + Bps_p = ((double)Bps[1] - (double)Bps[0]) / (double)h; //calculate the derivative using the forward finite difference | |
99 | + } | |
100 | + else{ | |
101 | + Bps_p = ((double)Bps[0] - (double)Bps[1]) / (double)h; //calculate the derivative using the backward finite difference | |
102 | + } | |
103 | + | |
104 | + std::cout<<" probed n = "<<n[1]<<" with "<<Bps[1]<<" Bps, gradient = "<<Bps_p<<" Bps"<<std::endl; | |
105 | + | |
106 | + double new_n_precise = n[0] + alpha * Bps_p; //calculate the next point (snap to closest integer) | |
107 | + size_t new_n_nearest = round_limit(new_n_precise); //calculate the next point (given batch parameters) | |
108 | + | |
109 | + if(new_n_nearest == n[0]){ //if the newest point is the same as the original point | |
110 | + Bps[0] = Bps[1]; //update the Bps | |
111 | + //if(n[0] == dn) n[1] = n[0] + h; //if we're on the left edge, probe forward | |
112 | + //else n[1] = n[0] - h; //otherwise probe backwards | |
113 | + std::cout<<" staying at n = "<<n[0]<<" for now"<<std::endl; | |
114 | + //return n[1]; //return the probe point | |
115 | + | |
116 | + Bps[0] = 0; //reset the Bps for the current point | |
117 | + return n[0]; //return the current point for a re-calculation | |
118 | + } | |
119 | + else{ //if the newest point is different from the original point | |
120 | + n[0] = new_n_nearest; //move to the new point | |
121 | + Bps[0] = 0; //set the Bps to zero (point hasn't been tested) | |
122 | + std::cout<<" moving to n = "<<n[0]<<std::endl; | |
123 | + return n[0]; //return the new point | |
124 | + } | |
125 | + } | |
126 | + } | |
127 | + } | |
128 | + | |
129 | + /*// this function updates the optimizer, given the number of bytes processed in an interval and time spent processing | |
65 | 130 | size_t update(size_t bytes_processed, size_t ms_spent){ |
66 | 131 | interval_B += bytes_processed; //increment the number of bytes processed |
67 | 132 | interval_ms += ms_spent; //increment the number of milliseconds spent processing |
... | ... | @@ -98,11 +163,33 @@ public: |
98 | 163 | } |
99 | 164 | if(sample_step) return n[0]; |
100 | 165 | return n[1]; //insufficient information, keep the same batch size |
101 | - } | |
166 | + }*/ | |
167 | + | |
168 | + /*size_t update(size_t bytes_processed, size_t ms_spent){ | |
169 | + interval_B += bytes_processed; //increment the number of bytes processed | |
170 | + interval_ms += ms_spent; //increment the number of milliseconds spent processing | |
171 | + | |
172 | + //if( Bps[0] == 0 ){ //if the left boundary hasn't been processed | |
173 | + | |
174 | + | |
175 | + //if we have sufficient information to evaluate the optimization function at this point | |
176 | + if(interval_ms >= window_ms){ | |
177 | + size_t new_Bps = interval_B / interval_ms; //calculate the current Bps | |
178 | + | |
179 | + if(Bps[0] == 0) //if the left interval Bps hasn't been calculated | |
180 | + Bps[0] = interval_B / interval_ms; //that is the interval being processed | |
181 | + else | |
182 | + Bps[1] = interval_B / interval_ms; //otherwise the right interval is being processed | |
183 | + | |
184 | + if(Bps[0] != 0 && Bps[1] != 0){ //if both intervals have been processed | |
185 | + | |
186 | + | |
187 | + } | |
188 | + }*/ | |
102 | 189 | |
103 | 190 | size_t update(size_t bytes_processed, size_t ms_spent, size_t& data_rate){ |
104 | 191 | size_t time = update(bytes_processed, ms_spent); |
105 | - data_rate = Bps; | |
192 | + data_rate = Bps[0]; | |
106 | 193 | return time; |
107 | 194 | } |
108 | 195 | }; | ... | ... |
stim/envi/bsq.h
... | ... | @@ -392,17 +392,24 @@ public: |
392 | 392 | bool bil(std::string outname, bool PROGRESS = false, bool VERBOSE = false){ |
393 | 393 | |
394 | 394 | const size_t buffers = 4; //number of buffers required for this algorithm |
395 | + | |
395 | 396 | size_t mem_per_batch = binary<T>::buffer_size / buffers; //calculate the maximum memory available for a batch |
396 | 397 | |
397 | 398 | size_t slice_bytes = X() * Z() * sizeof(T); //number of bytes in an input batch slice (Y-slice in this case) |
398 | 399 | size_t max_slices_per_batch = mem_per_batch / slice_bytes; //maximum number of slices we can process in one batch given memory constraints |
400 | + | |
401 | + if(VERBOSE){ | |
402 | + std::cout<<"maximum memory available for processing: "<<(double)binary<T>::buffer_size/(double)1000000<<" MB"<<std::endl; | |
403 | + std::cout<<" this supports a batch size of "<<max_slices_per_batch<<" Y-axis slices"<<std::endl; | |
404 | + } | |
405 | + | |
399 | 406 | if(max_slices_per_batch == 0){ //if there is insufficient memory for a single slice, throw an error |
400 | 407 | std::cout<<"error, insufficient memory for stim::bsq::bil()"<<std::endl; |
401 | 408 | exit(1); |
402 | 409 | } |
403 | 410 | size_t max_batch_bytes = max_slices_per_batch * slice_bytes; //calculate the amount of memory that will be allocated for all four buffers |
404 | 411 | |
405 | - stream_optimizer O(max_slices_per_batch, 1, max_slices_per_batch); | |
412 | + stream_optimizer O(1, max_slices_per_batch); | |
406 | 413 | |
407 | 414 | T* src[2]; //source double-buffer for asynchronous batching |
408 | 415 | src[0] = (T*) malloc(max_batch_bytes); |
... | ... | @@ -421,10 +428,6 @@ public: |
421 | 428 | std::future<size_t> rthread; |
422 | 429 | std::future<std::ostream&> wthread; //create asynchronous threads for reading and writing |
423 | 430 | |
424 | - //readlines(src[0], 0, N[0]); //read the first batch into the 0 source buffer | |
425 | - //y_load += N[0]; //increment the loaded slice counter | |
426 | - //int b = 1; | |
427 | - | |
428 | 431 | std::chrono::high_resolution_clock::time_point t_start, pt_start; //high-resolution timers |
429 | 432 | std::chrono::high_resolution_clock::time_point t_end, pt_end; |
430 | 433 | size_t t_batch; //number of milliseconds to process a batch |
... | ... | @@ -435,15 +438,15 @@ public: |
435 | 438 | size_t data_rate; |
436 | 439 | |
437 | 440 | rt_total += readlines(src[0], 0, N[0]); //read the first batch into the 0 source buffer |
438 | - y_load += N[0]; //increment the loaded slice counter | |
439 | - int b = 1; //initialize the double buffer to 0 | |
441 | + y_load += N[0]; //increment the loaded slice counter | |
442 | + int b = 1; //initialize the double buffer to 0 | |
440 | 443 | while(y_proc < Y()){ //while there are still slices to be processed |
441 | 444 | t_start = std::chrono::high_resolution_clock::now(); //start the timer for this batch |
442 | 445 | if(y_load < Y()){ //if there are still slices to be loaded, load them |
443 | - if(y_proc > 0){ | |
444 | - N[b] = O.update(N[!b] * slice_bytes, t_batch, data_rate); //set the batch size based on optimization | |
445 | - std::cout<<"New N = "<<N[b]<<" at "<<(double)data_rate / 1000000<<" MB/s"<<std::endl; | |
446 | - } | |
446 | + //if(y_proc > 0){ | |
447 | + | |
448 | + | |
449 | + //} | |
447 | 450 | if(y_load + N[b] > Y()) N[b] = Y() - y_load; //if the next batch would process more than the total slices, adjust the batch size |
448 | 451 | rthread = std::async(std::launch::async, &stim::bsq<T>::readlines, this, src[b], y_load, N[b]); |
449 | 452 | rt_total += rthread.get(); |
... | ... | @@ -452,7 +455,6 @@ public: |
452 | 455 | |
453 | 456 | b = !b; //swap the double-buffer |
454 | 457 | pt_total += binary<T>::permute(dst[b], src[b], X(), N[b], Z(), 0, 2, 1); //permute the batch to a BIL file |
455 | - //target.write((char*)dst[b], N[b] * slice_bytes); //write the permuted data to the output file | |
456 | 458 | wt_total += writeblock(&target, dst[b], N[b] * slice_bytes); //write the permuted data to the output file |
457 | 459 | y_proc += N[b]; //increment the counter of processed pixels |
458 | 460 | if(PROGRESS) progress = (double)( y_proc + 1 ) / Y() * 100; //increment the progress counter based on the number of processed pixels |
... | ... | @@ -460,6 +462,8 @@ public: |
460 | 462 | t_batch = std::chrono::duration_cast<std::chrono::milliseconds>(t_end-t_start).count(); |
461 | 463 | t_total += t_batch; |
462 | 464 | //if(y_load < Y()) rt_total += rthread.get(); //if a new batch was set to load, make sure it loads after calculations |
465 | + N[b] = O.update(N[!b] * slice_bytes, t_batch, data_rate); //set the batch size based on optimization | |
466 | + //std::cout<<"New N = "<<N[!b]<<" selected with "<<(double)data_rate / 1000000<<" MB/s"<<std::endl; | |
463 | 467 | } |
464 | 468 | |
465 | 469 | free(src[0]); //free buffer resources | ... | ... |