00001 #ifndef __RBFCLASSIFICATIONCONTROL_H
00002 #define __RBFCLASSIFICATIONCONTROL_H
00003
00004 #include <Control.hpp>
00005 #include <GNGCore/SGNGAlgorithm.hpp>
00006 #include "RBFNetwork.hpp"
00007 #include <Generators/IInputGenerator.hpp>
00008
00009 #include <iostream>
00010 #include <fstream>
00011
00012
00013
00014 #define RBF_DECISION_DELTA (0.25)
00015
00022 class RBFClassificationControl : public Control
00023 {
00024 public:
00033 RBFClassificationControl(RBFNetwork & net, IInputGenerator * g, SGNGAlgorithm * alg,
00034 unsigned int max, unsigned int backLog, double decisionLimit, double decisionDelta, std::ofstream * log)
00035 : m_net(net), m_generator(*g), m_alg(alg)
00036 {
00037 m_decisionLimit = decisionLimit;
00038 m_decisionDelta = decisionDelta;
00039 m_maxMisclassificationsAllowed = max;
00040 m_numMisclassifications = m_maxMisclassificationsAllowed+1;
00041 m_backLog = backLog;
00042 m_iteration = 0;
00043
00044 m_log = log;
00045 }
00046
00047 virtual ~RBFClassificationControl()
00048 {
00049 delete m_alg;
00050 delete &m_generator;
00051 m_log->close();
00052 delete m_log;
00053 }
00054
00058 virtual bool Iterate()
00059 {
00060 static unsigned int backLog = m_backLog;
00061 static unsigned int numMisclassifications = 0;
00062
00063 if(IsStopCriteriaMet())
00064 return false;
00065
00066 if(IsReadyToIterate())
00067 {
00068 m_iteration++;
00069 backLog--;
00070
00071 Vector input = m_generator.GetInput();
00072 Vector output = m_generator.GetOutput();
00073
00074 std::cout << m_iteration;
00075 std::cout << ", I:[" << input << "], D:[" << output << "]";
00076 Vector netOutput = m_net.Recall( input );
00077
00078
00079
00080
00081
00082
00083
00084 unsigned int bestIndex = 0;
00085 unsigned int secondBestIndex = 1;
00086 for(unsigned int i=1; i < netOutput.size(); i++)
00087 {
00088 if(netOutput[i] > netOutput[bestIndex])
00089 {
00090 secondBestIndex = bestIndex;
00091 bestIndex = i;
00092 }
00093 else
00094 if(netOutput[i] > netOutput[secondBestIndex])
00095 secondBestIndex = i;
00096 }
00097
00098 std::cout << ", O:[" << netOutput << "]";
00099
00100
00101
00102
00103 if( output[bestIndex] == 1.0 &&
00104 netOutput[bestIndex] >= m_decisionLimit &&
00105 fabs(netOutput[bestIndex] - netOutput[secondBestIndex]) >= m_decisionDelta )
00106 {
00107 std::cout << ", 1 ";
00108 }
00109 else
00110 {
00111 numMisclassifications++;
00112 std::cout << ", 0 " ;
00113 }
00114 std::cout << std::endl;
00115
00116
00117 m_net.Train( output );
00118
00119 m_alg->SetSquaredError( netOutput.SquaredDistance(output) );
00120 m_alg->Iterate( input );
00121
00122
00123 if(backLog == 0)
00124 {
00125 m_numMisclassifications = numMisclassifications;
00126 backLog = m_backLog;
00127 numMisclassifications = 0;
00128 (*m_log) << m_iteration << " " << m_numMisclassifications << std::endl;
00129 }
00130
00131 return true;
00132 }
00133 return false;
00134 }
00135
00136 protected:
00137 virtual bool IsStopCriteriaMet()
00138 {
00139 return m_numMisclassifications <= m_maxMisclassificationsAllowed;
00140 }
00141
00142 protected:
00143 RBFNetwork & m_net;
00144 IInputGenerator & m_generator;
00145 SGNGAlgorithm * m_alg;
00146
00147 unsigned int m_maxMisclassificationsAllowed;
00148 double m_decisionLimit;
00149 double m_decisionDelta;
00150 unsigned int m_backLog;
00151 unsigned int m_numMisclassifications;
00152 std::ofstream * m_log;
00153 };
00154
00155
00156 #endif