00001 #ifndef __RBFFUNCTIONAPPROXIMATIONCONTROL_H
00002 #define __RBFFUNCTIONAPPROXIMATIONCONTROL_H
00003
00004 #include <Control.hpp>
00005 #include <GNGCore/SGNGAlgorithm.hpp>
00006 #include "RBFNetwork.hpp"
00007 #include <Generators/IInputGenerator.hpp>
00008
00009 #include <iostream>
00010 #include <fstream>
00011
00018 class RBFFunctionApproximationControl : public Control
00019 {
00020 public:
00029 RBFFunctionApproximationControl(RBFNetwork & net, IInputGenerator * g, SGNGAlgorithm * alg, double max, unsigned int backLog, std::ofstream * log)
00030 : m_net(net), m_generator(*g), m_alg(alg)
00031 {
00032 m_maxErrorAllowed = max;
00033 m_currentError = m_maxErrorAllowed+1;
00034 m_backLog = backLog;
00035 m_iteration = 0;
00036
00037 m_log = log;
00038 }
00039
00040 virtual ~RBFFunctionApproximationControl()
00041 {
00042 delete m_alg;
00043 delete &m_generator;
00044 m_log->close();
00045 delete m_log;
00046 }
00047
00051 virtual bool Iterate()
00052 {
00053 static unsigned int backLog = m_backLog;
00054 static double loggedError = 0.0;
00055
00056 if(IsStopCriteriaMet())
00057 return false;
00058
00059 if(IsReadyToIterate())
00060 {
00061 m_iteration++;
00062 backLog--;
00063 Vector input = m_generator.GetInput();
00064 Vector output = m_generator.GetOutput();
00065
00066 std::cout << m_iteration;
00067 std::cout << ", I:" << input << ", D:" << output;
00068 Vector netOutput = m_net.Recall( input );
00069 std::cout << ", O:" << netOutput;
00070
00071 m_net.Train( output );
00072
00073 double currentError = netOutput.SquaredDistance(output);
00074 std::cout << ", sqE:" << currentError << std::endl;
00075 loggedError += currentError;
00076 m_alg->SetSquaredError( currentError );
00077 m_alg->Iterate( input );
00078
00079
00080 if(backLog == 0)
00081 {
00082 m_currentError = loggedError / (double)m_backLog;
00083 backLog = m_backLog;
00084 loggedError = 0.0;
00085 (*m_log) << m_iteration << " " << m_currentError << std::endl;
00086 }
00087
00088 return true;
00089 }
00090 return false;
00091 }
00092
00097 virtual void Snapshot()
00098 {
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113 }
00114
00115 protected:
00116 virtual bool IsStopCriteriaMet()
00117 {
00118 return m_currentError <= m_maxErrorAllowed;
00119 }
00120
00121 protected:
00122 RBFNetwork & m_net;
00123 IInputGenerator & m_generator;
00124 SGNGAlgorithm * m_alg;
00125
00126 unsigned int m_backLog;
00127 double m_maxErrorAllowed;
00128 double m_currentError;
00129 std::ofstream * m_log;
00130 };
00131
00132
00133 #endif