00001 #include <vector.h>
00002
00003
00004 #include "MultiFunction.h"
00005 #include "Sym.h"
00006 #include "FunDef.h"
00007
00008 using namespace std;
00009
00010
00011 typedef vector<double>::const_iterator data_ptr;
00012 typedef vector<data_ptr> data_ptrs;
00013 typedef data_ptrs::const_iterator arg_ptr;
00014
00015 #include "MultiFuncs.cpp"
00016
00017 typedef double (*fptr)( arg_ptr );
00018
00019 string print_function( fptr f) {
00020 if (f == multi_function::plus) return "+";
00021 if (f == multi_function::mult) return "*";
00022 if (f == multi_function::min) return "-";
00023 if (f == multi_function::inv) return "/";
00024 if (f == multi_function::exp) return "e";
00025 return "unknown";
00026 }
00027
00028
00029 struct Function {
00030
00031 fptr function;
00032 arg_ptr args;
00033
00034 double operator()() const { return function(args); }
00035 };
00036
00037 static vector<Function> token_2_function;
00038
00039 Sym make_binary(Sym sym) {
00040 if (sym.args().size() == 2) return sym;
00041 SymVec args = sym.args();
00042 Sym an = args.back();
00043 args.pop_back();
00044 Sym nw = make_binary( Sym( sym.token(), args) );
00045 args.resize(2);
00046 args[0] = nw;
00047 args[1] = an;
00048 return Sym(sym.token(), args);
00049 }
00050
00051 class Compiler {
00052
00053 public:
00054
00055 enum func_type {constant, variable, function};
00056
00057 typedef pair<func_type, unsigned> entry;
00058
00059 #if USE_TR1
00060 typedef std::tr1::unordered_map<Sym, entry, HashSym> HashMap;
00061 #else
00062 typedef hash_map<Sym, entry, HashSym> HashMap;
00063 #endif
00064
00065 HashMap map;
00066
00067 vector<double> constants;
00068 vector<unsigned> variables;
00069 vector< fptr > functions;
00070 vector< vector<entry> > function_args;
00071
00072 unsigned total_args;
00073
00074 vector<entry> outputs;
00075
00076 Compiler() : total_args(0) {}
00077
00078 entry do_add(Sym sym) {
00079
00080 HashMap::iterator it = map.find(sym);
00081
00082 if (it == map.end()) {
00083
00084 token_t token = sym.token();
00085
00086 if (is_constant(token)) {
00087 constants.push_back( get_constant_value(token) );
00088 entry e = make_pair(constant, constants.size()-1);
00089 map.insert( make_pair(sym, e) );
00090 return e;
00091
00092 } else if (is_variable(token)) {
00093 unsigned idx = get_variable_index(token);
00094 variables.push_back(idx);
00095 entry e = make_pair(variable, variables.size()-1);
00096 map.insert( make_pair(sym, e) );
00097 return e;
00098 }
00099
00100 fptr f;
00101 vector<entry> vec;
00102 const SymVec& args = sym.args();
00103
00104 switch (token) {
00105 case sum_token:
00106 {
00107 if (args.size() == 0) {
00108 return do_add( SymConst(0.0));
00109 }
00110 if (args.size() == 1) {
00111 return do_add(args[0]);
00112 }
00113 if (args.size() == 2) {
00114 vec.push_back(do_add(args[0]));
00115 vec.push_back(do_add(args[1]));
00116 f = multi_function::plus;
00117
00118 break;
00119
00120 } else {
00121 return do_add( make_binary(sym) );
00122 }
00123
00124 }
00125 case prod_token:
00126 {
00127 if (args.size() == 0) {
00128 return do_add( SymConst(1.0));
00129 }
00130 if (args.size() == 1) {
00131 return do_add(args[0]);
00132 }
00133 if (args.size() == 2) {
00134 vec.push_back(do_add(args[0]));
00135 vec.push_back(do_add(args[1]));
00136 f = multi_function::mult;
00137
00138 break;
00139
00140
00141 } else {
00142 return do_add( make_binary(sym) );
00143 }
00144 }
00145 case sqr_token:
00146 {
00147 SymVec newargs(2);
00148 newargs[0] = args[0];
00149 newargs[1] = args[0];
00150 return do_add( Sym(prod_token, newargs));
00151 }
00152 default :
00153 {
00154 if (args.size() != 1) {
00155 cerr << "Unknown function " << sym << " encountered" << endl;
00156 exit(1);
00157 }
00158
00159 vec.push_back(do_add(args[0]));
00160
00161 switch (token) {
00162 case min_token: f = multi_function::min; break;
00163 case inv_token: f = multi_function::inv; break;
00164 case exp_token :f = multi_function::exp; break;
00165 default :
00166 {
00167 cerr << "Unimplemented token encountered " << sym << endl;
00168 exit(1);
00169 }
00170 }
00171
00172
00173
00174
00175 }
00176
00177 }
00178
00179 total_args += vec.size();
00180 function_args.push_back(vec);
00181 functions.push_back(f);
00182
00183 entry e = make_pair(function, functions.size()-1);
00184 map.insert( make_pair(sym, e) );
00185 return e;
00186
00187 }
00188
00189 return it->second;
00190 }
00191
00192 void add(Sym sym) {
00193 entry e = do_add(sym);
00194 outputs.push_back(e);
00195 }
00196
00197 };
00198
00199 class MultiFunctionImpl {
00200 public:
00201
00202
00203 vector<unsigned> input_idx;
00204
00205 unsigned constant_offset;
00206 unsigned var_offset;
00207
00208
00209 vector<double> data;
00210 vector<Function> funcs;
00211 data_ptrs args;
00212
00213 vector<unsigned> output_idx;
00214
00215 MultiFunctionImpl() {}
00216
00217 void clear() {
00218 input_idx.clear();
00219 data.clear();
00220 funcs.clear();
00221 args.clear();
00222 output_idx.clear();
00223 constant_offset = 0;
00224 }
00225
00226 void eval(const double* x, double* y) {
00227 unsigned i;
00228
00229 for (i = constant_offset; i < constant_offset + input_idx.size(); ++i) {
00230 data[i] = x[input_idx[i-constant_offset]];
00231 }
00232
00233 for(; i < data.size(); ++i) {
00234 data[i] = funcs[i-var_offset]();
00235
00236 }
00237
00238 for (i = 0; i < output_idx.size(); ++i) {
00239 y[i] = data[output_idx[i]];
00240 }
00241 }
00242
00243 void eval(const vector<double>& x, vector<double>& y) {
00244 eval(&x[0], &y[0]);
00245 }
00246
00247 void setup(const vector<Sym>& pop) {
00248
00249 clear();
00250 Compiler compiler;
00251
00252 for (unsigned i = 0; i < pop.size(); ++i) {
00253 Sym sym = (expand_all(pop[i]));
00254 compiler.add(sym);
00255 }
00256
00257
00258 constant_offset = compiler.constants.size();
00259 var_offset = constant_offset + compiler.variables.size();
00260 int n = var_offset + compiler.functions.size();
00261
00262 data.resize(n);
00263 funcs.resize(compiler.functions.size());
00264 args.resize(compiler.total_args);
00265
00266
00267 for (unsigned i = 0; i < constant_offset; ++i) {
00268 data[i] = compiler.constants[i];
00269
00270 }
00271
00272
00273 input_idx = compiler.variables;
00274
00275
00276
00277
00278
00279
00280 unsigned which_arg = 0;
00281 for (unsigned i = 0; i < funcs.size(); ++i) {
00282
00283 Function f;
00284 f.function = compiler.functions[i];
00285
00286
00287
00288
00289 for (unsigned j = 0; j < compiler.function_args[i].size(); ++j) {
00290
00291 Compiler::entry e = compiler.function_args[i][j];
00292
00293 unsigned idx = e.second;
00294
00295 switch (e.first) {
00296 case Compiler::function: idx += compiler.variables.size();
00297 case Compiler::variable: idx += compiler.constants.size();
00298 case Compiler::constant: {}
00299 }
00300
00301 args[which_arg + j] = data.begin() + idx;
00302
00303 }
00304
00305
00306
00307 f.args = args.begin() + which_arg;
00308 which_arg += compiler.function_args[i].size();
00309 funcs[i] = f;
00310 }
00311
00312
00313 output_idx.resize(compiler.outputs.size());
00314 for (unsigned i = 0; i < output_idx.size(); ++i) {
00315 output_idx[i] = compiler.outputs[i].second;
00316 switch(compiler.outputs[i].first) {
00317 case Compiler::function: output_idx[i] += compiler.variables.size();
00318 case Compiler::variable: output_idx[i] += compiler.constants.size();
00319 case Compiler::constant: {}
00320 }
00321
00322 }
00323 }
00324
00325 };
00326
00327
00328
00329 MultiFunction::MultiFunction(const std::vector<Sym>& pop) : pimpl(new MultiFunctionImpl) {
00330 pimpl->setup(pop);
00331 }
00332
00333 MultiFunction::~MultiFunction() { delete pimpl; }
00334
00335 void MultiFunction::operator()(const std::vector<double>& x, std::vector<double>& y) {
00336 pimpl->eval(x,y);
00337 }
00338
00339 void MultiFunction::operator()(const double* x, double* y) {
00340 pimpl->eval(x,y);
00341 }