Commit 2418bcc36e6329504a80abb22c56edbb83445246

Authored by David Mayerich
1 parent 7a2d0012

optimized function class

Showing 1 changed file with 140 additions and 96 deletions   Show diff stats
math/function.h
... ... @@ -5,145 +5,166 @@
5 5  
6 6 namespace rts{
7 7  
8   -template <class X, class Y>
  8 +//template class for a one-dimensional function
  9 +template <class Tx, class Ty>
9 10 class function
10 11 {
11   - //datapoint class for storing function points
12   - struct dataPoint
13   - {
14   - X x;
15   - Y y;
16   - };
17   -
18   - //function data
19   - std::vector<dataPoint> f;
  12 + std::vector<Tx> X;
  13 + std::vector<Ty> Y;
20 14  
21 15 //comparison function for searching lambda
22   - static bool findCeiling(dataPoint a, dataPoint b)
23   - {
24   - return (a.x > b.x);
25   - }
  16 + static bool findCeiling(Tx ax, Tx bx){
  17 + return (ax > bx);
  18 + }
26 19  
  20 + //process a string to extract a function (generally delimited by tabs, commas, spaces, etc.)
  21 + void process_string(std::string s){
  22 + std::stringstream ss(s);
27 23  
28   -public:
29   - function()
30   - {
31   - //insert(0, 0);
32   - }
  24 + //std::string test;
  25 + //std::getline(ss, test);
  26 + //std::cout<<test;
  27 + //exit(1);
33 28  
34   - Y linear(X x) const
35   - {
36   - if(f.size() == 0) return (Y)0; //return zero if the function is empty
37   - //declare an iterator
38   - typedef typename std::vector< dataPoint >::iterator f_iter;
39   - f_iter it;
40   -
41   - //dataPoint s;
42   - //s.x = x;
43   -
44   - //it = search(f.begin(), f.end(), &s, &s + 1, &function<X, Y>::findCeiling);
45   - unsigned int i;
46   - for(i = 0; i<f.size(); i++)
47   - {
48   - if(f[i].x > x)
49   - break;
50   - }
  29 + Tx x;
  30 + Ty y;
  31 + std::string line;
51 32  
52   - //if the wavelength is past the end of the list, return the back
53   - if(i == f.size())
54   - return f.back().y;
55   - //if the wavelength is before the beginning of the list, return the front
56   - else if(i == 0)
57   - return f.front().y;
58   - //otherwise interpolate
59   - else
60   - {
61   - X xMax = f[i].x;
62   - X xMin = f[i - 1].x;
63   - //std::cout<<lMin<<"----------"<<lMax<<std::endl;
64   -
65   - X a = (x - xMin) / (xMax - xMin);
66   - Y riMin = f[i - 1].y;
67   - Y riMax = f[i].y;
68   - Y interp;
69   - interp = riMax * a + riMin * (1.0 - a);
70   - return interp;
71   - }
  33 + while(!ss.eof()){
  34 +
  35 + std::getline(ss, line);
  36 + if(line[0] == '#') continue;
  37 +
  38 + std::stringstream lstream(line);
  39 +
  40 + lstream>>x; //read the x value
  41 + lstream>>y; //read the y value
  42 +
  43 + //std::cout<<x<<", "<<y<<std::endl;
  44 +
  45 + if(ss.eof()) break;
  46 + insert(x, y); //insert the read value into the function
  47 +
  48 + }
72 49 }
73 50  
74   - ///add a data point to a function
75   - void insert(X x, Y y)
  51 +
  52 +public:
  53 +
  54 + //linear interpolation
  55 + Ty linear(Tx x) const
76 56 {
77   - dataPoint s;
78   - s.x = x;
79   - s.y = y;
  57 + if(X.size() == 0) return (Ty)0; //return zero if the function is empty
80 58  
81   - if(f.size() == 0 || f.back().x < x)
82   - return f.push_back(s);
  59 + unsigned int N = X.size(); //number of sample points
83 60  
84 61 //declare an iterator
85   - typename std::vector< dataPoint >::iterator it;
  62 + typedef typename std::vector< Tx >::iterator f_iter; //declare an iterator
  63 + f_iter it;
  64 +
  65 + //find the first X-coordinate that is greater than x
  66 + unsigned int i;
  67 + for(i = 0; i<N; i++){
  68 + if(X[i] > x)
  69 + break;
  70 + }
  71 + //i currently holds the ceiling
  72 +
  73 + //if the wavelength is past the end of the list, return the last sample point
  74 + if(i == N) return Y[N];
  75 + //if the wavelength is before the beginning of the list, return the front
  76 + else if(i == 0) return Y[0];
  77 + //otherwise interpolate
  78 + else{
  79 + Tx xMax = X[i];
  80 + Tx xMin = X[i-1];
  81 +
  82 + Tx a = (x - xMin) / (xMax - xMin);
  83 + Ty riMin = Y[i - 1];
  84 + Ty riMax = Y[i];
  85 + Ty interp = riMax * a + riMin * (1 - a);
  86 + return interp;
  87 + }
  88 + }
  89 +
  90 + ///add a data point to a function
  91 + void insert(Tx x, Ty y)
  92 + {
  93 + unsigned int N = X.size(); //number of sample points
86 94  
87   -
  95 + if(N == 0 || X[N-1] < x){
  96 + X.push_back(x);
  97 + Y.push_back(y);
  98 + return;
  99 + }
88 100  
89   - it = search(f.begin(), f.end(), &s, &s + 1, &function<X, Y>::findCeiling);
  101 + //declare an iterator and search for the x value
  102 + typename std::vector< Tx >::iterator it;
  103 + it = search(X.begin(), X.end(), &x, &x + 1, &function<Tx, Ty>::findCeiling);
90 104  
91 105 //if the function value is past the end of the vector, add it to the back
92   - if(it == f.end())
93   - return f.push_back(s);
  106 + if(*it == N){
  107 + X.push_back(x);
  108 + Y.push_back(y);
  109 + }
94 110 //otherwise add the value at the iterator position
95   - else
96   - {
97   - f.insert(it, s);
  111 + else{
  112 + X.insert(it, x);
  113 + Y.insert(Y.begin() + *it, y);
98 114 }
99 115  
100 116 }
101 117  
102   - X getX(unsigned int i) const
103   - {
104   - return f[i].x;
  118 + Tx getX(unsigned int i) const{
  119 + return X[i];
105 120 }
106 121  
107   - Y getY(unsigned int i) const
108   - {
109   - return f[i].y;
  122 + Ty getY(unsigned int i) const{
  123 + return Y[i];
110 124 }
111 125  
112 126 ///get the number of data points in the function
113   - unsigned int getN() const
114   - {
115   - return f.size();
  127 + unsigned int getN() const{
  128 + return X.size();
116 129 }
117 130  
118 131 //look up an indexed component
119   - dataPoint operator[](int i) const
120   - {
121   - return f[i];
  132 + Ty operator[](int i) const{
  133 + if(i <= X.size()){
  134 + std::cout<<"ERROR: accessing non-existing sample point in 'function'"<<std::endl;
  135 + exit(1);
  136 + }
  137 + return Y[i];
122 138 }
123 139  
124 140 ///linear interpolation
125   - Y operator()(X x) const
126   - {
  141 + Tx operator()(Tx x) const{
127 142 return linear(x);
128 143 }
129 144  
130   - function<X, Y> operator+(Y r) const
131   - {
132   - function<X, Y> result;
  145 + //add a constant to the function
  146 + function<Tx, Ty> operator+(Ty r) const{
  147 +
  148 + function<Tx, Ty> result;
133 149  
134   - //add r to every point in f
135   - for(int i=0; i<f.size(); i++)
136   - {
137   - result.f.push_back(f[i]);
138   - result.f[i].y += r;
  150 + //if there are points in the function
  151 + if(X.size() > 0){
  152 + //add r to every point in f
  153 + for(unsigned int i=0; i<X.size(); i++){
  154 + result.X.push_back(X[i]);
  155 + result.Y.push_back(Y[i] + r);
  156 + }
  157 + }
  158 + else{
  159 + result = r;
139 160 }
140 161  
141 162 return result;
142 163 }
143 164  
144   - function<X, Y> & operator= (const Y & rhs)
145   - {
146   - f.clear();
  165 + function<Tx, Ty> & operator= (const Ty & rhs){
  166 + X.clear();
  167 + Y.clear();
147 168 if(rhs != 0) //if the RHS is zero, just clear, otherwise add one value of RHS
148 169 insert(0, rhs);
149 170  
... ... @@ -151,6 +172,29 @@ public:
151 172 }
152 173  
153 174  
  175 +
  176 + std::string str(){
  177 + stringstream ss;
  178 +
  179 + unsigned int N = X.size(); //number of sample points
  180 + //output each function value
  181 + for(unsigned int i = 0; i<N; i++){
  182 + ss<<X[i]<<", "<<Y[i]<<std::endl;
  183 + }
  184 +
  185 + return ss.str();
  186 +
  187 + }
  188 +
  189 + void load(std::string filename){
  190 + std::ifstream t(filename.c_str());
  191 + std::string str((std::istreambuf_iterator<char>(t)),
  192 + std::istreambuf_iterator<char>());
  193 +
  194 + process_string(str);
  195 + }
  196 +
  197 +
154 198 };
155 199  
156 200 } //end namespace rts
... ...