OpenGM  2.3.x
Discrete Graphical Model Library
ad3.hxx
Go to the documentation of this file.
1 #pragma once
2 #ifndef OPENGM_EXTERNAL_AD3_HXX
3 #define OPENGM_EXTERNAL_AD3_HXX
4 
8 
9 
10 
11 #include "ad3/FactorGraph.h"
12 //#include "FactorSequence.h"
13 
14 
15 namespace opengm {
16  namespace external {
17 
21 
22  template<class GM,class ACC>
23  class AD3Inf : public Inference<GM, ACC> {
24 
25  public:
26  typedef GM GraphicalModelType;
27  typedef ACC AccumulationType;
32 
33  enum SolverType{
37  };
38 
39  struct Parameter {
41  const SolverType solverType = AD3_ILP,
42  const double eta = 0.1,
43  const bool adaptEta = true,
44  UInt64Type steps = 1000,
45  const double residualThreshold = 1e-6,
46  const int verbosity = 0
47  ) :
48  solverType_(solverType),
49  eta_(eta),
50  adaptEta_(adaptEta),
51  steps_(steps),
52  residualThreshold_(residualThreshold),
53  verbosity_(verbosity)
54  {
55  }
56 
58 
59  double eta_;
60  bool adaptEta_;
64  };
65 
66  // construction
67  AD3Inf(const GraphicalModelType& gm, const Parameter para = Parameter());
68  ~AD3Inf();
69 
70  // query
71  std::string name() const;
72  const GraphicalModelType& graphicalModel() const;
73  // inference
75  template<class VisitorType>
76  InferenceTermination infer(VisitorType&);
77  InferenceTermination arg(std::vector<LabelType>&, const size_t& = 1) const;
78 
79  ValueType value()const{
80  return gm_.evaluate(arg_);
81  }
82 
83  ValueType bound()const{
84  if(inferenceDone_ && parameter_.solverType_==AD3_ILP ){
85  return bound_;
86  }
87  else{
88  return bound_;
89  }
90  }
91 
92 
94  if( meta::Compare<OperatorType,Adder>::value && meta::Compare<AccumulationType,Minimizer>::value){
95  return val*(-1.0);
96  }
97  else if( meta::Compare<OperatorType,Adder>::value && meta::Compare<AccumulationType,Maximizer>::value){
98  return val;
99  }
100  }
101 
103  if( meta::Compare<OperatorType,Adder>::value && meta::Compare<AccumulationType,Minimizer>::value){
104  return val*(-1.0);
105  }
106  else if( meta::Compare<OperatorType,Adder>::value && meta::Compare<AccumulationType,Maximizer>::value){
107  return val;
108  }
109  }
110 
111 
112  // iterface to create a ad3 gm without a gm
113 
114  template<class N_LABELS_ITER>
115  AD3Inf(N_LABELS_ITER nLabelsBegin,N_LABELS_ITER nLabelsEnd, const Parameter para = Parameter());
116 
117 
118  AD3Inf(const UInt64Type nVar,const UInt64Type nLabels, const Parameter para,const bool foo);
119 
120 
121  template<class VI_ITERATOR,class FUNCTION>
122  void addFactor(VI_ITERATOR viBegin,VI_ITERATOR viEnd,const FUNCTION & function);
123 
124 
125  const std::vector<double> & posteriors()const{
126  return posteriors_;
127  }
128 
129  const std::vector<double> & higherOrderPosteriors()const{
130  return additional_posteriors_;
131  }
132 
133 
134  private:
135  const GraphicalModelType& gm_;
136  Parameter parameter_;
137  IndexType numVar_;
138 
139  // AD3Inf MEMBERS
140  AD3::FactorGraph factor_graph_;
141  std::vector<AD3::MultiVariable*> multi_variables_;
142 
143  std::vector<double> posteriors_;
144  std::vector<double> additional_posteriors_;
145  double bound_;
146 
147  std::vector<LabelType> arg_;
148 
149  bool inferenceDone_;
150 
151  std::vector<LabelType> space_; // only used if setup without gm
152 
153  };
154  // public interface
158 
159  template<class GM,class ACC>
161  ::AD3Inf(
162  const typename AD3Inf::GraphicalModelType& gm,
163  const Parameter para
164  ) :
165  gm_(gm),
166  parameter_(para),
167  numVar_(gm.numberOfVariables()),
168  factor_graph_(),
169  multi_variables_(gm.numberOfVariables()),
170  posteriors_(),
171  additional_posteriors_(),
172  bound_(),
173  arg_(gm.numberOfVariables(),static_cast<LabelType>(0)),
174  inferenceDone_(false),
175  space_(0)
176  {
177 
178  if(meta::Compare<OperatorType,Adder>::value==false){
179  throw RuntimeError("AD3 does not only support opengm::Adder as Operator");
180  }
181 
182  if(meta::Compare<AccumulationType,Minimizer>::value==false and meta::Compare<AccumulationType,Maximizer>::value==false ){
183  throw RuntimeError("AD3 does not only support opengm::Minimizer and opengm::Maximizer as Accumulatpr");
184  }
185 
186 
187  bound_ = ACC::template ineutral<ValueType>();
188 
189 
190 
191  factor_graph_.SetVerbosity(parameter_.verbosity_);
192  UInt64Type maxFactorSize = 0 ;
193  for(IndexType fi=0;fi<gm_.numberOfFactors();++fi){
194  maxFactorSize=std::max(static_cast<UInt64Type>(gm_[fi].size()),maxFactorSize);
195  }
196 
197  ValueType * facVal = new ValueType[maxFactorSize];
198 
199 
200  // fill space :
201  // - Create a multi-valued variable for variable of gm
202  // and initialize unaries with 0
203  for(IndexType vi=0;vi<gm_.numberOfVariables();++vi){
204  multi_variables_[vi] = factor_graph_.CreateMultiVariable(gm_.numberOfLabels(vi));
205  for(LabelType l=0;l<gm_.numberOfLabels(vi);++l){
206  multi_variables_[vi]->SetLogPotential(l,0.0);
207  }
208  }
209 
210 
211  // - add higher order factors
212  // - setup values for 1. order and higher order factors
213  for(IndexType fi=0;fi<gm_.numberOfFactors();++fi){
214  //gm_[fi].copyValues(facVal);
215  gm_[fi].copyValuesSwitchedOrder(facVal);
216  const IndexType nVar=gm_[fi].numberOfVariables();
217 
218  if(nVar==1){
219  const IndexType vi0 = gm_[fi].variableIndex(0);
220  const IndexType nl0 = gm_.numberOfLabels(vi0);
221 
222  for(LabelType l=0;l<nl0;++l){
223  const ValueType logP = multi_variables_[vi0]->GetLogPotential(l);
224  const ValueType val = this->valueToMaxSum(facVal[l]);
225  multi_variables_[vi0]->SetLogPotential(l,logP+val);
226  }
227  }
228  else if (nVar>1){
229  // std::cout<<"factor size "<<gm_[fi].size()<<"\n";
230  // create higher order factor function
231  std::vector<double> additional_log_potentials(gm_[fi].size());
232  for(IndexType i=0;i<gm_[fi].size();++i){
233  additional_log_potentials[i]=this->valueToMaxSum(facVal[i]);
234  }
235 
236  // create high order factor vi
237  std::vector<AD3::MultiVariable*> multi_variables_local(nVar);
238  for(IndexType v=0;v<nVar;++v){
239  multi_variables_local[v]=multi_variables_[gm_[fi].variableIndex(v)];
240  }
241 
242  // create higher order factor
243  factor_graph_.CreateFactorDense(multi_variables_local,additional_log_potentials);
244  }
245  else{
246  OPENGM_CHECK(false,"const factors are not yet implemented");
247  }
248 
249  }
250 
251  // delete buffer
252  delete[] facVal;
253  }
254 
255  template<class GM,class ACC>
256  template<class N_LABELS_ITER>
258  N_LABELS_ITER nLabelsBegin,
259  N_LABELS_ITER nLabelsEnd,
260  const Parameter para
261  ) :
262  gm_(GM()), // DIRTY
263  parameter_(para),
264  numVar_(std::distance(nLabelsBegin,nLabelsEnd)),
265  factor_graph_(),
266  multi_variables_(std::distance(nLabelsBegin,nLabelsEnd)),
267  posteriors_(),
268  additional_posteriors_(),
269  bound_(),
270  arg_(std::distance(nLabelsBegin,nLabelsEnd),static_cast<LabelType>(0)),
271  space_(nLabelsBegin,nLabelsEnd)
272  {
273 
274  if(meta::Compare<OperatorType,Adder>::value==false){
275  throw RuntimeError("AD3 does not only support opengm::Adder as Operator");
276  }
277  if(meta::Compare<AccumulationType,Minimizer>::value==false and meta::Compare<AccumulationType,Maximizer>::value==false ){
278  throw RuntimeError("AD3 does not only support opengm::Minimizer and opengm::Maximizer as Accumulatpr");
279  }
280  bound_ = ACC::template ineutral<ValueType>();
281  factor_graph_.SetVerbosity(parameter_.verbosity_);
282 
283  // and initialize unaries with 0
284  for(IndexType vi=0;vi<numVar_;++vi){
285  multi_variables_[vi] = factor_graph_.CreateMultiVariable(space_[vi]);
286  for(LabelType l=0;l<space_[vi];++l){
287  multi_variables_[vi]->SetLogPotential(l,0.0);
288  }
289  }
290  }
291 
292  template<class GM,class ACC>
294  const UInt64Type nVar,
295  const UInt64Type nLabels,
296  const Parameter para,
297  const bool foo
298  ) :
299  gm_(GM()), // DIRTY
300  parameter_(para),
301  numVar_(nVar),
302  factor_graph_(),
303  multi_variables_(nVar),
304  posteriors_(),
305  additional_posteriors_(),
306  bound_(),
307  arg_(nVar,static_cast<LabelType>(0)),
308  space_(nVar,nLabels)
309  {
310 
311  if(meta::Compare<OperatorType,Adder>::value==false){
312  throw RuntimeError("AD3 does not only support opengm::Adder as Operator");
313  }
314  if(meta::Compare<AccumulationType,Minimizer>::value==false and meta::Compare<AccumulationType,Maximizer>::value==false ){
315  throw RuntimeError("AD3 does not only support opengm::Minimizer and opengm::Maximizer as Accumulatpr");
316  }
317  bound_ = ACC::template ineutral<ValueType>();
318  factor_graph_.SetVerbosity(parameter_.verbosity_);
319  for(IndexType vi=0;vi<numVar_;++vi){
320  multi_variables_[vi] = factor_graph_.CreateMultiVariable(space_[vi]);
321  for(LabelType l=0;l<space_[vi];++l){
322  multi_variables_[vi]->SetLogPotential(l,0.0);
323  }
324  }
325  }
326 
327 
328  template<class GM,class ACC>
329  template<class VI_ITERATOR,class FUNCTION>
330  void
332  VI_ITERATOR visBegin,
333  VI_ITERATOR visEnd,
334  const FUNCTION & function
335  ){
336  const IndexType nVis = std::distance(visBegin,visEnd);
337  OPENGM_CHECK_OP(nVis,==,function.dimension(),"functions dimension does not match number of variabole indices");
338 
339  for(IndexType v=0;v<nVis;++v){
340  OPENGM_CHECK_OP(space_[visBegin[v]],==,function.shape(v),"functions shape does not match space");
341  }
342 
343 
344  if(nVis==1){
345  LabelType l[1];
346  for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0]){
347  const ValueType logP = multi_variables_[visBegin[0]]->GetLogPotential(l[0]);
348  const ValueType val = this->valueToMaxSum(function(l));
349  multi_variables_[visBegin[0]]->SetLogPotential(l[0],logP+val);
350  }
351  }
352  else if(nVis>=2){
353 
354 
355  // create high order factor vi
356  std::vector<AD3::MultiVariable*> multi_variables_local(nVis);
357  for(IndexType v=0;v<nVis;++v){
358  multi_variables_local[v]=multi_variables_[visBegin[v]];
359  }
360 
361  // create higher order function (for dense factor)
362  std::vector<double> additional_log_potentials(function.size());
363 
364  // FILL THE FUNCTION
365 
366  if(nVis==2){
367  LabelType l[2];
368  UInt64Type c=0;
369  for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
370  for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1]){
371  additional_log_potentials[c]=this->valueToMaxSum(function(l));
372  ++c;
373  }
374  }
375  else if(nVis==3){
376  LabelType l[3];
377  UInt64Type c=0;
378  for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
379  for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
380  for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2]){
381  additional_log_potentials[c]=this->valueToMaxSum(function(l));
382  ++c;
383  }
384  }
385  else if(nVis==4){
386  LabelType l[4];
387  UInt64Type c=0;
388  for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
389  for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
390  for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
391  for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3]){
392  additional_log_potentials[c]=this->valueToMaxSum(function(l));
393  ++c;
394  }
395  }
396  else if(nVis==5){
397  LabelType l[5];
398  UInt64Type c=0;
399  for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
400  for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
401  for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
402  for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
403  for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4]){
404  additional_log_potentials[c]=this->valueToMaxSum(function(l));
405  ++c;
406  }
407  }
408  else if(nVis==6){
409  LabelType l[6];
410  UInt64Type c=0;
411  for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
412  for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
413  for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
414  for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
415  for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
416  for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5]){
417  additional_log_potentials[c]=this->valueToMaxSum(function(l));
418  ++c;
419  }
420  }
421  else if(nVis==7){
422  LabelType l[7];
423  UInt64Type c=0;
424  for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
425  for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
426  for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
427  for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
428  for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
429  for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5])
430  for(l[6]=0; l[6]<space_[visBegin[6]]; ++l[6]){
431  additional_log_potentials[c]=this->valueToMaxSum(function(l));
432  ++c;
433  }
434  }
435  else if(nVis==8){
436  LabelType l[8];
437  UInt64Type c=0;
438  for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
439  for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
440  for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
441  for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
442  for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
443  for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5])
444  for(l[6]=0; l[6]<space_[visBegin[6]]; ++l[6])
445  for(l[7]=0; l[7]<space_[visBegin[7]]; ++l[7])
446  {
447  additional_log_potentials[c]=this->valueToMaxSum(function(l));
448  ++c;
449  }
450  }
451  else if(nVis==9){
452  LabelType l[9];
453  UInt64Type c=0;
454  for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
455  for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
456  for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
457  for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
458  for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
459  for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5])
460  for(l[6]=0; l[6]<space_[visBegin[6]]; ++l[6])
461  for(l[7]=0; l[7]<space_[visBegin[7]]; ++l[7])
462  for(l[8]=0; l[8]<space_[visBegin[8]]; ++l[8])
463  {
464  additional_log_potentials[c]=this->valueToMaxSum(function(l));
465  ++c;
466  }
467  }
468  else if(nVis==10){
469  LabelType l[10];
470  UInt64Type c=0;
471  for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
472  for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
473  for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
474  for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
475  for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
476  for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5])
477  for(l[6]=0; l[6]<space_[visBegin[6]]; ++l[6])
478  for(l[7]=0; l[7]<space_[visBegin[7]]; ++l[7])
479  for(l[8]=0; l[8]<space_[visBegin[8]]; ++l[8])
480  for(l[9]=0; l[9]<space_[visBegin[9]]; ++l[9])
481  {
482  additional_log_potentials[c]=this->valueToMaxSum(function(l));
483  ++c;
484  }
485  }
486  else{
487  throw RuntimeError("order must be <=10 for inplace building of Ad3Inf (call us if you need higher order)");
488  }
489 
490 
491 
492 
493 
494  // create higher order factor
495  factor_graph_.CreateFactorDense(multi_variables_local,additional_log_potentials);
496 
497  }
498 
499  }
500 
501 
502  template<class GM,class ACC>
505 
506  }
507 
508  template<class GM,class ACC>
509  inline std::string
511  ::name() const {
512  return "AD3Inf";
513  }
514 
515  template<class GM,class ACC>
516  inline const typename AD3Inf<GM,ACC>::GraphicalModelType&
519  return gm_;
520  }
521 
522  template<class GM,class ACC>
523  inline InferenceTermination
526  EmptyVisitorType v;
527  return infer(v);
528  }
529 
530  template<class GM,class ACC>
531  template<class VisitorType>
533  AD3Inf<GM,ACC>::infer(VisitorType& visitor)
534  {
535  visitor.begin(*this);
536 
537  // set parameters
538  if(parameter_.solverType_ == AD3_LP || parameter_.solverType_ == AD3_ILP){
539  factor_graph_.SetEtaAD3(parameter_.eta_);
540  factor_graph_.AdaptEtaAD3(parameter_.adaptEta_);
541  factor_graph_.SetMaxIterationsAD3(parameter_.steps_);
542  factor_graph_.SetResidualThresholdAD3(parameter_.residualThreshold_);
543  }
544  if(parameter_.solverType_ == PSDD_LP){
545  factor_graph_.SetEtaPSDD(parameter_.eta_);
546  factor_graph_.SetMaxIterationsPSDD(parameter_.steps_);
547  }
548 
549 
550  // solve
551  double value;
552  if ( parameter_.solverType_ == AD3_LP){
553  //std::cout<<"ad3 lp\n";
554  factor_graph_.SolveLPMAPWithAD3(&posteriors_, &additional_posteriors_, &value, &bound_);
555  }
556  if ( parameter_.solverType_ == AD3_ILP){
557  //std::cout<<"ad3 ilp\n";
558  factor_graph_.SolveExactMAPWithAD3(&posteriors_, &additional_posteriors_, &value, &bound_);
559  }
560  if (parameter_.solverType_ == PSDD_LP){
561  //std::cout<<"ad3 psdd lp\n";
562  factor_graph_.SolveExactMAPWithAD3(&posteriors_, &additional_posteriors_, &value, &bound_);
563  }
564 
565  // transform bound
566  bound_ =this->valueFromMaxSum(bound_);
567 
568  // make gm arg
569  UInt64Type c=0;
570  for(IndexType vi = 0; vi < numVar_; ++vi) {
571  LabelType bestLabel = 0 ;
572  double bestVal = -100000;
573  const LabelType nLabels = (space_.size()==0 ? gm_.numberOfLabels(vi) : space_[vi] );
574  for(LabelType l=0;l< nLabels;++l){
575  const double val = posteriors_[c];
576  //std::cout<<"vi= "<<vi<<" l= "<<l<<" val= "<<val<<"\n";
577 
578  if(bestVal<0 || val>bestVal){
579  bestVal=val;
580  bestLabel=l;
581  }
582  ++c;
583  }
584  arg_[vi]=bestLabel;
585  }
586  inferenceDone_=true;
587 
588 
589  visitor.end(*this);
590  return NORMAL;
591  }
592 
593  template<class GM,class ACC>
594  inline InferenceTermination
596  ::arg(std::vector<LabelType>& arg, const size_t& n) const {
597  if(n > 1) {
598  return UNKNOWN;
599  }
600  else {
601  arg.resize(numVar_);
602  std::copy(arg_.begin(),arg_.end(),arg.begin());
603  return NORMAL;
604  }
605  }
606 
607 
608  } // namespace external
609 } // namespace opengm
610 
611 #endif // #ifndef OPENGM_EXTERNAL_AD3Inf_HXX
612 
The OpenGM namespace.
Definition: config.hxx:43
InferenceTermination arg(std::vector< LabelType > &, const size_t &=1) const
Definition: ad3.hxx:596
const GraphicalModelType & graphicalModel() const
Definition: ad3.hxx:518
AD3Inf(const GraphicalModelType &gm, const Parameter para=Parameter())
void addFactor(VI_ITERATOR viBegin, VI_ITERATOR viEnd, const FUNCTION &function)
Definition: ad3.hxx:331
ValueType bound() const
return a bound on the solution
Definition: ad3.hxx:83
Parameter(const SolverType solverType=AD3_ILP, const double eta=0.1, const bool adaptEta=true, UInt64Type steps=1000, const double residualThreshold=1e-6, const int verbosity=0)
Definition: ad3.hxx:40
STL namespace.
ValueType valueToMaxSum(const ValueType val) const
Definition: ad3.hxx:93
detail_types::UInt64Type UInt64Type
uint64
Definition: config.hxx:300
const std::vector< double > & posteriors() const
Definition: ad3.hxx:125
GraphicalModelType::IndexType IndexType
Definition: inference.hxx:40
ValueType value() const
return the solution (value)
Definition: ad3.hxx:79
visitors::EmptyVisitor< AD3Inf< GM, ACC > > EmptyVisitorType
Definition: ad3.hxx:30
GraphicalModelType::ValueType ValueType
Definition: inference.hxx:41
const std::vector< double > & higherOrderPosteriors() const
Definition: ad3.hxx:129
Inference algorithm interface.
Definition: inference.hxx:34
#define OPENGM_CHECK_OP(A, OP, B, TXT)
Definition: submodel2.hxx:24
std::string name() const
Definition: ad3.hxx:511
visitors::VerboseVisitor< AD3Inf< GM, ACC > > VerboseVisitorType
Definition: ad3.hxx:29
InferenceTermination infer()
Definition: ad3.hxx:525
#define OPENGM_CHECK(B, TXT)
Definition: submodel2.hxx:28
visitors::TimingVisitor< AD3Inf< GM, ACC > > TimingVisitorType
Definition: ad3.hxx:31
GraphicalModelType::LabelType LabelType
Definition: inference.hxx:39
ValueType valueFromMaxSum(const ValueType val) const
Definition: ad3.hxx:102
OpenGM runtime error.
Definition: opengm.hxx:100
InferenceTermination
Definition: inference.hxx:24