MeanWalker

histogrambuildertemplate.h

Go to the documentation of this file.
00001 /*
00002 
00003   This file is a part of MeanWalker, a library which provides utilities to
00004   sample from probability distributions using methods like the 
00005   Metropilis-Hastings algorithm and Goodman-Weare algorithm.
00006 
00007   Copyright (C) 2012 Jori Liesenborgs
00008 
00009   Contact: jori.liesenborgs@gmail.com
00010   
00011   This program is free software; you can redistribute it and/or modify
00012   it under the terms of the GNU General Public License as published by
00013   the Free Software Foundation; either version 2 of the License, or
00014   (at your option) any later version.
00015   
00016   This program is distributed in the hope that it will be useful,
00017   but WITHOUT ANY WARRANTY; without even the implied warranty of
00018   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00019   GNU General Public License for more details.
00020   
00021   You should have received a copy of the GNU General Public License
00022   along with this program; if not, write to the Free Software
00023   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
00024   
00025 */
00026 
00031 #ifndef MEANWALKER_HISTOGRAMBUILDERTEMPLATE_H
00032 
00033 #define MEANWALKER_HISTOGRAMBUILDERTEMPLATE_H
00034 
00035 #include "meanwalkerconfig.h"
00036 #include "randomnumbergenerator.h"
00037 #include "histogram.h"
00038 #include <vector>
00039 
00040 namespace meanwalker
00041 {
00042 
00048 template <class WalkerBase>
00049 class MEANWALKER_IMPORTEXPORT HistogramBuilderTemplate : public WalkerBase
00050 {
00051 public:
00052         HistogramBuilderTemplate(RandomNumberGenerator *pRng) : WalkerBase(pRng)
00053         {
00054                 m_init = false;
00055         }
00056 
00057         ~HistogramBuilderTemplate()
00058         {
00059         }
00060 
00061         bool initNormal(const double *pMinCoords, const double *pMaxCoords, size_t dim, size_t numBins)
00062         {
00063                 if (m_init)
00064                 {
00065                         errut::ErrorBase::setErrorString("Already initialized");
00066                         return false;
00067                 }
00068 
00069                 m_delayed = false;
00070                 m_init = true;
00071                 m_numBins = numBins;
00072 
00073                 m_histograms.resize(dim);
00074 
00075                 for (size_t i = 0 ; i < m_histograms.size() ; i++)
00076                 {
00077                         m_histograms[i].setIndex(i);
00078                         m_histograms[i].setInstance(this);
00079                 }
00080 
00081                 for (size_t i = 0 ; i < m_histograms.size() ; i++)
00082                         m_histograms[i].init(pMinCoords[i], pMaxCoords[i], m_numBins);
00083 
00084                 return true;
00085         }
00086 
00087         bool initDelayed(size_t dim, size_t numBins)
00088         {
00089                 if (m_init)
00090                 {
00091                         errut::ErrorBase::setErrorString("Already initialized");
00092                         return false;
00093                 }
00094                 m_delayed = true;
00095                 m_init = true;
00096                 m_numBins = numBins;
00097 
00098                 m_histograms.resize(dim);
00099 
00100                 for (size_t i = 0 ; i < m_histograms.size() ; i++)
00101                 {
00102                         m_histograms[i].setIndex(i);
00103                         m_histograms[i].setInstance(this);
00104                 }
00105 
00106                 return true;
00107         }
00108 
00109         void destroy()
00110         {
00111                 m_histograms.resize(0);
00112                 m_init = false;
00113         }
00114 
00125         bool run(size_t initSteps, size_t finalSteps)
00126         {
00127                 if (!m_init)
00128                 {
00129                         errut::ErrorBase::setErrorString("Not initialized");
00130                         return false;
00131                 }
00132 
00133                 m_record = false; 
00134                 WalkerBase::walk(initSteps); 
00135 
00136                 if (finalSteps > 0)
00137                 {
00138                         if (m_delayed)
00139                         {
00140                                 for (size_t i = 0 ; i < m_histograms.size() ; i++)
00141                                         m_histograms[i].reserveMemory(finalSteps);
00142                         }
00143 
00144                         m_record = true; 
00145                         WalkerBase::walk(finalSteps); 
00146 
00147                         if (m_delayed)
00148                         {
00149                                 m_delayed = false;
00150                                 for (size_t i = 0 ; i < m_histograms.size() ; i++)
00151                                         m_histograms[i].processRecordedValues(m_numBins);
00152 
00153                                 std::vector<const double *> recordedValues(m_histograms.size());
00154 
00155                                 for (size_t i = 0 ; i < m_histograms.size() ; i++)
00156                                         recordedValues[i] = m_histograms[i].getRecordedValues();
00157 
00158                                 onRecordedValuesProcessed(&(recordedValues[0]), m_histograms.size(), m_histograms[0].recordSize());
00159 
00160                                 for (size_t i = 0 ; i < m_histograms.size() ; i++)
00161                                         m_histograms[i].clearRecordedValues();
00162                         }
00163                 }
00164 
00165                 return true;
00166         }
00167 
00169         void clearHistograms()                                                                                  
00170         { 
00171                 for (size_t i = 0 ; i < m_histograms.size() ; i++) 
00172                         m_histograms[i].clear(); 
00173         }
00174 
00176         const Histogram &getHistogram(size_t i) const                                                           
00177         { 
00178                 return m_histograms[i]; 
00179         }
00180 
00182         const std::vector<const Histogram *> getHistograms() const
00183         {
00184                 std::vector<const Histogram *> h(m_histograms.size());
00185 
00186                 for (size_t i = 0 ; i < h.size() ; i++)
00187                         h[i] = &(m_histograms[i]);
00188 
00189                 return h;
00190         }
00191 
00193         void printHistograms() const
00194         {
00195                 for (size_t i = 0 ; i < m_histograms.size() ; i++)
00196                 {
00197                         std::cout << "# Histogram " << i << std::endl;
00198                         m_histograms[i].print();
00199                 }
00200         }
00201 protected:
00205         void onNewSample(double functionValue, const double *pCoords)
00206         {
00207                 if (m_record)
00208                 {
00209                         if (!m_delayed)
00210                         {
00211                                 for (size_t i = 0 ; i < m_histograms.size() ; i++)
00212                                         m_histograms[i].process(pCoords[i]);
00213                         }
00214                         else
00215                         {
00216                                 for (size_t i = 0 ; i < m_histograms.size() ; i++)
00217                                         m_histograms[i].record(pCoords[i]);
00218                         }
00219                 }
00220         }
00221 
00233         virtual void getHistogramMinMax(size_t idx, const double *pValues, size_t numValues, double &minValue, double &maxValue)
00234         {
00235                 DelayedHistogram::getDefaultHistogramMinMax(pValues, numValues, minValue, maxValue);
00236         }
00237 
00243         virtual void onRecordedValuesProcessed(const double **pValues, size_t numHistograms, size_t numRecords)
00244         { 
00245         }
00246 private:
00247         class InternalDelayedHistogram : public DelayedHistogram
00248         {
00249         public:
00250                 InternalDelayedHistogram()                                                                      { m_pInstance = 0; m_idx = 0; }
00251 
00252                 void setInstance(HistogramBuilderTemplate *pInst)                                               { m_pInstance = pInst; }
00253                 void setIndex(size_t idx)                                                                       { m_idx = idx; }
00254         protected:
00255                 void getHistogramMinMax(const double *pValues, size_t numValues, double &minValue, double &maxValue) { m_pInstance->getHistogramMinMax(m_idx, pValues, numValues, minValue, maxValue); }
00256         private:
00257                 size_t m_idx;
00258                 HistogramBuilderTemplate *m_pInstance;
00259         };
00260 
00261         std::vector<InternalDelayedHistogram> m_histograms;
00262         bool m_record, m_delayed, m_init;
00263         size_t m_numBins;
00264 };
00265 
00266 } // end namespace   
00267 
00268 #endif // MEANWALKER_HISTOGRAMBUILDERTEMPLATE_H