00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019 #include <LanguageTable.h>
00020 #include <TreeBuilder.h>
00021 #include <FunDef.h>
00022 #include <Dataset.h>
00023
00024 #include <eoSymInit.h>
00025 #include <eoSym.h>
00026 #include <eoPop.h>
00027 #include <eoSymMutate.h>
00028
00029 #include <eoSymCrossover.h>
00030 #include <eoSymEval.h>
00031 #include <eoOpContainer.h>
00032 #include <eoDetTournamentSelect.h>
00033 #include <eoMergeReduce.h>
00034 #include <eoGenContinue.h>
00035 #include <eoEasyEA.h>
00036 #include <eoGeneralBreeder.h>
00037
00038 #include <utils/eoParser.h>
00039 #include <utils/eoCheckPoint.h>
00040 #include <utils/eoStat.h>
00041 #include <utils/eoStdoutMonitor.h>
00042 #include <utils/eoRNG.h>
00043
00044 using namespace std;
00045
00046 typedef EoSym<eoMinimizingFitness> EoType;
00047
00048 static int functions_added = 0;
00049
00050 void add_function(LanguageTable& table, eoParser& parser, string name, unsigned arity, token_t token, const FunDef& fun);
00051 void setup_language(LanguageTable& table, eoParser& parser);
00052
00053 template <class T>
00054 T& select(bool check, T& a, T& b) { if (check) return a; return b; }
00055
00056 class eoBestIndividualStat : public eoSortedStat<EoType, string> {
00057 public:
00058 eoBestIndividualStat() : eoSortedStat<EoType, string>("", "best individual") {}
00059
00060 void operator()(const vector<const EoType*>& _pop) {
00061 ostringstream os;
00062 os << (Sym) *_pop[0];
00063 value() = os.str();
00064 }
00065
00066 };
00067
00068 class AverageSizeStat : public eoStat<EoType, double> {
00069 public:
00070 AverageSizeStat() : eoStat<EoType, double>(0.0, "Average size population") {}
00071
00072 void operator()(const eoPop<EoType>& _pop) {
00073 double total = 0.0;
00074 for (unsigned i = 0; i < _pop.size(); ++i) {
00075 total += _pop[i].size();
00076 }
00077 value() = total/_pop.size();
00078 }
00079 };
00080
00081 class SumSizeStat : public eoStat<EoType, unsigned> {
00082 public:
00083 SumSizeStat() : eoStat<EoType, unsigned>(0u, "Number of subtrees") {}
00084
00085 void operator()(const eoPop<EoType>& _pop) {
00086 unsigned total = 0;
00087 for (unsigned i = 0; i < _pop.size(); ++i) {
00088 total += _pop[i].size();
00089 }
00090 value() = total;
00091 }
00092 };
00093
00094 class DagSizeStat : public eoStat<EoType, unsigned> {
00095 public:
00096 DagSizeStat() : eoStat<EoType, unsigned>(0u, "Number of distinct subtrees") {}
00097
00098 void operator()(const eoPop<EoType>& _pop) {
00099 value() = Sym::get_dag().size();
00100 }
00101 };
00102
00103 int main(int argc, char* argv[]) {
00104
00105 eoParser parser(argc, argv);
00106
00107
00108 LanguageTable table;
00109 setup_language(table, parser);
00110
00111
00112
00113 eoValueParam<string> datafile = parser.createParam(string(""), "datafile", "Training data", 'd', string("Regression"), true);
00114 double train_percentage = parser.createParam(1.0, "trainperc", "Percentage of data used for training", 0, string("Regression")).value();
00115
00116
00117
00118 unsigned pop_size = parser.createParam(500u, "population-size", "Population Size", 'p', string("Population")).value();
00119
00120 uint32_t seed = parser.createParam( uint32_t(time(0)), "random-seed", "Seed for rng", 'D').value();
00121
00122 cout << "Seed " << seed << endl;
00123 rng.reseed(seed);
00124
00125 double var_prob = parser.createParam(
00126 0.9,
00127 "var-prob",
00128 "Probability of selecting a var vs. const when creating a terminal",
00129 0,
00130 "Population").value();
00131
00132
00133 double grow_prob = parser.createParam(
00134 0.5,
00135 "grow-prob",
00136 "Probability of selecting 'grow' method instead of 'full' in initialization and mutation",
00137 0,
00138 "Population").value();
00139
00140 unsigned max_depth = parser.createParam(
00141 8u,
00142 "max-depth",
00143 "Maximum depth used in initialization and mutation",
00144 0,
00145 "Population").value();
00146
00147
00148 bool use_uniform = parser.createParam(
00149 false,
00150 "use-uniform",
00151 "Use uniform node selection instead of bias towards internal nodes (functions)",
00152 0,
00153 "Population").value();
00154
00155 double constant_mut_prob = parser.createParam(
00156 0.1,
00157 "constant-mut-rate",
00158 "Probability of performing constant mutation",
00159 0,
00160 "Population").value();
00161
00162
00163 double subtree_mut_prob = parser.createParam(
00164 0.2,
00165 "subtree-mut-rate",
00166 "Probability of performing subtree mutation",
00167 0,
00168 "Population").value();
00169
00170 double node_mut_prob = parser.createParam(
00171 0.2,
00172 "node-mut-rate",
00173 "Probability of performing node mutation",
00174 0,
00175 "Population").value();
00176
00177
00178
00179
00180
00181
00182
00183
00184 double subtree_xover_prob = parser.createParam(
00185 0.4,
00186 "xover-rate",
00187 "Probability of performing subtree crossover",
00188 0,
00189 "Population").value();
00190
00191 double homologous_prob = parser.createParam(
00192 0.4,
00193 "homologous-rate",
00194 "Probability of performing homologous crossover",
00195 0,
00196 "Population").value();
00197
00198 unsigned max_gens = parser.createParam(
00199 50,
00200 "max-gens",
00201 "Maximum number of generations to run",
00202 'g',
00203 "Population").value();
00204
00205 unsigned tournamentsize = parser.createParam(
00206 5,
00207 "tournament-size",
00208 "Tournament size used for selection",
00209 't',
00210 "Population").value();
00211
00212 unsigned maximumSize = parser.createParam(
00213 -1u,
00214 "maximum-size",
00215 "Maximum size after crossover",
00216 's',
00217 "Population").value();
00218
00219 unsigned meas_param = parser.createParam(
00220 2u,
00221 "measure",
00222 "Error measure:\n\
00223 0 -> absolute error\n\
00224 1 -> mean squared error\n\
00225 2 -> mean squared error scaled (equivalent with correlation)\n\
00226 ",
00227 'm',
00228 "Regression").value();
00229
00230
00231 ErrorMeasure::measure meas = ErrorMeasure::mean_squared_scaled;
00232 if (meas_param == 0) meas = ErrorMeasure::absolute;
00233 if (meas_param == 1) meas = ErrorMeasure::mean_squared;
00234
00235
00236
00237 if (parser.userNeedsHelp())
00238 {
00239 parser.printHelp(std::cout);
00240 return 1;
00241 }
00242
00243 if (functions_added == 0) {
00244 cout << "ERROR: no functions defined" << endl;
00245 exit(1);
00246 }
00247
00248
00249 Dataset dataset;
00250 dataset.load_data(datafile.value());
00251
00252 cout << "Data " << datafile.value() << " loaded " << endl;
00253
00254
00255 unsigned nvars = dataset.n_fields();
00256 for (unsigned i = 0; i < nvars; ++i) {
00257 table.add_function( SymVar(i).token(), 0);
00258 }
00259
00260 TreeBuilder builder(table, var_prob);
00261 eoSymInit<EoType> init(builder, grow_prob, max_depth);
00262
00263 eoPop<EoType> pop(pop_size, init);
00264
00265 BiasedNodeSelector biased_sel;
00266 RandomNodeSelector random_sel;
00267
00268 NodeSelector& node_selector = select<NodeSelector>(use_uniform, random_sel, biased_sel);
00269
00270
00271 eoSequentialOp<EoType> genetic_operator;
00272
00273 eoSymSubtreeMutate<EoType> submutate(builder, node_selector);
00274 genetic_operator.add( submutate, subtree_mut_prob);
00275
00276
00277 double std = 1.0;
00278 eoSymConstantMutate<EoType> constmutate(std);
00279 genetic_operator.add(constmutate, constant_mut_prob);
00280
00281 eoSymNodeMutate<EoType> nodemutate(table);
00282 genetic_operator.add(nodemutate, node_mut_prob);
00283
00284
00285
00286
00287
00288 eoSizeLevelCrossover<EoType> bin;
00289
00290 genetic_operator.add(bin, subtree_xover_prob);
00291
00292 eoBinHomologousCrossover<EoType> hom;
00293 genetic_operator.add(hom, homologous_prob);
00294
00295
00296 IntervalBoundsCheck check(dataset.input_minima(), dataset.input_maxima());
00297 ErrorMeasure measure(dataset, train_percentage, meas);
00298
00299 eoSymPopEval<EoType> evaluator(check, measure, maximumSize);
00300
00301 eoDetTournamentSelect<EoType> selectOne(tournamentsize);
00302 eoGeneralBreeder<EoType> breeder(selectOne, genetic_operator,1);
00303 eoPlusReplacement<EoType> replace;
00304
00305
00306 eoGenContinue<EoType> term(max_gens);
00307 eoCheckPoint<EoType> checkpoint(term);
00308
00309 eoBestFitnessStat<EoType> beststat;
00310 checkpoint.add(beststat);
00311
00312 eoBestIndividualStat printer;
00313 AverageSizeStat avgSize;
00314 DagSizeStat dagSize;
00315 SumSizeStat sumSize;
00316
00317 checkpoint.add(printer);
00318 checkpoint.add(avgSize);
00319 checkpoint.add(dagSize);
00320 checkpoint.add(sumSize);
00321
00322 eoStdoutMonitor genmon;
00323 genmon.add(beststat);
00324 genmon.add(printer);
00325 genmon.add(avgSize);
00326 genmon.add(dagSize);
00327 genmon.add(sumSize);
00328 genmon.add(term);
00329
00330 checkpoint.add(genmon);
00331
00332 eoPop<EoType> dummy;
00333 evaluator(pop, dummy);
00334
00335 eoEasyEA<EoType> ea(checkpoint, evaluator, breeder, replace);
00336
00337 ea(pop);
00338
00339 }
00340
00341 void add_function(LanguageTable& table, eoParser& parser, string name, unsigned arity, token_t token, const FunDef& fun, bool all) {
00342 ostringstream desc;
00343 desc << "Enable function " << name << " arity = " << arity;
00344 bool enabled = parser.createParam(false, name, desc.str(), 0, "Language").value();
00345
00346 if (enabled || all) {
00347 cout << "Func " << name << " enabled" << endl;
00348 table.add_function(token, arity);
00349 if (arity > 0) functions_added++;
00350 }
00351 }
00352
00353 void setup_language(LanguageTable& table, eoParser& parser) {
00354
00355 bool all = parser.createParam(false,"all", "Enable all functions").value();
00356 bool ratio = parser.createParam(false,"ratio","Enable rational functions (inv,min,sum,prod)").value();
00357 bool poly = parser.createParam(false,"poly","Enable polynomial functions (min,sum,prod)").value();
00358
00359
00360 vector<const FunDef*> lang = get_defined_functions();
00361
00362 for (token_t i = 0; i < lang.size(); ++i) {
00363
00364 if (lang[i] == 0) continue;
00365
00366 bool is_poly = false;
00367 if (poly && (i == prod_token || i == sum_token || i == min_token) ) {
00368 is_poly = true;
00369 }
00370
00371 bool is_ratio = false;
00372 if (ratio && (is_poly || i == inv_token)) {
00373 is_ratio = true;
00374 }
00375
00376 const FunDef& fun = *lang[i];
00377
00378 if (fun.has_varargs() ) {
00379
00380 for (unsigned j = fun.min_arity(); j < fun.min_arity() + 8; ++j) {
00381 if (j==1) continue;
00382 ostringstream nm;
00383 nm << fun.name() << j;
00384 bool addanyway = (all || is_ratio || is_poly) && j == 2;
00385 add_function(table, parser, nm.str(), j, i, fun, addanyway);
00386 }
00387 }
00388 else {
00389 add_function(table, parser, fun.name(), fun.min_arity(), i, fun, all || is_ratio || is_poly);
00390 }
00391 }
00392 }
00393
00394