RBFClassificationControl.hpp

Go to the documentation of this file.
00001 #ifndef __RBFCLASSIFICATIONCONTROL_H
00002 #define __RBFCLASSIFICATIONCONTROL_H
00003 
00004 #include <Control.hpp>
00005 #include <GNGCore/SGNGAlgorithm.hpp>
00006 #include "RBFNetwork.hpp"
00007 #include <Generators/IInputGenerator.hpp>
00008 
00009 #include <iostream>
00010 #include <fstream>
00011 
00012 // the difference between the larget answer and 
00013 //  second largest answer in the output vector.
00014 #define RBF_DECISION_DELTA (0.25)
00015 
00022 class RBFClassificationControl : public Control
00023 {
00024 public:
00033     RBFClassificationControl(RBFNetwork & net, IInputGenerator * g, SGNGAlgorithm * alg, 
00034         unsigned int max, unsigned int backLog, double decisionLimit, double decisionDelta, std::ofstream * log) 
00035         : m_net(net), m_generator(*g), m_alg(alg)
00036     {
00037         m_decisionLimit = decisionLimit;
00038         m_decisionDelta = decisionDelta;
00039         m_maxMisclassificationsAllowed = max;
00040         m_numMisclassifications = m_maxMisclassificationsAllowed+1;
00041         m_backLog = backLog;
00042         m_iteration = 0;
00043 
00044         m_log = log;
00045     }
00046 
00047     virtual ~RBFClassificationControl() 
00048     { 
00049         delete m_alg; 
00050         delete &m_generator; 
00051         m_log->close();
00052         delete m_log;
00053     }
00054 
00058     virtual bool Iterate()
00059     {
00060         static unsigned int backLog = m_backLog;
00061         static unsigned int numMisclassifications = 0;
00062 
00063         if(IsStopCriteriaMet())
00064             return false;
00065 
00066         if(IsReadyToIterate())
00067         {
00068             m_iteration++;
00069             backLog--;
00070 
00071             Vector input = m_generator.GetInput();
00072             Vector output = m_generator.GetOutput();
00073 
00074             std::cout << m_iteration;
00075             std::cout << ", I:[" << input << "], D:[" << output << "]";
00076             Vector netOutput = m_net.Recall( input );
00077 
00078             
00079             // decide if the output is correct. 
00080             //  For a correct answer to be accepted, the differece
00081             //  between the answer and second best answer must be
00082             //  at least m_decisionLimit.
00083             //  Find the best and second best.
00084             unsigned int bestIndex = 0;
00085             unsigned int secondBestIndex = 1;
00086             for(unsigned int i=1; i < netOutput.size(); i++)
00087             {
00088               if(netOutput[i] > netOutput[bestIndex])
00089               {
00090                 secondBestIndex = bestIndex;
00091                 bestIndex = i;
00092               }
00093               else
00094                 if(netOutput[i] > netOutput[secondBestIndex])
00095                   secondBestIndex = i;
00096             }
00097 
00098             std::cout << ", O:[" << netOutput << "]";
00099 
00100             // The Decision Limit check.
00101             //   Reject the answer if diff between best and second best is
00102             //   less than m_decisionDelta or of best answer is too low.                                             
00103             if( output[bestIndex] == 1.0 &&
00104                     netOutput[bestIndex] >= m_decisionLimit && 
00105                     fabs(netOutput[bestIndex] - netOutput[secondBestIndex]) >= m_decisionDelta )
00106             {
00107                 std::cout << ", 1 "; //<< winner; //netOutput;
00108             }
00109             else
00110             {
00111                 numMisclassifications++;
00112                 std::cout << ", 0 " ;
00113             }
00114             std::cout << std::endl;
00115 
00116             // train the network
00117             m_net.Train( output );
00118 
00119             m_alg->SetSquaredError( netOutput.SquaredDistance(output) );
00120             m_alg->Iterate( input );
00121 
00122             // calculate the set the number of misclassifications in the last backlog steps.
00123             if(backLog == 0)
00124             {
00125                 m_numMisclassifications = numMisclassifications;
00126                 backLog = m_backLog;
00127                 numMisclassifications = 0;
00128                 (*m_log) << m_iteration << " " << m_numMisclassifications << std::endl;
00129             }
00130 
00131             return true;
00132         }
00133         return false;
00134     }
00135 
00136 protected:
00137     virtual bool IsStopCriteriaMet() 
00138     {
00139         return m_numMisclassifications <= m_maxMisclassificationsAllowed;
00140     }
00141 
00142 protected:
00143     RBFNetwork &       m_net;
00144     IInputGenerator &  m_generator;
00145     SGNGAlgorithm *    m_alg;
00146 
00147     unsigned int       m_maxMisclassificationsAllowed; // in backlog steps.
00148     double             m_decisionLimit;
00149     double             m_decisionDelta;
00150     unsigned int       m_backLog;
00151     unsigned int       m_numMisclassifications;
00152     std::ofstream *    m_log;
00153 };
00154 
00155 
00156 #endif

Generated on Mon Mar 22 16:40:48 2004 for GNG_GL by doxygen 1.3.6