OpenGM  2.3.x
Discrete Graphical Model Library
trws_base.hxx
Go to the documentation of this file.
1 #ifndef TRWS_BASE_HXX_
2 #define TRWS_BASE_HXX_
3 #include <iostream>
4 #include <time.h>
7 #include <functional>
11 
12 namespace opengm {
13 namespace trws_base{
14 
15 template<class GM>
17 {
18 public:
19  typedef GM GraphicalModelType;
21  typedef typename GM::ValueType ValueType;
22  typedef typename GM::IndexType IndexType;
23  typedef typename GM::LabelType LabelType;
27  //typedef enum {GRIDSTRUCTURE, GENERALSTRUCTURE} StructureType;
29 
30  static StructureType getStructureType(const std::string& structName)
31  {
32  if (structName.compare("GRID")==0) return GRIDSTRUCTURE;
33  else if (structName.compare("EDGE")==0) return EDGESTRUCTURE;
34  else return GENERALSTRUCTURE;
35  }
36 
37  static std::string getString(StructureType structure)
38  {
39  switch (structure)
40  {
41  case GENERALSTRUCTURE: return std::string("GENERAL");
42  case GRIDSTRUCTURE : return std::string("GRID");
43  case EDGESTRUCTURE : return std::string("EDGE BASED");
44  default: return std::string("UNKNOWN");
45  }
46  }
47 
48 
50 
51  typedef std::vector<typename GM::ValueType> DDVectorType;
52 
53  DecompositionStorage(const GM& gm,StructureType structureType=GENERALSTRUCTURE, const DDVectorType* pddvector=0);
55 
56  const GM& masterModel()const{return _gm;}
57  LabelType numberOfLabels(IndexType varId)const{return _gm.numberOfLabels(varId);}
58  IndexType numberOfModels()const{return (IndexType)_subModels.size();}
59  IndexType numberOfSharedVariables()const{return (IndexType)_variableDecomposition.size();}
60  SubModel& subModel(IndexType modelId){return *_subModels[modelId];}
61  const SubModel& subModel(IndexType modelId)const{return *_subModels[modelId];}
62  IndexType size(IndexType subModelId)const{return (IndexType)_subModels[subModelId]->size();}
63 
64  const SubVariableListType& getSubVariableList(IndexType varId)const{return _variableDecomposition[varId];}
65  StructureType getStructureType()const{return _structureType;}
66 #ifdef TRWS_DEBUG_OUTPUT
67  void PrintTestData(std::ostream& fout)const;
68  void PrintVariableDecompositionConsistency(std::ostream& fout)const;
69 #endif
70 
71  void getDDVector(DDVectorType* ddvector)const;
72  size_t getDDVectorSize()const;
73  void addDDvector(const DDVectorType& ddvector);
74 private:
75  void _InitSubModels(const DDVectorType* pddvector=0);
76  //void _addDDvector(const DDVectorType& ddvector);
77  const GM& _gm;
78  StructureType _structureType;
79  std::vector<SubModel*> _subModels;
80  std::vector<SubVariableListType> _variableDecomposition;
81  VariableToFactorMap _var2FactorMap;
82 };
83 
84 
85 template<class ValueType>
87 {
89  ValueType precision_;
90  bool absolutePrecision_;//true for absolute precision, false for relative w.r.t. dual value
94 
95  TRWSPrototype_Parameters(size_t maxIternum,
96  ValueType precision=1.0,
97  bool absolutePrecision=true,
98  ValueType minRelativeDualImprovement=-1.0,
99  bool fastComputations=true,
100  bool canonicalNormalization=false):
101  maxNumberOfIterations_(maxIternum),
102  precision_(precision),
103  absolutePrecision_(absolutePrecision),
104  minRelativeDualImprovement_(minRelativeDualImprovement),
105  fastComputations_(fastComputations),
106  canonicalNormalization_(canonicalNormalization)
107  {};
108 };
109 
110 template<class GM>
112 {
113 public:
114  typedef typename GM::IndexType IndexType;
117  struct FactorVarID
118  {
120  FactorVarID(IndexType fID,IndexType vID,IndexType lID):
121  factorId(fID),varId(vID),localId(lID){};
122 
123 #ifdef TRWS_DEBUG_OUTPUT
124  void print(std::ostream& out)const{out <<"("<<factorId<<","<<varId<<","<<localId<<"),";}
125 #endif
126 
127  IndexType factorId;
128  IndexType varId;
129  IndexType localId;//local index of varId
130  };
131  typedef std::vector<FactorVarID> FactorList;
132  typedef typename FactorList::const_iterator const_iterator;
133 
134  PreviousFactorTable(const GM& gm);
135  const_iterator begin(IndexType varId,MoveDirection md)const{return (md==Storage::Direct ? _forwardFactors[varId].begin() : _backwardFactors[varId].begin());}
136  const_iterator end(IndexType varId,MoveDirection md)const{return (md==Storage::Direct ? _forwardFactors[varId].end() : _backwardFactors[varId].end());}
137 #ifdef TRWS_DEBUG_OUTPUT
138  void PrintTestData(std::ostream& fout);
139 #endif
140 private:
141  std::vector<FactorList> _forwardFactors;
142  std::vector<FactorList> _backwardFactors;
143 };
144 
145 template<class GM>
147 _forwardFactors(gm.numberOfVariables()),
148 _backwardFactors(gm.numberOfVariables())
149 {
150  std::vector<IndexType> varIDs(2);
151  for (IndexType factorId=0;factorId<gm.numberOfFactors();++factorId)
152  {
153  switch (gm[factorId].numberOfVariables())
154  {
155  case 1: break;
156  case 2:
157  gm[factorId].variableIndices(varIDs.begin());
158  if (varIDs[0] < varIDs[1])
159  {
160  _forwardFactors[varIDs[1]].push_back(FactorVarID(factorId,varIDs[0],0));
161  _backwardFactors[varIDs[0]].push_back(FactorVarID(factorId,varIDs[1],1));
162  }
163  else
164  {
165  _forwardFactors[varIDs[0]].push_back(FactorVarID(factorId,varIDs[1],1));
166  _backwardFactors[varIDs[1]].push_back(FactorVarID(factorId,varIDs[0],0));
167  }
168  break;
169  default: throw std::runtime_error("PreviousFactor::PreviousFactor: only the factors of order <=2 are supported!");
170  }
171  }
172 }
173 
174 #ifdef TRWS_DEBUG_OUTPUT
175 template<class GM>
176 void PreviousFactorTable<GM>::PrintTestData(std::ostream& fout)
177 {
178  fout << "Forward factors:"<<std::endl;
179  for (size_t varId=0;varId<_forwardFactors.size();++varId)
180  {
181  fout << "varId="<<varId<<", ";
182  for (size_t i=0;i<_forwardFactors[varId].size();++i)
183  _forwardFactors[varId][i].print(fout);
184  fout <<std::endl;
185  }
186 
187  fout << "Backward factors:"<<std::endl;
188  for (size_t varId=0;varId<_backwardFactors.size();++varId)
189  {
190  fout << "varId="<<varId<<", ";
191  for (size_t i=0;i<_backwardFactors[varId].size();++i)
192  _backwardFactors[varId][i].print(fout);
193  fout <<std::endl;
194  }
195 }
196 #endif
197 
198 template <class SubSolver>
200 {
201 public:
202  typedef typename SubSolver::GMType GM;//TODO: remove me
203  typedef GM GraphicalModelType;
204  typedef typename SubSolver::ACCType ACC;//TODO: remove me
205  typedef ACC AccumulationType;
206  typedef SubSolver SubSolverType;
208  //typedef visitors::ExplicitEmptyVisitor< TRWSPrototype<SubSolverType> > EmptyVisitorParent;
211 
212  typedef typename SubSolver::const_iterators_pair const_marginals_iterators_pair;
213  typedef typename GM::ValueType ValueType;
214  typedef typename GM::IndexType IndexType;
215  typedef typename GM::LabelType LabelType;
217  typedef typename std::vector<ValueType> OutputContainerType;
218  typedef typename OutputContainerType::iterator OutputIteratorType;//TODO: make a template parameter
219 
221 
224  typedef typename Storage::UnaryFactor UnaryFactor;
225 
226  TRWSPrototype(Storage& storage,const Parameters& params
227 #ifdef TRWS_DEBUG_OUTPUT
228  ,std::ostream& fout=std::cout
229 #endif
230  );
231  virtual ~TRWSPrototype();
232 
233  virtual ValueType GetBestIntegerBound()const{return _bestIntegerBound;};
234  virtual ValueType value()const{return _bestIntegerBound;}
235  virtual ValueType bound()const{return _dualBound;}
236  virtual const std::vector<LabelType>& arg()const{return _bestIntegerLabeling;}
237 
238 #ifdef TRWS_DEBUG_OUTPUT
239  virtual void PrintTestData(std::ostream& fout)const;
240 #endif
241 
242  bool CheckDualityGap(ValueType primalBound,ValueType dualBound);
243  virtual std::pair<ValueType,ValueType> GetMarginals(IndexType variable, OutputIteratorType begin){return std::make_pair((ValueType)0,(ValueType)0);};
244 
245  /*
246  * returns marginals of a subsolver for a given variable
247  * Index of the variable is local - for the given subsolver
248  */
249 
250  void GetMarginalsMove();
251  void BackwardMove();//optimization move, also estimates a primal bound
252 
253  ValueType getBound(size_t i)const{return _subSolvers[i]->GetObjectiveValue();}
254  virtual InferenceTermination infer(){EmptyVisitorParent vis; EmptyVisitorType visitor(&vis,this); return infer(visitor);};
255  template<class VISITOR> InferenceTermination infer(VISITOR&);
256  void ForwardMove();
259  }
260 
261  ValueType lastDualUpdate()const{return _lastDualUpdate;}
262 
263  template<class VISITOR> InferenceTermination infer_visitor_updates(VISITOR& visitor, size_t* pinterCounter=0);
264  InferenceTermination core_infer(size_t* piterCounter=0){EmptyVisitorParent vis; EmptyVisitorType visitor(&vis,this); return _core_infer(visitor,piterCounter);};
265  const FactorProperties& getFactorProperties()const{return _factorProperties;}
266 
267  /*
268  * typedef TRWS_Reparametrizer<Storage,ACC> ReparametrizerType;
269  */
270 // template<class ReparametrizerType>
271 // ReparametrizerType * getReparametrizer(const typename ReparametrizerType::Parameter& params=typename ReparametrizerType::Parameter())const
272 // {return new ReparametrizerType(_storage,_factorProperties,params);}
273 
274 
275 protected:
277  template <class VISITOR> InferenceTermination _core_infer(VISITOR& visitor, size_t* piterCounter=0);
279  virtual void _postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end)=0;
280  virtual void _normalizeMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end,SubSolver* subSolver)=0;
281  void _EvaluateIntegerBounds();
282 
283  /*
284  * Integer labeling computation functions
285  */
286  virtual void _SumUpForwardMarginals(std::vector<ValueType>* pout,const_marginals_iterators_pair itpair)=0;
287  void _EstimateIntegerLabel(IndexType varId,const std::vector<ValueType>& sumMarginal)
288  {_integerLabeling[varId]=std::max_element(sumMarginal.begin(),sumMarginal.end(),ACC::template ibop<ValueType>)-sumMarginal.begin();}
289 
290  void _InitSubSolvers();
291  void _ForwardMove();
292  void _FinalizeMove();
293  ValueType _GetObjectiveValue();
294  IndexType _order(IndexType i);
295  IndexType _core_order(IndexType i,IndexType totalSize);
296  virtual bool _CheckConvergence(ValueType relativeThreshold);
297 
298  virtual bool _CheckStoppingCondition(InferenceTermination* pterminationCode);
299  virtual void _EstimateTRWSBound(){};
300 
301  virtual void _InitMove()=0;
302 
303  Storage& _storage;
304  FactorProperties _factorProperties;
306  Parameters _parameters;
307 
308 #ifdef TRWS_DEBUG_OUTPUT
309  std::ostream& _fout;
310 #endif
311 
312  ValueType _dualBound;
313  ValueType _oldDualBound;
314  ValueType _lastDualUpdate;
315 
317  std::vector<SubSolver*> _subSolvers;
318 
319  std::vector<std::vector<ValueType> > _marginals;
320 
321  ValueType _integerBound;
322  ValueType _bestIntegerBound;
323 
324  std::vector<LabelType> _integerLabeling;
325  std::vector<LabelType> _bestIntegerLabeling;
326 
327  /* Computation optimization */
328  std::vector<ValueType> _sumMarginal;
329  mutable typename FactorProperties::ParameterStorageType _factorParameters;
330 
331 private:
333  TRWSPrototype& operator =(TRWSPrototype&);
334 };
335 
336 template<class ValueType>
338 {
340  ValueType smoothingValue_;
341  SumProdTRWS_Parameters(size_t maxIternum,
342  ValueType smValue,
343  ValueType precision=1.0,
344  bool absolutePrecision=true,
345  ValueType minRelativeDualImprovement=2*std::numeric_limits<ValueType>::epsilon(),
346  bool fastComputations=true,
347  bool canonicalNormalization=false)
348  :parent(maxIternum,precision,absolutePrecision,minRelativeDualImprovement,fastComputations,canonicalNormalization),
349  smoothingValue_(smValue){};
350 };
351 
352 template<class GM,class ACC>
353 class SumProdTRWS : public TRWSPrototype<SumProdSolver<GM,ACC,typename std::vector<typename GM::ValueType>::const_iterator> >
354 {
355 public:
361  typedef typename parent::ValueType ValueType;
362  typedef typename parent::IndexType IndexType;
363  typedef typename parent::LabelType LabelType;
368  typedef typename OutputContainerType::iterator OutputIteratorType;
369 
371 
372  SumProdTRWS(Storage& storage,const Parameters& params
373 #ifdef TRWS_DEBUG_OUTPUT
374  ,std::ostream& fout=std::cout
375 #endif
376  ):
377  parent(storage,params
378 #ifdef TRWS_DEBUG_OUTPUT
379  ,fout
380 #endif
381  ),
382  _bDualConverged(false),
383  _smoothingValue(params.smoothingValue_)
384  {};
386 
387 #ifdef TRWS_DEBUG_OUTPUT
388  void PrintTestData(std::ostream& fout)const;
389 #endif
390 
391  void SetSmoothing(ValueType smoothingValue){_bDualConverged=false; _smoothingValue=smoothingValue;_InitMove();}
392  ValueType GetSmoothing()const{return _smoothingValue;}
393  /*
394  * returns "averaged" over subsolvers marginals
395  * and pair of (ell_2 norm,ell_infty norm)
396  */
397  bool ConvergenceFlag()const{return _bDualConverged;}
398  std::pair<ValueType,ValueType> GetMarginals(IndexType variable, OutputIteratorType begin);
399  ValueType GetMarginalsAndDerivativeMove();
400  ValueType getDerivative(size_t i)const{return parent::_subSolvers[i]->getDerivative();}
401 
402 // template<class ITERATOR>
403 // void GetMarginalsForSubModel(IndexType modelId,IndexType localVarId,ITERATOR begin)
404 // { OPENGM_ASSERT(modelId < parent::_subSolvers.size());
405 // const_marginals_iterators_pair it=parent::_subSolvers[modelId]->GetMarginals(localVarId);
406 // ITERATOR end=begin+(it.second-it.first);
407 // std::copy(it.first,it.second,begin);
408 // _normalizeMarginals(begin,end,parent::_subSolvers[modelId]);
409 // ValueType mul; ACC::op(1.0,-1.0,mul);
410 // transform_inplace(begin,end,mulAndExp<ValueType>(mul));
411 // }
412 
413  template<class ITERATOR>
414  void GetMarginalsForSubModel(IndexType modelId,IndexType localVarId,ITERATOR begin)
415  { OPENGM_ASSERT(modelId < parent::_subSolvers.size());
416  const_marginals_iterators_pair it=parent::_subSolvers[modelId]->GetMarginals(localVarId);
417  ITERATOR end=begin+(it.second-it.first);
418  std::copy(it.first,it.second,begin);
419  ValueType mul; ACC::op(1.0,-1.0,mul);
420  _MaxNormalize_inplace(begin,end,(ValueType)0.0,ACC::template ibop<ValueType>);
421  transform_inplace(begin,end,mulAndExp<ValueType>(mul));
422  _MulNormalize(begin,end,(ValueType)0);
423  }
424 
425 protected:
426  void _SumUpForwardMarginals(std::vector<ValueType>* pout,const_marginals_iterators_pair itpair);
427  void _postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end);
428  void _normalizeMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end,SubSolver* subSolver);
429  void _InitMove();
430  bool _CheckConvergence(ValueType relativeThreshold){return _bDualConverged=parent::_CheckConvergence(relativeThreshold);};
431  bool _bDualConverged;
432  //bool _CheckStoppingCondition(InferenceTermination* pterminationCode);
433  ValueType _smoothingValue;
434 };
435 
436 //typedef TRWSPrototype_Parameters<ValueType> MaxSumTRWS_Parameters;
437 
438 template<class ValueType>
440 {
442 
443  MaxSumTRWS_Parameters(size_t maxIternum,
444  ValueType precision=1.0,
445  bool absolutePrecision=true,
446  ValueType minRelativeDualImprovement=-1.0,
447  bool fastComputations=true,
448  bool canonicalNormalization=false,
449  size_t treeAgreeMaxStableIter=0):
450  parent(maxIternum,precision,absolutePrecision,minRelativeDualImprovement,fastComputations,canonicalNormalization),
451  treeAgreeMaxStableIter_(treeAgreeMaxStableIter)
452  {
453 // if (treeAgreeMaxStableIter_==0)
454 // treeAgreeMaxStableIter_=maxIternum;
455 
456  };
457 
458  size_t treeAgreeMaxStableIter()const{return (treeAgreeMaxStableIter_==0 ? parent::maxNumberOfIterations_ : treeAgreeMaxStableIter_);}
459  void setTreeAgreeMaxStableIter(size_t val){treeAgreeMaxStableIter_=val;}
460 
461  private:
462  size_t treeAgreeMaxStableIter_;
463 };
464 
465 template<class GM,class ACC>
466 class MaxSumTRWS : public TRWSPrototype<MaxSumSolver<GM,ACC,typename std::vector<typename GM::ValueType>::const_iterator> >
467 {
468 public:
470  //typedef typename parent::Parameters Parameters;
473  typedef typename parent::ValueType ValueType;
474  typedef typename parent::IndexType IndexType;
475  typedef typename parent::LabelType LabelType;
482  // typedef typename parent::ReparametrizerType ReparametrizerType;
483 
486 
488 
489  MaxSumTRWS(Storage& storage,const Parameters& params
490 #ifdef TRWS_DEBUG_OUTPUT
491  ,std::ostream& fout=std::cout
492 #endif
493  ):
494  parent(storage,params
495 #ifdef TRWS_DEBUG_OUTPUT
496  ,fout
497 #endif
498  ),
499  _parameters(params),
500  _pseudoBoundValue(0.0),
502  _agree_count(0),
504  {}
506 
507  void getTreeAgreement(std::vector<bool>& out,std::vector<LabelType>* plabeling=0,std::vector<std::vector<LabelType> >* ptreeLabelings=0);
508  bool CheckTreeAgreement(InferenceTermination* pterminationCode);
509 protected:
510  void _SumUpForwardMarginals(std::vector<ValueType>* pout,const_marginals_iterators_pair itpair);
511  void _postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end);
512  void _normalizeMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end,SubSolver* subSolver);
513  void _InitMove();
514  void _EstimateTRWSBound();
515  bool _CheckStoppingCondition(InferenceTermination* pterminationCode);
516 
517  Parameters _parameters;
518 
519  ValueType _pseudoBoundValue;
521  /*
522  * computaton optimization
523  */
524  std::vector<bool> _treeAgree;
525  std::vector<bool> _mask;
526  std::vector<bool> _nodeMask;
527 
528  size_t _agree_count;
530 };
531 //============ TRWSPrototype IMPLEMENTATION ======================================
532 
533 template <class SubSolver>
535 #ifdef TRWS_DEBUG_OUTPUT
536  ,std::ostream& fout
537 #endif
538 ):
539 _storage(storage),
540 _factorProperties(storage.masterModel()),
541 _ftable(storage.masterModel()),
542 _parameters(params),
543 #ifdef TRWS_DEBUG_OUTPUT
544 _fout(fout),
545 #endif
546 _dualBound(ACC::template ineutral<ValueType>()),
547 _oldDualBound(ACC::template ineutral<ValueType>()),
548 _lastDualUpdate(0),
549 _moveDirection(SubModel::Direct),
550 _subSolvers(),
551 _marginals(),
552 _integerBound(ACC::template neutral<ValueType>()),
553 _bestIntegerBound(ACC::template neutral<ValueType>()),
554 _integerLabeling(storage.masterModel().numberOfVariables(),0),
555 _bestIntegerLabeling(storage.masterModel().numberOfVariables(),0),
556 _sumMarginal()
557 {
558 #ifdef TRWS_DEBUG_OUTPUT
559  _fout.precision(16);
560 #endif
561  _InitSubSolvers();
562  _marginals.resize(_storage.numberOfModels());
563 #ifdef TRWS_DEBUG_OUTPUT
564  _factorProperties.PrintStatusData(fout);
565 #endif
566 }
567 
568 template <class SubSolver>
570 {
571  for_each(_subSolvers.begin(),_subSolvers.end(),DeallocatePointer<SubSolver>);
572 };
573 
574 template <class SubSolver>
576 {
577  _subSolvers.resize(_storage.numberOfModels());
578  for (size_t modelId=0;modelId<_subSolvers.size();++modelId)
579  _subSolvers[modelId]= new SubSolver(_storage.subModel(modelId),_factorProperties,_parameters.fastComputations_);
580 }
581 
582 template <class SubSolver>
584 {
585  OPENGM_ASSERT((ACC::bop(-1,1) ? 1 : -1 )*(primalBound-dualBound) > -dualBound*std::numeric_limits<ValueType>::epsilon());
586 
587 // _fout << "(ACC::bop(-1,1) ? 1 : -1 )*(primalBound-dualBound)=" << (ACC::bop(-1,1) ? 1 : -1 )*(primalBound-dualBound)
588 // << ", -dualBound*std::numeric_limits<ValueType>::epsilon()=" << -dualBound*std::numeric_limits<ValueType>::epsilon()<<std::endl;
589 
590  ValueType endPrecision=std::max((ValueType)fabs(dualBound)*std::numeric_limits<ValueType>::epsilon(),_parameters.precision_);
591 
592 // _fout << "endPrecision="<<endPrecision<<", std::numeric_limits<ValueType>::epsilon()="
593 // <<std::numeric_limits<ValueType>::epsilon() <<", _parameters.precision_="<<_parameters.precision_<<std::endl;
594 
595  if (_parameters.absolutePrecision_)
596  {
597  if (fabs(primalBound-dualBound) <= endPrecision)
598  {
599  return true;
600  }
601  }
602  else
603  {
604  if (fabs((primalBound-dualBound))<= fabs(dualBound)*endPrecision )
605  return true;
606  }
607  return false;
608 }
609 
610 //template <class SubSolver>
611 //bool TRWSPrototype<SubSolver>::CheckDualityGap(ValueType primalBound,ValueType dualBound)
612 //{
613 // //TODO: check that primal bound > dualBound if (bop(primalBound,dualBound)
614 //
615 // OPENGM_ASSERT((ACC::bop(-1,1) ? 1 : -1 )*(primalBound-dualBound) > -dualBound*std::numeric_limits<ValueType>::epsilon());
616 //
617 //
618 // if (_parameters.absolutePrecision_)
619 // {
620 // if (fabs(primalBound-dualBound) <= _parameters.precision_)
621 // {
622 // return true;
623 // }
624 // }
625 // else
626 // {
627 // if (fabs((primalBound-dualBound)/dualBound)<= _parameters.precision_)
628 // return true;
629 // }
630 // return false;
631 //}
632 
633 
634 template <class SubSolver>
636 {
637  if (relativeThreshold >=0.0)
638  {
639  ValueType mul; ACC::iop(-1.0,1.0,mul);
640  if (ACC::bop(_dualBound, (_oldDualBound + static_cast<ValueType>(fabs(_dualBound))*mul*relativeThreshold)))
641  return true;
642  }
643  return false;
644 }
645 
646 template <class SubSolver>
648 {
649  _lastDualUpdate=fabs(_dualBound-_oldDualBound);
650 
651  if (CheckDualityGap(_bestIntegerBound,_dualBound))
652  {
653 #ifdef TRWS_DEBUG_OUTPUT
654  _fout << "TRWSPrototype::_CheckStoppingCondition(): duality gap <= specified precision!" <<std::endl;
655 #endif
656  *pterminationCode=opengm::CONVERGENCE;
657  return true;
658  }
659 
660  if (_CheckConvergence(_parameters.minRelativeDualImprovement_))
661  {
662 #ifdef TRWS_DEBUG_OUTPUT
663  _fout << "TRWSPrototype::_CheckStoppingCondition(): Dual update is smaller than the specified threshold. Stopping"<<std::endl;
664 #endif
665  *pterminationCode=opengm::NORMAL;
666  return true;
667  }
668 
669  _oldDualBound=_dualBound;
670 
671  return false;
672 }
673 
674 template <class SubSolver>
675 template <class VISITOR>
677 {
678  for (size_t iterationCounter=0;iterationCounter<_parameters.maxNumberOfIterations_;++iterationCounter)
679  {
680 #ifdef TRWS_DEBUG_OUTPUT
681  _fout <<"Iteration Nr."<<iterationCounter<<"-------------------------------------"<<std::endl;
682 #endif
683 
684  BackwardMove();
685 
686 #ifdef TRWS_DEBUG_OUTPUT
687  _fout << "dualBound=" << _dualBound <<", primalBound="<<_GetPrimalBound() <<std::endl;
688 #endif
689  _EstimateTRWSBound();
690  const size_t visitorReturn = visitor();
691 
692  if (piterCounter!=0) *piterCounter=iterationCounter+1;
693 
694  InferenceTermination returncode;
695  if (_CheckStoppingCondition(&returncode))
696  return returncode;
697 
698  if( visitorReturn != visitors::VisitorReturnFlag::ContinueInf ){
700  return opengm::CONVERGENCE;
701  } else {
702  return opengm::TIMEOUT;
703  }
704  }
705  }
706  return opengm::TIMEOUT;
707 }
708 
709 template <class SubSolver>
711 {
712  ValueType dualBound=0;
713  for (size_t i=0;i<_subSolvers.size();++i)
714  dualBound+=_subSolvers[i]->GetObjectiveValue();
715 
716  return dualBound;
717 }
718 
719 template <class SubSolver>
721 {
722  std::for_each(_subSolvers.begin(), _subSolvers.end(), std::mem_fun(&SubSolver::Move));
723  _moveDirection=SubModel::ReverseDirection(_moveDirection);
724  _dualBound=_GetObjectiveValue();
725 }
726 
727 template <class SubSolver>
729 {
730  std::for_each(_subSolvers.begin(), _subSolvers.end(), std::mem_fun(&SubSolver::MoveBack));
731  _moveDirection=SubModel::ReverseDirection(_moveDirection);
732 }
733 
734 template <class SubSolver>
736 {
737  return (_moveDirection==SubModel::Direct ? i : totalSize-i-1);
738 }
739 
740 template <class SubSolver>
742 {
743  return _core_order(i,_storage.numberOfSharedVariables());
744 }
745 
746 template <class SubSolver>
748 {
749  std::for_each(_subSolvers.begin(), _subSolvers.end(), std::mem_fun(&SubSolver::FinalizeMove));
750  _moveDirection=SubModel::ReverseDirection(_moveDirection);
751  _EstimateIntegerLabeling();
752 }
753 
754 #ifdef TRWS_DEBUG_OUTPUT
755 template <class SubSolver>
756 void TRWSPrototype<SubSolver>::PrintTestData(std::ostream& fout)const
757 {
758  fout << "_dualBound:" << _dualBound <<std::endl;
759  fout << "_oldDualBound:" << _oldDualBound <<std::endl;
760  fout << "_lastDualUpdate=" << _lastDualUpdate << std::endl;
761  fout << "_moveDirection:" << _moveDirection <<std::endl;
762  fout << "_integerBound=" << _integerBound << std::endl;
763  fout << "_bestIntegerBound=" << _bestIntegerBound << std::endl;
764  fout << "_integerLabeling=" << _integerLabeling;
765  fout << "_bestIntegerLabeling=" << _bestIntegerLabeling;
766 }
767 #endif
768 
769 template <class SubSolver>
770 template <class VISITOR>
772 {
773  visitor.begin();
774  InferenceTermination returncode=infer_visitor_updates(visitor);
775  visitor.end();
776  return returncode;
777 }
778 
779 template <class SubSolver>
780 template <class VISITOR>
782 {
783  _InitMove();
784  _ForwardMove();
785  _oldDualBound=_dualBound;
786 #ifdef TRWS_DEBUG_OUTPUT
787  _fout << "ForwardMove: dualBound=" << _dualBound <<std::endl;
788 #endif
789 
790  const size_t visitorReturn = visitor();
791  if( visitorReturn != visitors::VisitorReturnFlag::ContinueInf ){
793  return opengm::CONVERGENCE;
794  } else {
795  return opengm::TIMEOUT;
796  }
797  }
798 
799  InferenceTermination returncode;
800  returncode=_core_infer(visitor,piterCounter);
801  if (piterCounter!=0) ++(*piterCounter);
802  return returncode;
803 }
804 
805 template <class SubSolver>
807 {
808  _InitMove();
809  _ForwardMove();
810  _dualBound=_GetObjectiveValue();
811 }
812 
813 
814 template <class SubSolver>
816 {
817  std::vector<ValueType> averageMarginal;
818 
819  for (IndexType i=0;i<_storage.numberOfSharedVariables();++i)
820  {
821  IndexType varId=_order(i);
822  const typename Storage::SubVariableListType& varList=_storage.getSubVariableList(varId);
823  averageMarginal.assign(_storage.numberOfLabels(varId),0.0);
824 
825  //<!computing average marginals
826  for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
827  {
828  SubSolver& subSolver=*_subSolvers[modelIt->subModelId_];
829  std::vector<ValueType>& marginals=_marginals[modelIt->subModelId_];
830  marginals.resize(_storage.numberOfLabels(varId));
831 
832  IndexType startNodeIndex=_core_order(0,_storage.size(modelIt->subModelId_));
833 
834  if (modelIt->subVariableId_!=startNodeIndex)
835  subSolver.PushBack();
836 
837  typename SubSolver::const_iterators_pair marginalsit=subSolver.GetMarginals();
838 
839  std::copy(marginalsit.first,marginalsit.second,marginals.begin());
840  if (_parameters.canonicalNormalization_)
841  _normalizeMarginals(marginals.begin(),marginals.end(),&subSolver);
842  std::transform(marginals.begin(),marginals.end(),averageMarginal.begin(),averageMarginal.begin(),std::plus<ValueType>());
843  }
844  transform_inplace(averageMarginal.begin(),averageMarginal.end(),std::bind1st(std::multiplies<ValueType>(),-1.0/varList.size()));
845 
846 
847  //<!reweighting submodels
848 
849  for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
850  {
851  SubSolver& subSolver=*_subSolvers[modelIt->subModelId_];
852  std::vector<ValueType>& marginals=_marginals[modelIt->subModelId_];
853 
854  std::transform(marginals.begin(),marginals.end(),averageMarginal.begin(),marginals.begin(),std::plus<ValueType>());
855 
856  _postprocessMarginals(marginals.begin(),marginals.end());
857 
858  subSolver.IncreaseUnaryWeights(marginals.begin(),marginals.end());
859 
860  IndexType startNodeIndex=_core_order(0,_storage.size(modelIt->subModelId_));
861 
862  if (modelIt->subVariableId_!=startNodeIndex)
863  subSolver.UpdateMarginals();
864  else subSolver.InitReverseMove();
865  }
866  }
867 
868  _FinalizeMove();
869  _EvaluateIntegerBounds();
870  _dualBound=_GetObjectiveValue();
871 }
872 
873 template <class SubSolver>
875 {
876  for (IndexType i=0;i<_storage.numberOfSharedVariables();++i)
877  {
878  IndexType varId=_order(i);
879 
880  const typename Storage::SubVariableListType& varList=_storage.getSubVariableList(varId);
881  _sumMarginal.assign(_storage.masterModel().numberOfLabels(varId),0.0);
882  for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
883  {
884  const_marginals_iterators_pair itpair=_subSolvers[modelIt->subModelId_]->GetMarginals(modelIt->subVariableId_);
885  _SumUpForwardMarginals(&_sumMarginal,itpair);
886  }
887 
888  typename PreviousFactorTable<GM>::const_iterator begin=_ftable.begin(varId,_moveDirection);
889  typename PreviousFactorTable<GM>::const_iterator end=_ftable.end(varId,_moveDirection);
890  for (;begin!=end;++begin)
891  {
892  LabelType fixedLabel=_integerLabeling[begin->varId];
893  if ((_factorProperties.getFunctionType(begin->factorId)==FunctionParameters<GM>::POTTS) && _parameters.fastComputations_)
894  {
895  if (_sumMarginal.size() > fixedLabel)
896  _sumMarginal[_integerLabeling[begin->varId]]-=_factorProperties.getFunctionParameters(begin->factorId)[0];//instead of adding everywhere the same we just subtract the difference
897  }else
898  {
899  const typename GM::FactorType& pwfactor=_storage.masterModel()[begin->factorId];
900  IndexType localVarIndx = begin->localId;
901  //LabelType fixedLabel=_integerLabeling[begin->varId];
902  opengm::ViewFixVariablesFunction<GM> pencil(pwfactor,
903  std::vector<opengm::PositionAndLabel<IndexType,LabelType> >(1,
904  opengm::PositionAndLabel<IndexType,LabelType>(localVarIndx,
905  fixedLabel)));
906 
907  for (LabelType j=0;j<_sumMarginal.size();++j)
908  _sumMarginal[j]+=pencil(&j);
909  }
910  }
911  _EstimateIntegerLabel(varId,_sumMarginal);
912  }
913 }
914 
915 template <class SubSolver>
917 {
918  _integerBound=_storage.masterModel().evaluate(_integerLabeling.begin());
919  if (ACC::bop(_integerBound,_bestIntegerBound))
920  {
921  _bestIntegerLabeling=_integerLabeling;
922  _bestIntegerBound=_integerBound;
923  }
924 }
925 
926 //================================= DecompositionStorage IMPLEMENTATION =================================================
927 template<class GM>
928 DecompositionStorage<GM>::DecompositionStorage(const GM& gm,StructureType structureType, const DDVectorType* pddvector):
929 _gm(gm),
930 _structureType(structureType),
931 _subModels(),
932 _variableDecomposition(),
933 _var2FactorMap(gm)
934 {
935  _InitSubModels(pddvector);
936 }
937 
938 template<class GM>
940 {
941  for_each(_subModels.begin(),_subModels.end(),DeallocatePointer<SubModel>);
942 }
943 
944 template<class GM>
945 void DecompositionStorage<GM>::_InitSubModels(const DDVectorType* pddvector)
946 {
947  std::auto_ptr<Decomposition<GM> > pdecomposition;
948 
949  switch (_structureType)
950  {
951  case GRIDSTRUCTURE:
952  {
953  pdecomposition=std::auto_ptr<Decomposition<GM> >(new GridDecomposition<GM>(_gm));
954  break;
955  }
956  case EDGESTRUCTURE:
957  {
958  pdecomposition=std::auto_ptr<Decomposition<GM> >(new EdgeDecomposition<GM>(_gm));
959  break;
960  }
961  case GENERALSTRUCTURE:
962  {
963  pdecomposition=std::auto_ptr<Decomposition<GM> >(new MonotoneChainsDecomposition<GM>(_gm));
964  break;
965  }
966  default:
967  throw std::runtime_error("DecompositionStorage::_InitSubModels: Unknown decomposition type!");
968  }
969 // if (_structureType==GRIDSTRUCTURE)
970 // pdecomposition=std::auto_ptr<Decomposition<GM> >(new GridDecomposition<GM>(_gm));
971 // else (_structureType==EDGESTRUCTURE)
972 // pdecomposition=std::auto_ptr<Decomposition<GM> >(new EdgeDecomposition<GM>(_gm));
973 // else
974 // pdecomposition=std::auto_ptr<Decomposition<GM> >(new MonotoneChainsDecomposition<GM>(_gm));
975 
976  try{
977  pdecomposition->ComputeVariableDecomposition(&_variableDecomposition);
978  size_t numberOfModels=pdecomposition->getNumberOfSubModels();
979  _subModels.resize(numberOfModels);
980  for (size_t modelId=0;modelId<numberOfModels;++modelId)
981  {
982  const typename SubModel::IndexList& varList=pdecomposition->getVariableList(modelId);
983  typename SubModel::IndexList numOfSubModelsPerVar(varList.size());
984  // Initialize numOfSubModelsPerVar
985  for (size_t varIndx=0;varIndx<varList.size();++varIndx)
986  numOfSubModelsPerVar[varIndx]=_variableDecomposition[varList[varIndx]].size();
987 
988  _subModels[modelId]= new SubModel(_gm,_var2FactorMap,varList,pdecomposition->getFactorList(modelId),numOfSubModelsPerVar);
989  };
990 
991  if (pddvector!=0)
992  addDDvector(*pddvector);
993 
994  }catch(std::runtime_error& err)
995  {
996  throw err;
997  }
998 };
999 
1000 
1001 template<class GM>
1003 {
1004  if (delta.size()!=getDDVectorSize())
1005  throw std::runtime_error("DecompositionStorage<GM>::addDDvector(): Error: size of the input vector does not match the size of the graphical model.");
1006 
1007  typename DDVectorType::const_iterator deltaIt=delta.begin();
1008  for (IndexType varId=0;varId<masterModel().numberOfVariables();++varId)// all variables
1009  { const SubVariableListType& varList=getSubVariableList(varId);
1010 
1011  if (varList.size()==1) continue;
1012  typename SubVariableListType::const_iterator modelIt=varList.begin();
1013  IndexType firstModelId=modelIt->subModelId_;
1014  IndexType firstModelVariableId=modelIt->subVariableId_;
1015  ++modelIt;
1016  for(;modelIt!=varList.end();++modelIt) //all related models
1017  {
1018  std::transform(subModel(modelIt->subModelId_).ufBegin(modelIt->subVariableId_),
1019  subModel(modelIt->subModelId_).ufEnd(modelIt->subVariableId_),
1020  deltaIt,subModel(modelIt->subModelId_).ufBegin(modelIt->subVariableId_),
1021  std::plus<ValueType>());
1022 
1023  std::transform(subModel(firstModelId).ufBegin(firstModelVariableId),
1024  subModel(firstModelId).ufEnd(firstModelVariableId),
1025  deltaIt,subModel(firstModelId).ufBegin(firstModelVariableId),
1026  std::minus<ValueType>());
1027  deltaIt+=masterModel().numberOfLabels(varId);
1028  }
1029  }
1030 }
1031 
1032 template<class GM>
1034 {
1035  pddvector->resize(getDDVectorSize());
1036  typename DDVectorType::iterator gradientIt=pddvector->begin();
1037  UnaryFactor uf;
1038  for (IndexType varId=0;varId<_gm.numberOfVariables();++varId)// all variables
1039  {
1040  const SubVariableListType& varList=getSubVariableList(varId);
1041 
1042  if (varList.size()==1) continue;
1043  typename SubVariableListType::const_iterator modelIt=varList.begin();
1044  uf.resize(_gm.numberOfLabels(varId));
1045  _gm[_var2FactorMap(varId)].copyValues(uf.begin());
1046  transform_inplace(uf.begin(),uf.end(),std::bind2nd(std::multiplies<ValueType>(),1.0/varList.size()));
1047  ++modelIt;
1048  for(;modelIt!=varList.end();++modelIt) //all related models
1049  {
1050  const std::vector<ValueType>& buffer=subModel(modelIt->subModelId_).unaryFactors(modelIt->subVariableId_);
1051  gradientIt=std::transform(buffer.begin(),buffer.end(),uf.begin(),gradientIt,std::minus<ValueType>());
1052  }
1053  }
1054 }
1055 
1056 
1057 template<class GM>
1059 {
1060  size_t varsize=0;
1061  for (IndexType varId=0;varId<_gm.numberOfVariables();++varId)// all variables
1062  varsize+=(getSubVariableList(varId).size()-1)*_gm.numberOfLabels(varId);
1063  return varsize;
1064 }
1065 
1066 #ifdef TRWS_DEBUG_OUTPUT
1067 template<class GM>
1068 void DecompositionStorage<GM>::PrintTestData(std::ostream& fout)const
1069 {
1070  fout << "_variableDecomposition: "<<std::endl;
1071  for (size_t variableId=0;variableId<_variableDecomposition.size();++variableId)
1072  {
1073  std::for_each(_variableDecomposition[variableId].begin(),_variableDecomposition[variableId].end(),printSubVariable<typename MonotoneChainsDecomposition<GM>::SubVariable>(fout));
1074  fout << std::endl;
1075  }
1076 }
1077 
1078 template<class GM>
1079 void DecompositionStorage<GM>::PrintVariableDecompositionConsistency(std::ostream& fout)const
1080 {
1081  fout << "Variable decomposition consistency:" <<std::endl;
1082  for (size_t varId=0;varId<_gm.numberOfVariables();++varId)
1083  {
1084  fout << varId<<": ";
1085  const SubVariableListType& varList=_variableDecomposition[varId];
1086  typename SubVariableListType::const_iterator modelIt=varList.begin();
1087  std::vector<ValueType> sum(_gm.numberOfLabels(varId),0.0);
1088  while (modelIt!=varList.end())
1089  {
1090  const SubModel& subModel=*_subModels[modelIt->subModelId_];
1091  std::transform(subModel.unaryFactors(modelIt->subVariableId_).begin(),subModel.unaryFactors(modelIt->subVariableId_).end(),
1092  sum.begin(),sum.begin(),std::plus<ValueType>());
1093  ++modelIt;
1094  }
1095  std::vector<ValueType> originalFactor(_gm.numberOfLabels(varId),0.0);
1096  _gm[varId].copyValues(originalFactor.begin());
1097 
1098  std::transform(sum.begin(),sum.end(),originalFactor.begin(),sum.begin(),std::minus<ValueType>());
1099  fout << std::accumulate(sum.begin(),sum.end(),(ValueType)0.0)<<std::endl;
1100  }
1101 
1102 }
1103 #endif
1104 //================================= MaxSumTRWS IMPLEMENTATION =================================================
1105 
1106 template<class GM,class ACC>
1108 {
1109  parent::_moveDirection=SubModel::Direct;
1110  std::for_each(parent::_subSolvers.begin(), parent::_subSolvers.end(), std::mem_fun_t<void,SubSolver>(&SubSolver::InitMove));
1111 }
1112 
1113 template<class GM,class ACC>
1114 void MaxSumTRWS<GM,ACC>::_postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end)
1115 {
1116  transform_inplace(begin,end,std::bind1st(std::multiplies<ValueType>(),-1.0));
1117 }
1118 
1119 template<class GM,class ACC>
1121 {
1122  std::transform(itpair.first,itpair.second,pout->begin(),pout->begin(),std::plus<ValueType>());
1123 }
1124 
1125 template<class GM,class ACC>
1127 {
1128  if (parent::_parameters.canonicalNormalization_) return;
1129  std::vector<ValueType> bounds(parent::_storage.numberOfModels());
1130  for (size_t i=0;i<bounds.size();++i)
1131  bounds[i]=parent::_subSolvers[i]->GetObjectiveValue();
1132 
1133  ValueType min=*std::min_element(bounds.begin(),bounds.end());
1134  ValueType max=*std::max_element(bounds.begin(),bounds.end());
1135  ValueType eps; ACC::iop(max-min,min-max,eps);
1136  ACC::iop(min,max,_pseudoBoundValue);
1137 #ifdef TRWS_DEBUG_OUTPUT
1138  parent::_fout <<"min="<<min<<", max="<<max<<", eps="<<eps<<", pseudo bound="<<bounds.size()*_pseudoBoundValue<<std::endl;
1139 #endif
1140 }
1141 
1142 
1143 template<class GM,class ACC>
1144 void MaxSumTRWS<GM,ACC>::_normalizeMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end,SubSolver* subSolver)
1145 {
1146  //if (!parent::_parameters.canonicalNormalization_) return;
1147  ValueType maxVal=*std::max_element(begin,end,ACC::template bop<ValueType>);
1148  transform_inplace(begin,end,std::bind2nd(std::plus<ValueType>(),-maxVal));
1149 }
1150 
1151 template<class GM,class ACC>
1152 void MaxSumTRWS<GM,ACC>::getTreeAgreement(std::vector<bool>& out,std::vector<LabelType>* plabeling,std::vector<std::vector<LabelType> >* ptreeLabelings)
1153 {
1154  if (plabeling!=0)
1155  plabeling->resize(parent::_storage.masterModel().numberOfVariables());
1156  if (ptreeLabelings!=0)
1157  ptreeLabelings->assign(parent::_storage.masterModel().numberOfVariables(),std::vector<LabelType>());
1158 
1159  out.assign(parent::_storage.masterModel().numberOfVariables(),true);
1160  for (size_t varId=0;varId<parent::_storage.masterModel().numberOfVariables();++varId)
1161  {
1162  const typename Storage::SubVariableListType& varList=parent::_storage.getSubVariableList(varId);
1163  size_t label=0;
1164  for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin()
1165  ;modelIt!=varList.end();++modelIt)
1166  {
1167  size_t check_label=parent::_subSolvers[modelIt->subModelId_]->arg()[modelIt->subVariableId_];
1168 
1169  if (plabeling!=0) (*plabeling)[varId]=check_label;
1170  if (ptreeLabelings!=0) (*ptreeLabelings)[varId].push_back(check_label);
1171 
1172  if (modelIt==varList.begin())
1173  {
1174  label=check_label;
1175  }else if (check_label!=label)
1176  {
1177  out[varId]=false;
1178  break;
1179  }
1180  }
1181 
1182  }
1183 }
1184 
1185 
1186 template<class GM,class ACC>
1188 {
1189  getTreeAgreement(_treeAgree);
1190  size_t agree_count=count(_treeAgree.begin(),_treeAgree.end(),true);
1191  if (agree_count > _agree_count)
1192  {
1193  _treeAgree_iterationCounter=0;
1194  _agree_count=agree_count;
1195  }
1196  else
1197  ++_treeAgree_iterationCounter;
1198 
1199 #ifdef TRWS_DEBUG_OUTPUT
1200  parent::_fout << "tree-agreement: " << agree_count <<" out of "<<_treeAgree.size() <<", ="<<100*(double)agree_count/_treeAgree.size()<<"%"<<std::endl;
1201 #endif
1202 
1203  if (_treeAgree.size()==agree_count)
1204  {
1205 #ifdef TRWS_DEBUG_OUTPUT
1206  parent::_fout <<"Problem solved."<<std::endl;
1207 #endif
1208  *pterminationCode=opengm::CONVERGENCE;
1209  return true;
1210  }else
1211  return false;
1212 }
1213 
1214 template<class GM,class ACC>
1216 {
1217  if (CheckTreeAgreement(pterminationCode)) return true;
1218 
1219  if (_treeAgree_iterationCounter > _parameters.treeAgreeMaxStableIter())
1220  {
1221 #ifdef TRWS_DEBUG_OUTPUT
1222  parent::_fout <<"There were no improvement of tree agreement during last "<<_treeAgree_iterationCounter <<" steps. Aborting."<<std::endl;
1223 #endif
1224  *pterminationCode=NORMAL;
1225  return true;
1226  }
1227 
1228  return parent::_CheckStoppingCondition(pterminationCode);
1229 }
1230 
1231 //================================= SumProdTRWS IMPLEMENTATION =================================================
1232 #ifdef TRWS_DEBUG_OUTPUT
1233 template<class GM,class ACC>
1234 void SumProdTRWS<GM,ACC>::PrintTestData(std::ostream& fout)const
1235 {
1236  fout << "_smoothingValue:"<<_smoothingValue <<std::endl;
1237  parent::PrintTestData(fout);
1238 }
1239 #endif
1240 
1241 template<class GM,class ACC>
1242 void SumProdTRWS<GM,ACC>::_InitMove()//(ValueType smoothingValue)
1243 {
1244  parent::_moveDirection=SubModel::Direct;
1245  std::for_each(parent::_subSolvers.begin(), parent::_subSolvers.end(), std::bind2nd(std::mem_fun(&SubSolver::InitMove),_smoothingValue));
1246 }
1247 
1248 template<class GM,class ACC>
1249 void SumProdTRWS<GM,ACC>::_normalizeMarginals(typename std::vector<ValueType>::iterator begin,
1250  typename std::vector<ValueType>::iterator end,SubSolver* subSolver)
1251 {
1252  //if (!parent::_parameters.canonicalNormalization_) return;
1253  ValueType logPartition=subSolver->ComputeObjectiveValue();
1254  //normalizing marginals - subtracting log-partition function value/smoothing
1255  transform_inplace(begin,end,std::bind2nd(std::plus<ValueType>(),-logPartition/_smoothingValue));
1256 }
1257 
1258 template<class GM,class ACC>
1259 void SumProdTRWS<GM,ACC>::_postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end)
1260 {
1261  transform_inplace(begin,end,std::bind1st(std::multiplies<ValueType>(),-_smoothingValue));
1262 }
1263 
1264 template<class GM,class ACC>
1266 {
1267  std::transform(pout->begin(),pout->end(),itpair.first,pout->begin(),plus2ndMul<ValueType>(_smoothingValue));
1268 }
1269 
1270 template<class GM,class ACC>
1271 std::pair<typename SumProdTRWS<GM,ACC>::ValueType,typename SumProdTRWS<GM,ACC>::ValueType>
1273 {
1274  std::fill_n(begin,parent::_storage.numberOfLabels(varId),0.0);
1275  const typename Storage::SubVariableListType& varList=parent::_storage.getSubVariableList(varId);
1276 
1277  OPENGM_ASSERT(varList.size()>0);
1278 
1279  for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
1280  {
1281  std::vector<ValueType>& normMarginals=parent::_marginals[modelIt->subModelId_];
1282  normMarginals.resize(parent::_storage.numberOfLabels(varId));
1283  GetMarginalsForSubModel(modelIt->subModelId_,modelIt->subVariableId_,normMarginals.begin());
1284  std::transform(normMarginals.begin(),normMarginals.end(),begin,begin,std::plus<ValueType>());
1285  }
1286  transform_inplace(begin,begin+parent::_storage.numberOfLabels(varId),std::bind1st(std::multiplies<ValueType>(),1.0/varList.size()));
1287 
1288  ValueType ell2Norm=0, ellInftyNorm=0;
1289  for (typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
1290  {
1291  std::vector<ValueType>& normMarginals=parent::_marginals[modelIt->subModelId_];
1292  OutputIteratorType begin0=begin;
1293  for (typename std::vector<ValueType>::const_iterator bm=normMarginals.begin(); bm!=normMarginals.end();++bm)
1294  {
1295  //ValueType diff=(*bm-*begin0); ++begin0;
1296  ValueType diff=std::min((*bm-*begin0),*begin0); ++begin0;
1297  ell2Norm+=diff*diff;
1298  ellInftyNorm=std::max((ValueType)fabs(diff),ellInftyNorm);
1299  }
1300  }
1301 
1302  return std::make_pair(sqrt(ell2Norm),ellInftyNorm);
1303 }
1304 
1305 
1306 template<class GM,class ACC>
1309 {
1310  ValueType derivativeValue=0.0;
1311  //std::for_each(parent::_subSolvers.begin(), parent::_subSolvers.end(), std::(&SubSolver::MoveBackGetDerivative()));
1312  for (size_t i=0;i<parent::_subSolvers.size();++i)
1313  derivativeValue+=parent::_subSolvers[i]->MoveBackGetDerivative();
1314 
1315  parent::_moveDirection=SubModel::ReverseDirection(parent::_moveDirection);
1316  return derivativeValue;
1317 }
1318 
1319 };//DD
1320 }//namespace opengm
1321 #endif /* ADSAL_H_ */
TRWSPrototype_Parameters< ValueType > parent
Definition: trws_base.hxx:339
const SubModel & subModel(IndexType modelId) const
Definition: trws_base.hxx:61
ValueType getBound(size_t i) const
Definition: trws_base.hxx:253
StructureType getStructureType() const
Definition: trws_base.hxx:65
virtual ValueType value() const
Definition: trws_base.hxx:234
std::vector< LabelType > _integerLabeling
Definition: trws_base.hxx:324
The OpenGM namespace.
Definition: config.hxx:43
void _postprocessMarginals(typename std::vector< ValueType >::iterator begin, typename std::vector< ValueType >::iterator end)
Definition: trws_base.hxx:1114
virtual InferenceTermination infer()
Definition: trws_base.hxx:254
OutputContainerType::iterator OutputIteratorType
Definition: trws_base.hxx:368
parent::LabelType LabelType
Definition: trws_base.hxx:363
parent::ValueType ValueType
Definition: trws_base.hxx:473
const SubVariableListType & getSubVariableList(IndexType varId) const
Definition: trws_base.hxx:64
std::vector< bool > _treeAgree
Definition: trws_base.hxx:524
SubSolver::const_iterators_pair const_marginals_iterators_pair
Definition: trws_base.hxx:212
virtual void _SumUpForwardMarginals(std::vector< ValueType > *pout, const_marginals_iterators_pair itpair)=0
IndexType size(IndexType subModelId) const
Definition: trws_base.hxx:62
bool CheckDualityGap(ValueType primalBound, ValueType dualBound)
Definition: trws_base.hxx:583
std::vector< bool > _nodeMask
Definition: trws_base.hxx:526
parent::IndexType IndexType
Definition: trws_base.hxx:362
DecompositionStorage(const GM &gm, StructureType structureType=GENERALSTRUCTURE, const DDVectorType *pddvector=0)
Definition: trws_base.hxx:928
std::pair< ValueType, ValueType > GetMarginals(IndexType variable, OutputIteratorType begin)
Definition: trws_base.hxx:1272
parent::IndexType IndexType
Definition: trws_base.hxx:474
PreviousFactorTable< GM > _ftable
Definition: trws_base.hxx:305
TRWSPrototype< SumProdSolver< GM, ACC, typename std::vector< typename GM::ValueType >::const_iterator > > parent
Definition: trws_base.hxx:356
virtual ValueType _GetPrimalBound()
Definition: trws_base.hxx:278
parent::const_marginals_iterators_pair const_marginals_iterators_pair
Definition: trws_base.hxx:360
FactorVarID(IndexType fID, IndexType vID, IndexType lID)
Definition: trws_base.hxx:120
std::vector< std::vector< ValueType > > _marginals
computation optimization
Definition: trws_base.hxx:319
std::vector< ValueType > _sumMarginal
Definition: trws_base.hxx:328
DecompositionStorage< GM > Storage
Definition: trws_base.hxx:223
virtual std::pair< ValueType, ValueType > GetMarginals(IndexType variable, OutputIteratorType begin)
Definition: trws_base.hxx:243
void _normalizeMarginals(typename std::vector< ValueType >::iterator begin, typename std::vector< ValueType >::iterator end, SubSolver *subSolver)
Definition: trws_base.hxx:1249
void SetSmoothing(ValueType smoothingValue)
Definition: trws_base.hxx:391
parent::LabelType LabelType
Definition: trws_base.hxx:475
bool CheckTreeAgreement(InferenceTermination *pterminationCode)
Definition: trws_base.hxx:1187
#define OPENGM_ASSERT(expression)
Definition: opengm.hxx:77
TRWSPrototype_Parameters(size_t maxIternum, ValueType precision=1.0, bool absolutePrecision=true, ValueType minRelativeDualImprovement=-1.0, bool fastComputations=true, bool canonicalNormalization=false)
Definition: trws_base.hxx:95
Storage::UnaryFactor UnaryFactor
Definition: trws_base.hxx:224
const FactorProperties & getFactorProperties() const
Definition: trws_base.hxx:265
parent::UnaryFactor UnaryFactor
Definition: trws_base.hxx:478
void _normalizeMarginals(typename std::vector< ValueType >::iterator begin, typename std::vector< ValueType >::iterator end, SubSolver *subSolver)
Definition: trws_base.hxx:1144
parent::ValueType ValueType
Definition: trws_base.hxx:361
virtual const std::vector< LabelType > & arg() const
Definition: trws_base.hxx:236
ValueType GetMarginalsAndDerivativeMove()
Definition: trws_base.hxx:1308
DecompositionStorage< GM > Storage
Definition: trws_base.hxx:366
void getTreeAgreement(std::vector< bool > &out, std::vector< LabelType > *plabeling=0, std::vector< std::vector< LabelType > > *ptreeLabelings=0)
Definition: trws_base.hxx:1152
InputIterator transform_inplace(InputIterator first, InputIterator last, UnaryOperator op)
Definition: utilities2.hxx:79
std::vector< SubSolver * > _subSolvers
Definition: trws_base.hxx:317
T _MulNormalize(Iterator begin, Iterator end, T initialValue)
LabelType numberOfLabels(IndexType varId) const
Definition: trws_base.hxx:57
ValueType lastDualUpdate() const
Definition: trws_base.hxx:261
MonotoneChainsDecomposition< GM >::SubVariableListType SubVariableListType
Definition: trws_base.hxx:25
SubModel::MoveDirection _moveDirection
Definition: trws_base.hxx:316
parent::OutputContainerType OutputContainerType
Definition: trws_base.hxx:481
IndexType _order(IndexType i)
Definition: trws_base.hxx:741
VariableToFactorMapping< GM > VariableToFactorMap
Definition: trws_base.hxx:49
SumProdTRWS_Parameters(size_t maxIternum, ValueType smValue, ValueType precision=1.0, bool absolutePrecision=true, ValueType minRelativeDualImprovement=2 *std::numeric_limits< ValueType >::epsilon(), bool fastComputations=true, bool canonicalNormalization=false)
Definition: trws_base.hxx:341
static StructureType getStructureType(const std::string &structName)
Definition: trws_base.hxx:30
void _SumUpForwardMarginals(std::vector< ValueType > *pout, const_marginals_iterators_pair itpair)
Definition: trws_base.hxx:1120
SumProdTRWS_Parameters< ValueType > Parameters
Definition: trws_base.hxx:370
Storage::MoveDirection MoveDirection
Definition: trws_base.hxx:116
virtual ValueType GetBestIntegerBound() const
Definition: trws_base.hxx:233
SequenceStorage< GM > SubModel
Definition: trws_base.hxx:365
MaxSumTRWS_Parameters< ValueType > Parameters
Definition: trws_base.hxx:487
void GetMarginalsForSubModel(IndexType modelId, IndexType localVarId, ITERATOR begin)
Definition: trws_base.hxx:414
virtual void _normalizeMarginals(typename std::vector< ValueType >::iterator begin, typename std::vector< ValueType >::iterator end, SubSolver *subSolver)=0
parent::SubSolverType SubSolver
Definition: trws_base.hxx:359
bool _CheckStoppingCondition(InferenceTermination *pterminationCode)
Definition: trws_base.hxx:1215
IndexType numberOfSharedVariables() const
Definition: trws_base.hxx:59
std::vector< bool > _mask
Definition: trws_base.hxx:525
std::vector< ValueType > OutputContainerType
Definition: trws_base.hxx:217
FactorList::const_iterator const_iterator
Definition: trws_base.hxx:132
SequenceStorage< GM > SubModel
Definition: trws_base.hxx:222
TRWSPrototype_Parameters< ValueType > parent
Definition: trws_base.hxx:441
void addDDvector(const DDVectorType &ddvector)
Definition: trws_base.hxx:1002
virtual void _postprocessMarginals(typename std::vector< ValueType >::iterator begin, typename std::vector< ValueType >::iterator end)=0
FactorProperties _factorProperties
Definition: trws_base.hxx:304
virtual bool _CheckStoppingCondition(InferenceTermination *pterminationCode)
Definition: trws_base.hxx:647
std::vector< typename GM::ValueType > DDVectorType
Definition: trws_base.hxx:51
std::vector< LabelType > _bestIntegerLabeling
Definition: trws_base.hxx:325
ValueType _oldDualBound
>current dual bound (it is improved monotonically)
Definition: trws_base.hxx:313
void GetMarginalsMove()
>returns "averaged" over subsolvers marginals
Definition: trws_base.hxx:728
parent::EmptyVisitorType EmptyVisitorType
Definition: trws_base.hxx:477
void _SumUpForwardMarginals(std::vector< ValueType > *pout, const_marginals_iterators_pair itpair)
Definition: trws_base.hxx:1265
std::vector< FactorVarID > FactorList
Definition: trws_base.hxx:131
bool _CheckConvergence(ValueType relativeThreshold)
Definition: trws_base.hxx:430
MaxSumTRWS_Parameters(size_t maxIternum, ValueType precision=1.0, bool absolutePrecision=true, ValueType minRelativeDualImprovement=-1.0, bool fastComputations=true, bool canonicalNormalization=false, size_t treeAgreeMaxStableIter=0)
Definition: trws_base.hxx:443
void _postprocessMarginals(typename std::vector< ValueType >::iterator begin, typename std::vector< ValueType >::iterator end)
Definition: trws_base.hxx:1259
InferenceTermination infer_visitor_updates(VISITOR &visitor, size_t *pinterCounter=0)
Definition: trws_base.hxx:781
OutputContainerType::iterator OutputIteratorType
Definition: trws_base.hxx:218
parent::OutputContainerType OutputContainerType
Definition: trws_base.hxx:367
virtual ValueType bound() const
Definition: trws_base.hxx:235
TRWSPrototype_Parameters< ValueType > Parameters
Definition: trws_base.hxx:220
TRWSPrototype< MaxSumSolver< GM, ACC, typename std::vector< typename GM::ValueType >::const_iterator > > parent
Definition: trws_base.hxx:469
ValueType GetSmoothing() const
Definition: trws_base.hxx:392
TRWSPrototype(Storage &storage, const Parameters &params)
Definition: trws_base.hxx:534
MonotoneChainsDecomposition< GM >::SubVariable SubVariable
Definition: trws_base.hxx:24
parent::InferenceTermination InferenceTermination
Definition: trws_base.hxx:364
const_iterator end(IndexType varId, MoveDirection md) const
Definition: trws_base.hxx:136
const_iterator begin(IndexType varId, MoveDirection md) const
Definition: trws_base.hxx:135
void _EstimateIntegerLabel(IndexType varId, const std::vector< ValueType > &sumMarginal)
Definition: trws_base.hxx:287
Funcion that refers to a factor of another GraphicalModel in which some variables are fixed...
SubModel & subModel(IndexType modelId)
Definition: trws_base.hxx:60
parent::InferenceTermination InferenceTermination
Definition: trws_base.hxx:476
InferenceTermination _core_infer(VISITOR &visitor, size_t *piterCounter=0)
Definition: trws_base.hxx:676
opengm::InferenceTermination InferenceTermination
Definition: trws_base.hxx:216
DecompositionStorage< GM > Storage
Definition: trws_base.hxx:485
InferenceTermination core_infer(size_t *piterCounter=0)
Definition: trws_base.hxx:264
FactorProperties::ParameterStorageType _factorParameters
Definition: trws_base.hxx:329
visitors::EmptyVisitor< TRWSPrototype< SubSolverType > > EmptyVisitorParent
Definition: trws_base.hxx:209
void getDDVector(DDVectorType *ddvector) const
Definition: trws_base.hxx:1033
ValueType getDerivative(size_t i) const
besides computation of marginals returns derivative w.r.t. _smoothingValue
Definition: trws_base.hxx:400
void _InitSubSolvers()
>best label index
Definition: trws_base.hxx:575
SequenceStorage< GM > SubModel
Definition: trws_base.hxx:484
parent::SubSolverType SubSolver
Definition: trws_base.hxx:471
static std::string getString(StructureType structure)
Definition: trws_base.hxx:37
visitors::VisitorWrapper< EmptyVisitorParent, TRWSPrototype< SubSolver > > EmptyVisitorType
Definition: trws_base.hxx:210
FunctionParameters< GM > FactorProperties
Definition: trws_base.hxx:207
MaxSumTRWS(Storage &storage, const Parameters &params)
Definition: trws_base.hxx:489
ValueType _lastDualUpdate
previous dual bound (it is improved monotonically)
Definition: trws_base.hxx:314
parent::const_marginals_iterators_pair const_marginals_iterators_pair
Definition: trws_base.hxx:472
virtual bool _CheckConvergence(ValueType relativeThreshold)
Definition: trws_base.hxx:635
IndexType _core_order(IndexType i, IndexType totalSize)
Definition: trws_base.hxx:735
T _MaxNormalize_inplace(Iterator begin, Iterator end, T init, Comp comp)
InferenceTermination
Definition: inference.hxx:24
SumProdTRWS(Storage &storage, const Parameters &params)
Definition: trws_base.hxx:372