OpenGM  2.3.x
Discrete Graphical Model Library
trws.hxx
Go to the documentation of this file.
1 
5 #pragma once
6 #ifndef OPENGM_EXTERNAL_TRWS_HXX
7 #define OPENGM_EXTERNAL_TRWS_HXX
8 
15 
16 #include "typeView.h"
17 #include "MRFEnergy.h"
18 #include "instances.h"
19 #include "MRFEnergy.cpp"
20 #include "minimize.cpp"
21 #include "treeProbabilities.cpp"
22 #include "ordering.cpp"
23 
24 namespace opengm {
25  namespace external {
38  template<class GM>
39  class TRWS : public Inference<GM, opengm::Minimizer> {
40  public:
41  typedef GM GraphicalModelType;
47  typedef size_t VariableIndex;
49  struct Parameter {
51  enum EnergyType {VIEW, TABLES, TL1, TL2/*, WEIGHTEDTABLE*/};
59  bool doBPS_;
63  double tolerance_;
70  numberOfIterations_ = 1000;
71  useRandomStart_ = false;
72  useZeroStart_ = false;
73  doBPS_ = false;
74  energyType_ = VIEW;
75  tolerance_ = 0.0;
76  minDualChange_ = 0.00001;
77  calculateMinMarginals_ = false;
78  };
79  };
80  // construction
81  TRWS(const GraphicalModelType& gm, const Parameter para = Parameter());
82  // destruction
83  ~TRWS();
84  // query
85  std::string name() const;
86  const GraphicalModelType& graphicalModel() const;
87  // inference
88  template<class VISITOR>
89  InferenceTermination infer(VISITOR & visitor);
91  InferenceTermination arg(std::vector<LabelType>&, const size_t& = 1) const;
92  InferenceTermination marginal(const size_t variableIndex, IndependentFactorType& out) const;
93  typename GM::ValueType bound() const;
94  typename GM::ValueType value() const;
95  private:
96  const GraphicalModelType& gm_;
97  Parameter parameter_;
98  ValueType constTerm_;
99 
100  MRFEnergy<TypeView<GM> >* mrfView_;
101  typename MRFEnergy<TypeView<GM> >::NodeId* nodesView_;
102  MRFEnergy<TypeGeneral>* mrfGeneral_;
103  MRFEnergy<TypeGeneral>::NodeId* nodesGeneral_;
108  TypeGeneral::REAL* minMarginals_;
109  size_t* minMarginalsOffsets_;
110 
111 
112  double runTime_;
113  ValueType lowerBound_;
114  ValueType value_;
115  std::vector<LabelType> state_;
116  const IndexType numNodes_;
117  IndexType maxNumLabels_;
118  bool hasSameLabelNumber_;
119  void checkLabelNumber();
120 
121  void generateMRFView();
122  void generateMRFTables();
123  void generateMRFTL1();
124  void generateMRFTL2();
125  //void generateMRFWeightedTable();
126 
127  ValueType getT(IndexType factor) const;
128 
129  // required for energy type tl1
130  bool truncatedAbsoluteDifferenceFactors() const;
131 
132  // required for energy type tl2
133  bool truncatedSquaredDifferenceFactors() const;
134 
135  template <class ENERGYTYPE>
136  void addNodes(MRFEnergy<ENERGYTYPE>*& mrf, typename MRFEnergy<ENERGYTYPE>::NodeId*& nodes, typename ENERGYTYPE::REAL* D);
137 
138  template<class VISITOR, class ENERGYTYPE>
139  InferenceTermination inferImpl(VISITOR & visitor, MRFEnergy<ENERGYTYPE>* mrf);
140  };
141 
142  template<class GM, class ENERGYTYPE>
144  static void* create(typename GM::IndexType numLabels);
145  };
146 
147  template<class GM>
148  struct createMRFEnergy<GM, TypeView<GM> >{
149  static void* create(typename GM::IndexType numLabels);
150  };
151 
152  template<class GM>
153  struct createMRFEnergy<GM, TypeGeneral>{
154  static void* create(typename GM::IndexType numLabels);
155  };
156 
157  template<class GM>
158  struct createMRFEnergy<GM, TypeTruncatedLinear>{
159  static void* create(typename GM::IndexType numLabels);
160  };
161 
162  template<class GM>
163  struct createMRFEnergy<GM, TypeTruncatedQuadratic>{
164  static void* create(typename GM::IndexType numLabels);
165  };
166 
167  template<class GM, class ENERGYTYPE>
168  struct addMRFNode{
169  static typename MRFEnergy<ENERGYTYPE>::NodeId add(MRFEnergy<ENERGYTYPE>* mrf, typename GM::IndexType numLabels, typename ENERGYTYPE::REAL* D);
170  };
171 
172  template<class GM>
173  struct addMRFNode<GM, TypeView<GM> >{
174  static typename MRFEnergy<TypeView<GM> >::NodeId add(MRFEnergy<TypeView<GM> >* mrf, typename GM::IndexType numLabels, typename TypeView<GM>::REAL* D);
175  };
176 
177  template<class GM>
178  struct addMRFNode<GM, TypeGeneral>{
179  static typename MRFEnergy<TypeGeneral>::NodeId add(MRFEnergy<TypeGeneral>* mrf, typename GM::IndexType numLabels, typename TypeGeneral::REAL* D);
180  };
181 
182  template<class GM>
183  struct addMRFNode<GM, TypeTruncatedLinear>{
184  static typename MRFEnergy<TypeTruncatedLinear>::NodeId add(MRFEnergy<TypeTruncatedLinear>* mrf, typename GM::IndexType numLabels, typename TypeTruncatedLinear::REAL* D);
185  };
186 
187  template<class GM>
188  struct addMRFNode<GM, TypeTruncatedQuadratic>{
189  static typename MRFEnergy<TypeTruncatedQuadratic>::NodeId add(MRFEnergy<TypeTruncatedQuadratic>* mrf, typename GM::IndexType numLabels, typename TypeTruncatedQuadratic::REAL* D);
190  };
191 
192  template<class GM>
194  const typename TRWS::GraphicalModelType& gm,
195  const Parameter para
196  )
197  : gm_(gm), parameter_(para), mrfView_(NULL), nodesView_(NULL), mrfGeneral_(NULL), nodesGeneral_(NULL),
198  mrfTL1_(NULL), nodesTL1_(NULL), mrfTL2_(NULL), nodesTL2_(NULL), minMarginals_(NULL), minMarginalsOffsets_(NULL),
199  numNodes_(gm_.numberOfVariables()), maxNumLabels_(gm_.numberOfLabels(0)) {
200  // check label number
201  checkLabelNumber();
202  if(parameter_.calculateMinMarginals_){
203  size_t count = 0;
204  minMarginalsOffsets_ = new size_t[gm_.numberOfVariables()];
205  for(size_t i=0; i<gm_.numberOfVariables(); ++i){
206  minMarginalsOffsets_[i] = count;
207  count += gm_.numberOfLabels(i);
208  }
209  minMarginals_ = new TypeGeneral::REAL[count];
210  }
211 
212  // generate mrf model
213  switch(parameter_.energyType_) {
214  case Parameter::VIEW: {
215  generateMRFView();
216  break;
217  }
218  case Parameter::TABLES: {
219  generateMRFTables();
220  break;
221  }
222  case Parameter::TL1: {
223  if(!hasSameLabelNumber_) {
224  throw(RuntimeError("TRWS TL1 only supports graphical models where each variable has the same number of states."));
225  }
226  generateMRFTL1();
227  break;
228  }
229  case Parameter::TL2: {
230  if(!hasSameLabelNumber_) {
231  throw(RuntimeError("TRWS TL2 only supports graphical models where each variable has the same number of states."));
232  }
233  generateMRFTL2();
234  break;
235  }
236  /*case Parameter::WEIGHTEDTABLE: {
237  generateMRFWeightedTable();
238  break;
239  }*/
240  default: {
241  throw(RuntimeError("Unknown energy type."));
242  }
243  }
244 
245  // set initial value and lower bound
247  AccumulationType::ineutral(lowerBound_);
248  }
249 
250  template<class GM>
252  if(mrfView_) {
253  delete mrfView_;
254  }
255  if(nodesView_) {
256  delete[] nodesView_;
257  }
258 
259  if(mrfGeneral_) {
260  delete mrfGeneral_;
261  }
262  if(nodesGeneral_) {
263  delete[] nodesGeneral_;
264  }
265 
266  if(mrfTL1_) {
267  delete mrfTL1_;
268  }
269  if(nodesTL1_) {
270  delete[] nodesTL1_;
271  }
272 
273  if(mrfTL2_) {
274  delete mrfTL2_;
275  }
276  if(nodesTL2_) {
277  delete[] nodesTL2_;
278  }
279  if(minMarginals_) {
280  delete[] minMarginals_;
281  delete[] minMarginalsOffsets_;
282  }
283  }
284 
285  template<class GM>
286  inline std::string
287  TRWS<GM>
288  ::name() const {
289  return "TRWS";
290  }
291 
292  template<class GM>
293  inline const typename TRWS<GM>::GraphicalModelType&
294  TRWS<GM>
296  return gm_;
297  }
298 
299  template<class GM>
300  inline InferenceTermination
302  (
303  ) {
304  EmptyVisitorType visitor;
305  return this->infer(visitor);
306  }
307 
308  template<class GM>
309  template<class VISITOR>
310  inline InferenceTermination
312  (
313  VISITOR & visitor
314  ) {
315  switch(parameter_.energyType_) {
316  case Parameter::VIEW: {
317  return inferImpl(visitor, mrfView_);
318  break;
319  }
320  case Parameter::TABLES: {
321  return inferImpl(visitor, mrfGeneral_);
322  break;
323  }
324  case Parameter::TL1: {
325  return inferImpl(visitor, mrfTL1_);
326  break;
327  }
328  case Parameter::TL2: {
329  return inferImpl(visitor, mrfTL2_);
330  break;
331  }
332 /* case Parameter::WEIGHTEDTABLE: {
333  return inferImpl(visitor, mrf);
334  break;
335  }*/
336  default: {
337  throw(RuntimeError("Unknown energy type."));
338  }
339  }
340  }
341 
342  template<class GM>
343  inline InferenceTermination
344  TRWS<GM>
346  std::vector<LabelType>& arg,
347  const size_t& n
348  ) const {
349 
350  if(n > 1) {
351  return UNKNOWN;
352  }
353  else {
354  arg.resize(numNodes_);
355  switch(parameter_.energyType_) {
356  case Parameter::VIEW: {
357  for(IndexType i = 0; i < numNodes_; i++) {
358  arg[i] = mrfView_->GetSolution(nodesView_[i]);
359  }
360  return NORMAL;
361  break;
362  }
363  case Parameter::TABLES: {
364  for(IndexType i = 0; i < numNodes_; i++) {
365  arg[i] = mrfGeneral_->GetSolution(nodesGeneral_[i]);
366  }
367  return NORMAL;
368  break;
369  }
370  case Parameter::TL1: {
371  for(IndexType i = 0; i < numNodes_; i++) {
372  arg[i] = mrfTL1_->GetSolution(nodesTL1_[i]);
373  }
374  return NORMAL;
375  break;
376  }
377  case Parameter::TL2: {
378  for(IndexType i = 0; i < numNodes_; i++) {
379  arg[i] = mrfTL2_->GetSolution(nodesTL2_[i]);
380  }
381  return NORMAL;
382  break;
383  }
384 /* case Parameter::WEIGHTEDTABLE: {
385  for(IndexType i = 0; i < numNodes_; i++) {
386  arg[i] = mrfGeneral_->GetSolution(nodesGeneral_[i]);
387  }
388  return NORMAL;
389  break;
390  }*/
391  default: {
392  throw(RuntimeError("Unknown energy type."));
393  }
394  }
395  }
396  }
397 
401  template<class GM>
402  inline InferenceTermination
404  const size_t variableIndex,
406  ) const
407  {
408 
409  if(parameter_.calculateMinMarginals_){
410  out.assign(gm_, &variableIndex, &variableIndex+1, 0);
411  for(size_t i=0; i<gm_.numberOfLabels(variableIndex); ++i){
412  out(i) = minMarginals_[i+minMarginalsOffsets_[variableIndex]];
413  }
414  return NORMAL;
415  }else{
416  return UNKNOWN;
417  }
418  }
419 
420  template<class GM>
421  inline typename GM::ValueType
422  TRWS<GM>::bound() const {
423  return lowerBound_+constTerm_;
424  }
425  template<class GM>
426  inline typename GM::ValueType
427  TRWS<GM>::value() const {
428  return value_+constTerm_;
429  }
430 
431  template<class GM>
432  inline void TRWS<GM>::checkLabelNumber() {
433  hasSameLabelNumber_ = true;
434  for(IndexType i = 1; i < gm_.numberOfVariables(); i++) {
435  if(gm_.numberOfLabels(i) != maxNumLabels_) {
436  hasSameLabelNumber_ = false;
437  }
438  if(gm_.numberOfLabels(i) > maxNumLabels_) {
439  maxNumLabels_ = gm_.numberOfLabels(i);
440  }
441  }
442  }
443 
444  template<class GM>
445  inline void TRWS<GM>::generateMRFView() {
446  mrfView_ = new MRFEnergy<TypeView<GM> >(typename TypeView<GM>::GlobalSize());
447  nodesView_ = new typename MRFEnergy<TypeView<GM> >::NodeId[numNodes_];
448 
449  // add nodes
450  for(IndexType i = 0; i < numNodes_; i++) {
451  std::vector<typename GM::IndexType> factors;
452  for(typename GM::ConstFactorIterator iter = gm_.factorsOfVariableBegin(i); iter != gm_.factorsOfVariableEnd(i); iter++) {
453  if(gm_[*iter].numberOfVariables() == 1) {
454  factors.push_back(*iter);
455  }
456  }
457  nodesView_[i] = mrfView_->AddNode(typename TypeView<GM>::LocalSize(gm_.numberOfLabels(i)), typename TypeView<GM>::NodeData(gm_, factors));
458  }
459 
460  // add edges
461  constTerm_ = 0;
462  for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
463  if(gm_[i].numberOfVariables() == 0){
464  LabelType l = 0;
465  constTerm_ += gm_[i](&l);
466  }
467  if(gm_[i].numberOfVariables() == 2) {
468  IndexType a = gm_[i].variableIndex(0);
469  IndexType b = gm_[i].variableIndex(1);
470  mrfView_->AddEdge(nodesView_[a], nodesView_[b], typename TypeView<GM>::EdgeData(gm_, i));
471  }
472  }
473  // set random start message
474  if(parameter_.useRandomStart_) {
475  mrfView_->AddRandomMessages(1, 0.0, 1.0);
476  } else if(parameter_.useZeroStart_) {
477  mrfView_->ZeroMessages();
478  }
479  }
480 
481  template<class GM>
482  inline void TRWS<GM>::generateMRFTables() {
483  // add nodes
484  typename TypeGeneral::REAL* D = new typename TypeGeneral::REAL[maxNumLabels_];
485  addNodes(mrfGeneral_, nodesGeneral_, D);
486  delete[] D;
487 
488  // add edges
489  IndexType index[2];
490  constTerm_ = 0;
491  for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
492  if(gm_[i].numberOfVariables() == 0){
493  LabelType l = 0;
494  constTerm_ += gm_[i](&l);
495  }
496  if(gm_[i].numberOfVariables() == 2) {
497  IndexType a = gm_[i].variableIndex(0);
498  IndexType b = gm_[i].variableIndex(1);
499  IndexType numLabels_a = gm_.numberOfLabels(a);
500  IndexType numLabels_b = gm_.numberOfLabels(b);
501  typename TypeGeneral::REAL* V = new typename TypeGeneral::REAL[numLabels_a * numLabels_b];
502  for(size_t j = 0; j < numLabels_a; j++) {
503  for(size_t k = 0; k < numLabels_b; k++) {
504  index[0] = j;
505  index[1] = k;
506  V[j + k * numLabels_a] = gm_[i](index);
507  }
508  }
509  mrfGeneral_->AddEdge(nodesGeneral_[a], nodesGeneral_[b], TypeGeneral::EdgeData(TypeGeneral::GENERAL, V));
510  delete[] V;
511  }
512  }
513 
514  // set random start message
515  if(parameter_.useRandomStart_) {
516  mrfGeneral_->AddRandomMessages(1, 0.0, 1.0);
517  } else if(parameter_.useZeroStart_) {
518  mrfGeneral_->ZeroMessages();
519  }
520  }
521 
522  template<class GM>
523  inline void TRWS<GM>::generateMRFTL1() {
524  OPENGM_ASSERT(truncatedAbsoluteDifferenceFactors());
525 
526  // add nodes
527  typename TypeTruncatedLinear::REAL* D = new typename TypeTruncatedLinear::REAL[maxNumLabels_];
528  addNodes(mrfTL1_, nodesTL1_, D);
529  delete[] D;
530 
531  // add edges
532  constTerm_=0;
533  for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
534  if(gm_[i].numberOfVariables() == 0){
535  LabelType l = 0;
536  constTerm_ += gm_[i](&l);
537  }
538  if(gm_[i].numberOfVariables() == 2) {
539  // truncation
540  ValueType t = getT(i);
541  //std::cout << "t: " << t << std::endl;
542 
543  // weight
544  IndexType index[] = {0, 1};
545  ValueType w = gm_[i](index);
546  //std::cout << "w: " << w << std::endl;
547 
548  // corresponding node IDs
549  IndexType a = gm_[i].variableIndex(0);
550  IndexType b = gm_[i].variableIndex(1);
551  mrfTL1_->AddEdge(nodesTL1_[a], nodesTL1_[b], TypeTruncatedLinear::EdgeData(w, w * t));
552  }
553  }
554 
555  // set random start message
556  if(parameter_.useRandomStart_) {
557  mrfTL1_->AddRandomMessages(1, 0.0, 1.0);
558  } else if(parameter_.useZeroStart_) {
559  mrfTL1_->ZeroMessages();
560  }
561  }
562 
563  template<class GM>
564  inline void TRWS<GM>::generateMRFTL2() {
565  OPENGM_ASSERT(truncatedSquaredDifferenceFactors());
566 
567  // add nodes
568  typename TypeTruncatedQuadratic::REAL* D = new typename TypeTruncatedQuadratic::REAL[maxNumLabels_];
569  addNodes(mrfTL2_, nodesTL2_, D);
570  delete[] D;
571 
572  // add edges
573  constTerm_=0;
574  for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
575  if(gm_[i].numberOfVariables() == 0){
576  LabelType l = 0;
577  constTerm_ += gm_[i](&l);
578  }
579  if(gm_[i].numberOfVariables() == 2) {
580  // truncation
581  ValueType t = getT(i);
582  //std::cout << "t: " << t << std::endl;
583 
584  // weight
585  IndexType index[] = {0, 1};
586  ValueType w = gm_[i](index);
587  //std::cout << "w: " << w << std::endl;
588 
589  // corresponding node IDs
590  IndexType a = gm_[i].variableIndex(0);
591  IndexType b = gm_[i].variableIndex(1);
592  mrfTL2_->AddEdge(nodesTL2_[a], nodesTL2_[b], TypeTruncatedQuadratic::EdgeData(w, w * t));
593  }
594  }
595 
596  //mrfTL2_->SetAutomaticOrdering();
597 
598  // set random start message
599  if(parameter_.useRandomStart_) {
600  mrfTL2_->AddRandomMessages(1, 0.0, 1.0);
601  } else if(parameter_.useZeroStart_) {
602  mrfTL2_->ZeroMessages();
603  }
604  }
605 
606 /* template<class GM>
607  inline void TRWS<GM>::generateMRFWeightedTable() {
608 
609  }*/
610 
611  template<class GM>
612  inline typename GM::ValueType TRWS<GM>::getT(IndexType factor) const {
613  OPENGM_ASSERT(gm_.numberOfVariables(factor) == 2);
614 
615  IndexType index1[] = {0, 1};
616  IndexType index0[] = {0, maxNumLabels_-1};
617 
618  return gm_[factor](index0)/gm_[factor](index1);
619  }
620 
621  template<class GM>
622  inline bool TRWS<GM>::truncatedAbsoluteDifferenceFactors() const {
623  for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
624  if(gm_.numberOfVariables(i) == 2) {
625  if(gm_[i].isTruncatedAbsoluteDifference() == false) {
626  return false;
627  }
628  }
629  }
630  return true;
631  }
632 
633  template<class GM>
634  inline bool TRWS<GM>::truncatedSquaredDifferenceFactors() const {
635  for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
636  if(gm_.numberOfVariables(i) == 2) {
637  if(gm_[i].isTruncatedSquaredDifference() == false) {
638  return false;
639  }
640  }
641  }
642  return true;
643  }
644 
645  template<class GM>
646  template <class ENERGYTYPE>
647  inline void TRWS<GM>::addNodes(MRFEnergy<ENERGYTYPE>*& mrf, typename MRFEnergy<ENERGYTYPE>::NodeId*& nodes, typename ENERGYTYPE::REAL* D) {
648 
649  mrf = reinterpret_cast<MRFEnergy<ENERGYTYPE>*>(createMRFEnergy<GM, ENERGYTYPE>::create(maxNumLabels_));
650 
651  nodes = new typename MRFEnergy<ENERGYTYPE>::NodeId[numNodes_];
652  for(IndexType i = 0; i < numNodes_; i++) {
653  for(IndexType j = 0; j < gm_.numberOfLabels(i); j++) {
654  D[j] = 0.0;
655  }
656  for(typename GM::ConstFactorIterator iter = gm_.factorsOfVariableBegin(i); iter != gm_.factorsOfVariableEnd(i); iter++) {
657  if(gm_[*iter].numberOfVariables() == 1) {
658  for(IndexType j = 0; j < gm_.numberOfLabels(i); j++) {
659  D[j] += gm_[*iter](&j);
660  }
661  }
662  }
663  nodes[i] = addMRFNode<GM, ENERGYTYPE>::add(mrf, gm_.numberOfLabels(i), D);
664  }
665  }
666 
667  template<class GM, class ENERGYTYPE>
668  inline void* createMRFEnergy<GM, ENERGYTYPE>::create(typename GM::IndexType numLabels) {
669  RuntimeError("Unsupported Energy Type!");
670  return NULL;
671  }
672 
673  template<class GM>
674  inline void* createMRFEnergy<GM, TypeView<GM> >::create(typename GM::IndexType numLabels) {
675  return reinterpret_cast<void*>(new MRFEnergy<TypeView<GM> >(typename TypeView<GM>::GlobalSize()));
676  }
677 
678  template<class GM>
679  inline void* createMRFEnergy<GM, TypeGeneral>::create(typename GM::IndexType numLabels) {
680  return reinterpret_cast<void*>(new MRFEnergy<TypeGeneral>(typename TypeGeneral::GlobalSize()));
681  }
682 
683  template<class GM>
684  inline void* createMRFEnergy<GM, TypeTruncatedLinear>::create(typename GM::IndexType numLabels) {
685  return reinterpret_cast<void*>(new MRFEnergy<TypeTruncatedLinear>(typename TypeTruncatedLinear::GlobalSize(numLabels)));
686  }
687 
688  template<class GM>
689  inline void* createMRFEnergy<GM, TypeTruncatedQuadratic>::create(typename GM::IndexType numLabels) {
690  return reinterpret_cast<void*>(new MRFEnergy<TypeTruncatedQuadratic>(typename TypeTruncatedQuadratic::GlobalSize(numLabels)));
691  }
692 
693  template<class GM, class ENERGYTYPE>
694  inline typename MRFEnergy<ENERGYTYPE>::NodeId addMRFNode<GM, ENERGYTYPE>::add(MRFEnergy<ENERGYTYPE>* mrf, typename GM::IndexType numLabels, typename ENERGYTYPE::REAL* D) {
695  RuntimeError("Unsupported Energy Type!");
696  return NULL;
697  }
698 
699  template<class GM>
700  inline typename MRFEnergy<TypeView<GM> >::NodeId addMRFNode<GM, TypeView<GM> >::add(MRFEnergy<TypeView<GM> >* mrf, typename GM::IndexType numLabels, typename TypeView<GM>::REAL* D) {
701  return mrf->AddNode(typename TypeView<GM>::LocalSize(numLabels), typename TypeView<GM>::NodeData(D));
702  }
703 
704  template<class GM>
705  inline typename MRFEnergy<TypeGeneral>::NodeId addMRFNode<GM, TypeGeneral>::add(MRFEnergy<TypeGeneral>* mrf, typename GM::IndexType numLabels, typename TypeGeneral::REAL* D) {
706  return mrf->AddNode(typename TypeGeneral::LocalSize(numLabels), typename TypeGeneral::NodeData(D));
707  }
708 
709  template<class GM>
710  inline typename MRFEnergy<TypeTruncatedLinear>::NodeId addMRFNode<GM, TypeTruncatedLinear>::add(MRFEnergy<TypeTruncatedLinear>* mrf, typename GM::IndexType numLabels, typename TypeTruncatedLinear::REAL* D) {
711  return mrf->AddNode(typename TypeTruncatedLinear::LocalSize(), typename TypeTruncatedLinear::NodeData(D));
712  }
713 
714  template<class GM>
715  inline typename MRFEnergy<TypeTruncatedQuadratic>::NodeId addMRFNode<GM, TypeTruncatedQuadratic>::add(MRFEnergy<TypeTruncatedQuadratic>* mrf, typename GM::IndexType numLabels, typename TypeTruncatedQuadratic::REAL* D) {
716  return mrf->AddNode(typename TypeTruncatedQuadratic::LocalSize(), typename TypeTruncatedQuadratic::NodeData(D));
717  }
718 
719  template<class GM>
720  template<class VISITOR, class ENERGYTYPE>
721  inline InferenceTermination TRWS<GM>::inferImpl(VISITOR & visitor, MRFEnergy<ENERGYTYPE>* mrf) {
722  typename MRFEnergy<ENERGYTYPE>::Options options;
723  options.m_iterMax = 1; // maximum number of iterations
724  options.m_printIter = 2 * parameter_.numberOfIterations_;
725  visitor.begin(*this);
726 
727 
728  if(parameter_.doBPS_) {
729  typename ENERGYTYPE::REAL v;
730  for(size_t i = 0; i < parameter_.numberOfIterations_; ++i) {
731  mrf->Minimize_BP(options, v, minMarginals_);
732  value_ = v;
733  if( visitor(*this) != visitors::VisitorReturnFlag::ContinueInf ) {
734  break;
735  }
736  }
737  } else {
738  typename ENERGYTYPE::REAL v;
739  typename ENERGYTYPE::REAL b;
740  typename ENERGYTYPE::REAL d;
741  for(size_t i = 0; i < parameter_.numberOfIterations_; ++i) {
742  mrf->Minimize_TRW_S(options, b, v, minMarginals_);
743  d = b-lowerBound_;
744  lowerBound_ = b;
745  value_ = v;
746  if( visitor(*this) != visitors::VisitorReturnFlag::ContinueInf ) {
747  break;
748  }
749  if(fabs(value_ - lowerBound_) / opengmMax(static_cast<double>(fabs(value_)), 1.0) < parameter_.tolerance_) {
750  break;
751  }
752  if(d<parameter_.minDualChange_){
753  break;
754  }
755  }
756  }
757  //Copy MinMarginals
758 
759  visitor.end(*this);
760  return NORMAL;
761  }
762 
763  } // namespace external
764 } // namespace opengm
765 
766 #endif // #ifndef OPENGM_EXTERNAL_TRWS_HXX
The OpenGM namespace.
Definition: config.hxx:43
const GraphicalModelType & graphicalModel() const
Definition: trws.hxx:295
std::string name() const
Definition: trws.hxx:288
bool calculateMinMarginals_
Calculate MinMarginals.
Definition: trws.hxx:67
visitors::EmptyVisitor< TRWS< GM > > EmptyVisitorType
Definition: trws.hxx:45
bool useZeroStart_
zero starting message
Definition: trws.hxx:57
InferenceTermination marginal(const size_t variableIndex, IndependentFactorType &out) const
output a solution for a marginal for a specific variable
Definition: trws.hxx:403
size_t VariableIndex
Definition: trws.hxx:47
T opengmMax(const T &x, const T &y)
Definition: opengm.hxx:116
InferenceTermination infer()
Definition: trws.hxx:302
bool useRandomStart_
random starting message
Definition: trws.hxx:55
void create(const hid_t &, const std::string &, ShapeIterator, ShapeIterator, CoordinateOrder)
Create and close an HDF5 dataset to store Marray data.
#define OPENGM_ASSERT(expression)
Definition: opengm.hxx:77
InferenceTermination arg(std::vector< LabelType > &, const size_t &=1) const
Definition: trws.hxx:345
GM::ValueType bound() const
return a bound on the solution
Definition: trws.hxx:422
EnergyType
possible energy types for TRWS
Definition: trws.hxx:51
static MRFEnergy< ENERGYTYPE >::NodeId add(MRFEnergy< ENERGYTYPE > *mrf, typename GM::IndexType numLabels, typename ENERGYTYPE::REAL *D)
Definition: trws.hxx:694
double minDualChange_
TRWS termintas if fabs(bound(t)-bound(t+1)) < minDualChange_.
Definition: trws.hxx:65
GraphicalModelType::IndexType IndexType
Definition: inference.hxx:40
GraphicalModelType::ValueType ValueType
Definition: inference.hxx:41
static T ineutral()
inverse neutral element (with return)
Definition: minimizer.hxx:25
Inference algorithm interface.
Definition: inference.hxx:34
double tolerance_
TRWS termintas if fabs(value - bound) / max(fabs(value), 1) < trwsTolerance_.
Definition: trws.hxx:63
size_t numberOfIterations_
number of iterations
Definition: trws.hxx:53
visitors::VerboseVisitor< TRWS< GM > > VerboseVisitorType
Definition: trws.hxx:44
static T neutral()
neutral element (with return)
Definition: minimizer.hxx:16
TRWS(const GraphicalModelType &gm, const Parameter para=Parameter())
Definition: trws.hxx:193
double REAL
Definition: typeView.h:22
GM::ValueType value() const
return the solution (value)
Definition: trws.hxx:427
Minimization as a unary accumulation.
Definition: minimizer.hxx:12
static void * create(typename GM::IndexType numLabels)
Definition: trws.hxx:668
message passing (BPS, TRWS): [?]
Definition: trws.hxx:39
bool doBPS_
use normal LBP
Definition: trws.hxx:59
EnergyType energyType_
selected energy type
Definition: trws.hxx:61
opengm::Minimizer AccumulationType
Definition: trws.hxx:42
OpenGM runtime error.
Definition: opengm.hxx:100
visitors::TimingVisitor< TRWS< GM > > TimingVisitorType
Definition: trws.hxx:46
InferenceTermination
Definition: inference.hxx:24
GraphicalModelType::IndependentFactorType IndependentFactorType
Definition: inference.hxx:44