OpenGM  2.3.x
Discrete Graphical Model Library
external/libdai/inference.hxx
Go to the documentation of this file.
1 #ifndef OPENGM_LIBDAI_HXX
2 #define OPENGM_LIBDAI_HXX
3 
4 #include <vector>
5 #include <string>
6 #include <iostream>
7 #include <typeinfo>
8 #include <cmath>
9 #include <sstream>
10 
11 #include "opengm/opengm.hxx"
24 
25 #include <dai/alldai.h>
26 #include <dai/exceptions.h>
27 
29 
30 namespace opengm{
31 namespace external{
32 namespace libdai{
33 
34  template<class GM, class ACC ,class SOLVER>
35  class LibDaiInference
36  {
37  public:
38  typedef ACC AccumulationType;
39  typedef GM GraphicalModelType;
41  typedef opengm::visitors::VerboseVisitor< SOLVER > VerboseVisitorType;
42  typedef opengm::visitors::TimingVisitor< SOLVER > TimingVisitorType;
43  typedef opengm::visitors::EmptyVisitor< SOLVER > EmptyVisitorType;
44  ~LibDaiInference();
45  LibDaiInference(const GM & ,const std::string & );
46 
47  virtual const GraphicalModelType& graphicalModel_impl() const;
48  virtual void reset_impl();
49  virtual InferenceTermination infer_impl();
50  //template<class VISITOR>
51  //InferenceTermination infer(VISITOR&);
52  virtual InferenceTermination arg_impl(std::vector<LabelType>& v, const size_t= 1)const;
53  virtual InferenceTermination marginal_impl(const size_t, IndependentFactorType&) const;
54  virtual InferenceTermination factorMarginal_impl(const size_t, IndependentFactorType&) const;
55  protected:
56  ::dai::FactorGraph * convert(const GM &);
57  ::dai::FactorGraph * factorGraph_;
58  ::dai::InfAlg * ia_;
59  const GM & gm_;
60  std::string stringAlgParam_;
61  size_t numberOfExtraFactors_;
62  };
63 
64  template<class GM, class ACC ,class SOLVER>
66  delete ia_;
67  delete factorGraph_;
68  }
69 
70  template<class GM, class ACC ,class SOLVER>
72  (
73  const GM & gm,
74  const std::string & string_param
75  ): gm_(gm),
76  stringAlgParam_(string_param),
77  numberOfExtraFactors_(0) {
78  factorGraph_=convert(gm_);
79  ia_=dai::newInfAlgFromString(stringAlgParam_,*factorGraph_);
80  ia_->init();
81  }
82 
83  template<class GM, class ACC ,class SOLVER>
84  inline const GM &
86  return gm_;
87  }
88 
89  template<class GM, class ACC ,class SOLVER>
90  inline void
92  delete ia_;
93  delete factorGraph_;
94  factorGraph_=convert(gm_);
95  ia_=dai::newInfAlgFromString(stringAlgParam_,*factorGraph_);
96  ia_->init();
97  };
98 
99  template<class GM, class ACC ,class SOLVER>
100  inline InferenceTermination
102  try{
103  ia_->run();
104  return opengm::NORMAL;
105  }
106  catch(const dai::Exception & e) {
107  std::stringstream ss;
108  ss<<"libdai Error: "<<e.message(e.getCode())<<e.getMsg()<<"\n"<<e.getDetailedMsg();
109  throw ::opengm::RuntimeError(ss.str());
110  }
111  catch(...) {
113  }
114  return opengm::NORMAL;
115  }
116 
117 
118 
119 
120  template<class GM, class ACC ,class SOLVER>
121  inline InferenceTermination
123  (
124  const size_t variableIndex,
125  IndependentFactorType & marginalFactor
126  ) const{
127  try{
128  OPENGM_ASSERT(variableIndex<this->gm_.numberOfVariables());
129  OPENGM_ASSERT(variableIndex<this->factorGraph_->nrVars());
130  ::dai::Factor mf=this->ia_->belief(this->factorGraph_->var(variableIndex));
131  OPENGM_ASSERT(mf.nrStates()==gm_.numberOfLabels(variableIndex));
132  const size_t varIndex[]={variableIndex};
133  marginalFactor.assign(gm_,varIndex,varIndex +1);
134  for(size_t i=0;i<mf.nrStates();++i) {
135  marginalFactor(i)=mf.get(i);
136  if( opengm::meta::Compare<typename GM::OperatorType,opengm::Adder>::value && opengm::meta::Compare<ACC,opengm::Minimizer>::value) {
137  //back-transformation of f(x)=exp(-x);
138  marginalFactor(i)=static_cast<ValueType>(-1.0*std::log(mf.get(i)));
139  }
140  else if( opengm::meta::Compare<typename GM::OperatorType,opengm::Adder>::value && opengm::meta::Compare<ACC,opengm::Maximizer>::value) {
141  //back-transformation of f(x)=exp(x);
142  marginalFactor(i)=static_cast<ValueType>(std::log(mf.get(i)));
143  }
144  else if( opengm::meta::Compare<typename GM::OperatorType,opengm::Multiplier>::value && opengm::meta::Compare<ACC,opengm::Maximizer>::value) {
145  //back-transformation of f(x)=x;
146  marginalFactor(i)=static_cast<ValueType>(mf.get(i));
147  }
148  else if( opengm::meta::Compare<typename GM::OperatorType,opengm::Multiplier>::value && opengm::meta::Compare<ACC,opengm::Minimizer>::value) {
149  if(mf.get(i)==0.0) {
150  throw opengm::RuntimeError("zero marginal Values with OP=opengm::Multiplier with ACC=Minimizer are not supported in the opengm- libdai interface ");
151  }
152  //back-transformation of f(x)=1/x;
153  marginalFactor(i)=static_cast<ValueType>(1.0/mf.get(i));
154  }
155  else if( opengm::meta::Compare<typename GM::OperatorType,opengm::Multiplier>::value && opengm::meta::Compare<ACC,opengm::Integrator>::value) {
156  marginalFactor(i)=static_cast<ValueType>(mf.get(i));
157  }
158  else{
159  throw opengm::RuntimeError("OP/ACC not supported in the opengm-libdai interface ");
160  }
161  }
162  return opengm::NORMAL;
163  }
164  catch(const dai::Exception & e) {
165  std::stringstream ss;
166  ss<<"libdai Error: "<<e.message(e.getCode())<<" "<<e.getMsg()<<"\n"<<e.getDetailedMsg();
167  throw ::opengm::RuntimeError(ss.str());
168  }
169  catch(...) {
170  std::cout << "ERROR: Stopp with exception!"<<std::endl;
171  return opengm::UNKNOWN;
172  }
173  }
174  template<class GM, class ACC ,class SOLVER>
175  inline InferenceTermination
177  (
178  const size_t factorIndex,
179  IndependentFactorType & marginalFactor
180  ) const{
181  try{
182  OPENGM_ASSERT(factorIndex<this->gm_.numberOfFactors());
183  OPENGM_ASSERT(factorIndex<this->factorGraph_->nrFactors()-numberOfExtraFactors_);
184 
185  ::dai::VarSet varset;
186  for(size_t v=0;v<gm_[factorIndex].numberOfVariables();++v) {
187  varset.insert( ::dai::Var(gm_[factorIndex].variableIndex(v), gm_[factorIndex].numberOfLabels(v)) );
188  }
189  ::dai::Factor mf=this->ia_->belief(varset);
190  marginalFactor.assign(gm_,gm_[factorIndex].variableIndicesBegin(),gm_[factorIndex].variableIndicesEnd());
191  OPENGM_ASSERT(mf.nrStates()==marginalFactor.size());
192  for(size_t i=0;i<mf.nrStates();++i) {
193  marginalFactor(i)=mf.get(i);
194  if( opengm::meta::Compare<typename GM::OperatorType,opengm::Adder>::value && opengm::meta::Compare<ACC,opengm::Minimizer>::value) {
195  //back-transformation of f(x)=exp(-x);
196  marginalFactor(i)=static_cast<ValueType>(-1.0*std::log(mf.get(i)));
197  }
198  else if( opengm::meta::Compare<typename GM::OperatorType,opengm::Adder>::value && opengm::meta::Compare<ACC,opengm::Maximizer>::value) {
199  //back-transformation of f(x)=exp(x);
200  marginalFactor(i)=static_cast<ValueType>(std::log(mf.get(i)));
201  }
202  else if( opengm::meta::Compare<typename GM::OperatorType,opengm::Multiplier>::value && opengm::meta::Compare<ACC,opengm::Maximizer>::value) {
203  //back-transformation of f(x)=x;
204  marginalFactor(i)=static_cast<ValueType>(mf.get(i));
205  }
206  else if( opengm::meta::Compare<typename GM::OperatorType,opengm::Multiplier>::value && opengm::meta::Compare<ACC,opengm::Integrator>::value) {
207  //back-transformation of f(x)=x;
208  marginalFactor(i)=static_cast<ValueType>(mf.get(i));
209  }
210  else if( opengm::meta::Compare<typename GM::OperatorType,opengm::Multiplier>::value && opengm::meta::Compare<ACC,opengm::Minimizer>::value) {
211  if(mf.get(i)==0.0) {
212  throw opengm::RuntimeError("zero marginal Values with OP=opengm::Multiplier with ACC=Minimizer are not supported in the opengm- libdai interface ");
213  }
214  //back-transformation of f(x)=1/x;
215  marginalFactor(i)=static_cast<ValueType>(1.0/mf.get(i));
216  }
217  else if( opengm::meta::Compare<typename GM::OperatorType,opengm::Multiplier>::value && opengm::meta::Compare<ACC,opengm::Integrator>::value) {
218  marginalFactor(i)=static_cast<ValueType>(mf.get(i));
219  }
220  else{
221  throw opengm::RuntimeError("OP/ACC not supported in the opengm-libdai interface ");
222  }
223  }
224  return opengm::NORMAL;
225  }
226  catch(const dai::Exception & e) {
227  std::stringstream ss;
228  ss<<"libdai Error: "<<e.message(e.getCode())<<" "<<e.getMsg()<<"\n"<<e.getDetailedMsg();
229  throw ::opengm::RuntimeError(ss.str());
230  }
231  catch(...) {
232  std::cout << "ERROR: Stopp with exception!"<<std::endl;
233  return opengm::UNKNOWN;
234  }
235  }
236 
237  template<class GM, class ACC ,class SOLVER>
238  inline InferenceTermination
240  (
241  std::vector<typename LibDaiInference<GM,ACC,SOLVER>::LabelType>& v,
242  const size_t n
243  )const{
244  //std::cout <<"LIBDAI ARG"<<std::endl;
245  try{
246  std::vector<size_t> states=ia_->findMaximum();
247  v.assign(states.begin(),states.end());
248  return opengm::NORMAL;
249  }
250  catch(const dai::Exception & e) {
251  std::stringstream ss;
252  ss<<"libdai Error: "<<e.message(e.getCode())<<" "<<e.getMsg()<<"\n"<<e.getDetailedMsg();
253  throw ::opengm::RuntimeError(ss.str());
254  }
255  catch(...) {
257  }
258  return opengm::NORMAL;
259  }
260 
261  template<class GM, class ACC ,class SOLVER>
262  ::dai::FactorGraph * LibDaiInference<GM,ACC,SOLVER>::convert
263  (
264  const GM & gm
265  ) {
266  const size_t nrOfFactors=gm.numberOfFactors();
267  const size_t nrOfVariables=gm.numberOfVariables();
268  typedef typename GM::ValueType ValueType;
269  typedef double DaiValueType;
270 
271  std::vector< ::dai::Factor > factors;
272  factors.reserve(nrOfFactors);
273  std::vector<dai::Var> vars(nrOfVariables);
274  for (size_t i = 0; i < nrOfVariables; ++i) {
275  vars[i] = ::dai::Var(i, gm.numberOfLabels(i));
276  }
277  size_t maxFactorSize=0;
278  size_t maxNumberOfVariables=0;
279  for(size_t f=0;f<nrOfFactors;++f) {
280  const size_t factorSize=gm[f].size();
281  const size_t numberOfVariables=gm[f].numberOfVariables();
282  if(factorSize>maxFactorSize) maxFactorSize=factorSize;
283  if(numberOfVariables>maxNumberOfVariables) maxNumberOfVariables=numberOfVariables;
284  }
285  //buffer array for factor values
286  DaiValueType * factorData= new DaiValueType[maxFactorSize];
287  //buffer array for variables of a factor
288  ::dai::Var * varSet = new ::dai::Var[maxNumberOfVariables];
289  for(size_t f=0;f<nrOfFactors;++f) {
290  //factor information
291  const size_t factorSize=gm[f].size();
292  const size_t numberOfVariables=gm[f].numberOfVariables();
293  if(numberOfVariables==0) {
294  std::cout<<"\n\n WARNING \n\n";
295  }
296  //copy the variables of a factor into the varset
297  for(size_t v=0;v<numberOfVariables;++v) {
298  varSet[v]=vars[gm[f].variableIndex(v)];
299  }
300  dai::VarSet varset(varSet, varSet + numberOfVariables);
301  //marray view to the data for easy access
303  viewToFactorData(
304  gm[f].shapeBegin(),
305  gm[f].shapeEnd(),
306  factorData,
309  );
310  //fill factorData array with the data from the opengm factors
311  opengm::ShapeWalker<typename GM::FactorType::ShapeIteratorType> walker(gm[f].shapeBegin(),numberOfVariables);
312  for(size_t i=0;i<factorSize;++i) {
313  //viewToFactorData(walker.coordinateTuple().begin())=
314  if( opengm::meta::Compare<typename GM::OperatorType,opengm::Adder>::value &&
315  opengm::meta::Compare<ACC,opengm::Minimizer>::value) {
316  viewToFactorData(i)=std::exp(-1.0*
317  static_cast<DaiValueType>(gm[f](walker.coordinateTuple().begin())));
318  }
319  else if( opengm::meta::Compare<typename GM::OperatorType,opengm::Adder>::value &&
320  opengm::meta::Compare<ACC,opengm::Maximizer>::value) {
321  viewToFactorData(i)=std::exp(1.0*
322  static_cast<DaiValueType>(gm[f](walker.coordinateTuple().begin())));
323  }
324  else if( opengm::meta::Compare<typename GM::OperatorType,opengm::Multiplier>::value &&
325  opengm::meta::Compare<ACC,opengm::Maximizer>::value) {
326  viewToFactorData(i)=
327  static_cast<DaiValueType>(gm[f](walker.coordinateTuple().begin()));
328  }
329  else if( opengm::meta::Compare<typename GM::OperatorType,opengm::Multiplier>::value &&
330  opengm::meta::Compare<ACC,opengm::Integrator>::value) {
331  viewToFactorData(i)=
332  static_cast<DaiValueType>(gm[f](walker.coordinateTuple().begin()));
333  }
334  else if( opengm::meta::Compare<typename GM::OperatorType,opengm::Multiplier>::value &&
335  opengm::meta::Compare<ACC,opengm::Minimizer>::value) {
336  if(gm[f](walker.coordinateTuple().begin())==static_cast<ValueType>(0.0)) {
337  throw opengm::RuntimeError("zero Values with OP=opengm::Multiplier with ACC=Minimizer are not supported in the opengm- libdai interface ");
338  }
339  viewToFactorData(i)=static_cast<DaiValueType>(1.0)/
340  static_cast<DaiValueType>(gm[f](walker.coordinateTuple().begin()));
341  }
342  else {
343  throw opengm::RuntimeError("only build in OpenGM Operators and Accumulators are supported in the opengm- libdai interface ");
344  }
345  ++walker;
346  }
347  //add factor to the factor vector
348  dai::Factor factor(varset, factorData);
349  OPENGM_ASSERT(factor.nrStates()==gm[f].size());
350  factors.push_back(factor);
351  }
352  dai::FactorGraph * factorGraph = new dai::FactorGraph(factors.begin(), factors.end(), vars.begin(), vars.end());
353  delete [] factorData;
354  delete [] varSet;
355  OPENGM_ASSERT(factorGraph->nrFactors()==gm.numberOfFactors());
356  OPENGM_ASSERT(factorGraph->nrVars()==gm.numberOfVariables());
357  return factorGraph;
358  }
359 } // end namespace libdai
360 } // end namespace external
361 } //end namespace opengm
362 
364 
365 #endif // OPENGM_LIBDAI_HXX
The OpenGM namespace.
Definition: config.hxx:43
Array-Interface to an interval of memory.
Definition: marray.hxx:44
#define OPENGM_ASSERT(expression)
Definition: opengm.hxx:77
#define OPENGM_GM_TYPE_TYPEDEFS
Definition: inference.hxx:13
OpenGM runtime error.
Definition: opengm.hxx:100
InferenceTermination
Definition: inference.hxx:24