OpenGM  2.3.x
Discrete Graphical Model Library
grante.hxx
Go to the documentation of this file.
1 #ifndef GRANTE_HXX_
2 #define GRANTE_HXX_
3 
4 #include <sstream>
5 
11 
12 // grante includes
13 #include "FactorGraph.h"
14 #include "BruteForceExactInference.h"
15 #include "BeliefPropagation.h"
16 #include "DiffusionInference.h"
17 #include "SimulatedAnnealingInference.h"
18 
19 namespace opengm {
20  namespace external {
21 
27  // GRANTE
33  template<class GM>
34  class GRANTE : public Inference<GM, opengm::Minimizer> {
35  public:
36  typedef GM GraphicalModelType;
42 
44  struct Parameter {
49  // Used to define the threshold for stopping condition for Belief Propagation method
50  double tolerance_;
51  // Print iteration statistics for Belief Propagation method
52  bool verbose_;
53 
54  // Select MessageSchedule type for Belief Propagation method
55  Grante::BeliefPropagation::MessageSchedule BPSchedule_;
56 
57  // Number of simulated annealing distributions
58  unsigned int SASteps_;
59  // Initial Boltzmann temperature for simulated annealing.
60  double SAT0_;
61  // Final Boltzmann temperature for simulated annealing.
62  double SATfinal_;
63 
65  Parameter() : inferenceType_(BRUTEFORCE), numberOfIterations_(100),
66  tolerance_(1.0e-6), verbose_(false),
67  BPSchedule_(Grante::BeliefPropagation::Sequential), SASteps_(100),
68  SAT0_(10.0), SATfinal_(0.05) {
69  }
70  };
71  // construction
72  GRANTE(const GraphicalModelType& gm, const Parameter& para);
73  // destruction
74  ~GRANTE();
75  // query
76  std::string name() const;
77  const GraphicalModelType& graphicalModel() const;
78  // inference
79  template<class VISITOR>
80  InferenceTermination infer(VISITOR & visitor);
82  InferenceTermination arg(std::vector<LabelType>&, const size_t& = 1) const;
83  typename GM::ValueType bound() const;
84  typename GM::ValueType value() const;
85 
86  protected:
87  const GraphicalModelType& gm_;
91 
92  Grante::FactorGraphModel* granteModel_;
93  Grante::FactorGraph* granteGraph_;
94  Grante::InferenceMethod* granteInferenceMethod_;
95  std::vector<unsigned int> granteState_;
96  std::vector<Grante::FactorDataSource*> granteDataSourceCollector_;
97 
98  bool sanityCheck(ValueType value) const;
99 
100  void groupFactors(std::vector<std::vector<IndexType> >& groupedFactors) const;
101  void groupFactorTypes(const std::vector<std::vector<IndexType> >& groupedFactors, std::vector<std::vector<IndexType> >& groupedFactorTypes) const;
102 
103  template<class T, class OBJECT>
104  struct InsertFunctor {
105  void operator()(const T v) {
106  (*object_)[index_] = static_cast<double>(v);
107  index_++;
108  }
109  int index_;
110  OBJECT* object_;
111  };
112  };
113 
114  template<class GM>
116  : gm_(gm), parameter_(para), granteModel_(new Grante::FactorGraphModel()), granteGraph_(NULL),
117  granteInferenceMethod_(NULL) {
118 
119  // group factors
120  std::vector<std::vector<IndexType> > groupedFactors;
121  groupFactors(groupedFactors);
122 
123  // group grante factor types
124  std::vector<std::vector<IndexType> > groupedFactorTypes;
125  groupFactorTypes(groupedFactors, groupedFactorTypes);
126 
127  // add factor types
128  for(size_t i = 0; i < groupedFactorTypes.size(); i++) {
129  // create unique factor type name
130  std::stringstream ss;
131  ss << i;
132  std::string name = ss.str();
133 
134  // select representative factor
135  IndexType currentFactor = groupedFactors[groupedFactorTypes[i][0]][0];
136 
137  // set number of labels for each variable
138  std::vector<unsigned int> cardinalities;
139  for(IndexType j = 0; j < gm_[currentFactor].numberOfVariables(); j++) {
140  cardinalities.push_back(static_cast<unsigned int>(gm_.numberOfLabels(gm_[currentFactor].variableIndex(j))));
141  }
142 
143  // add factor type to model
144  granteModel_->AddFactorType(new Grante::FactorType(name, cardinalities, std::vector<double>()));
145  }
146 
147  // get number of labels for all variables
148  std::vector<unsigned int> cardinalities;
149  for(IndexType i = 0; i < gm_.numberOfVariables(); i++) {
150  cardinalities.push_back(gm_.numberOfLabels(i));
151  }
152  // create factor graph from model
153  granteGraph_ = new Grante::FactorGraph(granteModel_, cardinalities);
154 
155  // add factors to graph
156  for(size_t i = 0; i < groupedFactorTypes.size(); i++) {
157  // create unique factor type name
158  std::stringstream ss;
159  ss << i;
160  std::string name = ss.str();
161  // get factor type by name
162  Grante::FactorType* currentFactorType = granteModel_->FindFactorType(name);
163  // add all factors with same factor type
164  OPENGM_ASSERT(groupedFactorTypes[i].size() > 0);
165  for(size_t j = 0; j < groupedFactorTypes[i].size(); j++) {
166  OPENGM_ASSERT(groupedFactors[groupedFactorTypes[i][j]].size() > 0);
167  if(groupedFactors[groupedFactorTypes[i][j]].size() == 1) {
168  // single factor, no shared data
169  IndexType currentFactor = groupedFactors[groupedFactorTypes[i][j]][0];
170  // determine connected variables
171  std::vector<unsigned int> var_index;
172  for(IndexType k = 0; k < gm_[currentFactor].numberOfVariables(); k++) {
173  var_index.push_back(static_cast<unsigned int>(gm_[currentFactor].variableIndex(k)));
174  }
175  // copy data
176  std::vector<double> data(currentFactorType->ProdCardinalities());
177  ViewFunction<GM> function = gm_[currentFactor];
179  inserter.index_ = 0;
180  inserter.object_ = &data;
181  function.forAllValuesInOrder(inserter);
182 
183  // crate factor
184  Grante::Factor* factor = new Grante::Factor(currentFactorType, var_index, data);
185  // add factor to graph (graph takes ownership)
186  granteGraph_->AddFactor(factor);
187  } else {
188  // multiple factors with shared data
189  IndexType currentFactor = groupedFactors[groupedFactorTypes[i][j]][0];
190  // create shared factor data
191  std::vector<double> data(currentFactorType->ProdCardinalities());
192  ViewFunction<GM> function = gm_[currentFactor];
194  inserter.index_ = 0;
195  inserter.object_ = &data;
196  function.forAllValuesInOrder(inserter);
197  Grante::FactorDataSource* currentDataSource = new Grante::FactorDataSource(data);
198  granteDataSourceCollector_.push_back(currentDataSource);
199  // add all factors with shared data
200  for(size_t k = 0; k < groupedFactors[groupedFactorTypes[i][j]].size(); k++) {
201  currentFactor = groupedFactors[groupedFactorTypes[i][j]][k];
202  // determine connected variables
203  std::vector<unsigned int> var_index;
204  for(IndexType l = 0; l < gm_[currentFactor].numberOfVariables(); l++) {
205  var_index.push_back(static_cast<unsigned int>(gm_[currentFactor].variableIndex(l)));
206  }
207  // crate factor
208  Grante::Factor* factor = new Grante::Factor(currentFactorType, var_index, currentDataSource);
209  // add factor to graph (graph takes ownership)
210  granteGraph_->AddFactor(factor);
211  }
212  }
213  }
214  }
215 
216  // Perform forward map: update energies upon model change
217  granteGraph_->ForwardMap();
218 
219  // set inference method
220  switch(parameter_.inferenceType_) {
221  case Parameter::BRUTEFORCE : {
222  granteInferenceMethod_ = new Grante::BruteForceExactInference(granteGraph_);
223  break;
224  }
225  case Parameter::BP : {
226  granteInferenceMethod_ = new Grante::BeliefPropagation(granteGraph_, parameter_.BPSchedule_);
227  static_cast<Grante::BeliefPropagation*>(granteInferenceMethod_)->SetParameters(parameter_.verbose_, parameter_.numberOfIterations_, parameter_.tolerance_);
228  break;
229  }
230  case Parameter::DIFFUSION : {
231  granteInferenceMethod_ = new Grante::DiffusionInference(granteGraph_);
232  static_cast<Grante::DiffusionInference*>(granteInferenceMethod_)->SetParameters(parameter_.verbose_, parameter_.numberOfIterations_, parameter_.tolerance_);
233  break;
234  }
235  case Parameter::SA : {
236  granteInferenceMethod_ = new Grante::SimulatedAnnealingInference(granteGraph_, parameter_.verbose_);
237  static_cast<Grante::SimulatedAnnealingInference*>(granteInferenceMethod_)->SetParameters(parameter_.SASteps_, parameter_.SAT0_, parameter_.SATfinal_);
238  break;
239  }
240  default: {
241  throw(RuntimeError("Unknown inference type"));
242  }
243  }
244  // set initial value and lower bound
247  }
248 
249  template<class GM>
251  if(granteInferenceMethod_) {
252  delete granteInferenceMethod_;
253  }
254  for(size_t i = 0; i < granteDataSourceCollector_.size(); i++) {
255  delete granteDataSourceCollector_[i];
256  }
257  if(granteGraph_) {
258  delete granteGraph_;
259  }
260  if(granteModel_) {
261  delete granteModel_;
262  }
263  }
264 
265  template<class GM>
266  inline std::string GRANTE<GM>::name() const {
267  return "GRANTE";
268  }
269 
270  template<class GM>
272  return gm_;
273  }
274 
275  template<class GM>
277  EmptyVisitorType visitor;
278  return this->infer(visitor);
279  }
280 
281  template<class GM>
282  template<class VISITOR>
283  inline InferenceTermination GRANTE<GM>::infer(VISITOR & visitor) {
284  visitor.begin(*this);
285  value_ = granteInferenceMethod_->MinimizeEnergy(granteState_);
286  visitor.end(*this);
287  return NORMAL;
288  }
289 
290  template<class GM>
291  inline InferenceTermination GRANTE<GM>::arg(std::vector<LabelType>& arg, const size_t& n) const {
292  arg.resize(gm_.numberOfVariables());
293  for(IndexType i = 0; i < gm_.numberOfVariables(); i++) {
294  arg[i] = static_cast<LabelType>(granteState_[i]);
295  }
296  return NORMAL;
297  }
298 
299  template<class GM>
300  inline typename GM::ValueType GRANTE<GM>::bound() const {
301  return lowerBound_;
302  }
303 
304  template<class GM>
305  inline typename GM::ValueType GRANTE<GM>::value() const {
306  //sanity check
307  OPENGM_ASSERT(sanityCheck(value_));
308  return value_;
309  }
310 
311  template<class GM>
312  inline bool GRANTE<GM>::sanityCheck(ValueType value) const {
313  if(granteState_.size() > 0) {
314  std::vector<LabelType> result;
315  arg(result);
316  return fabs(value - gm_.evaluate(result)) < OPENGM_FLOAT_TOL;
317  } else {
318  ValueType temp;
319  AccumulationType::neutral(temp);
320  return value == temp;
321  }
322  }
323 
324  template<class GM>
325  inline void GRANTE<GM>::groupFactors(std::vector<std::vector<IndexType> >& groupedFactors) const {
326  // Factors are grouped by function index and the cardinalities of the connected variables.
327  groupedFactors.clear();
328  typedef std::map<std::pair<IndexType, std::vector<LabelType> >, size_t> Map;
329  Map lookupTable;
330  for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
331  IndexType currentFunctionIndex = gm_[i].functionIndex();
332  std::vector<LabelType> currentCardinalities;
333  for(IndexType j = 0; j < gm_[i].numberOfVariables(); j++) {
334  currentCardinalities.push_back(gm_.numberOfLabels(gm_[i].variableIndex(j)));
335  }
336  std::pair<IndexType, std::vector<LabelType> > currentKey(currentFunctionIndex, currentCardinalities);
337  typename Map::const_iterator iter = lookupTable.find(currentKey);
338  if(iter != lookupTable.end()) {
339  groupedFactors[iter->second].push_back(i);
340  } else {
341  std::vector<IndexType> newVec(1, i);
342  groupedFactors.push_back(newVec);
343  lookupTable[currentKey] = groupedFactors.size() - 1;
344  }
345  }
346  }
347 
348  template<class GM>
349  inline void GRANTE<GM>::groupFactorTypes(const std::vector<std::vector<IndexType> >& groupedFactors, std::vector<std::vector<IndexType> >& groupedFactorTypes) const {
350  groupedFactorTypes.clear();
351  typedef std::map<std::vector<LabelType>, size_t > Map;
352  Map lookupTable;
353  for(IndexType i = 0; i < groupedFactors.size(); i++) {
354  IndexType currentNumberOfVariables = gm_[groupedFactors[i][0]].numberOfVariables();
355  std::vector<LabelType> currentCardinalities;
356  for(IndexType j = 0; j < currentNumberOfVariables; j++) {
357  currentCardinalities.push_back(gm_.numberOfLabels(gm_[groupedFactors[i][0]].variableIndex(j)));
358  }
359  typename Map::const_iterator iter = lookupTable.find(currentCardinalities);
360  if(iter != lookupTable.end()) {
361  groupedFactorTypes[iter->second].push_back(i);
362  } else {
363  std::vector<IndexType> newVec(1, i);
364  groupedFactorTypes.push_back(newVec);
365  lookupTable[currentCardinalities] = groupedFactorTypes.size() - 1;
366  }
367  }
368  }
369 
370  } // namespace external
371 } // namespace opengm
372 
373 #endif /* GRANTE_HXX_ */
#define OPENGM_FLOAT_TOL
The OpenGM namespace.
Definition: config.hxx:43
Grante::BeliefPropagation::MessageSchedule BPSchedule_
Definition: grante.hxx:55
visitors::TimingVisitor< GRANTE< GM > > TimingVisitorType
Definition: grante.hxx:41
Grante::InferenceMethod * granteInferenceMethod_
Definition: grante.hxx:94
void groupFactors(std::vector< std::vector< IndexType > > &groupedFactors) const
Definition: grante.hxx:325
GRANTE GRANTE inference algorithm class.
Definition: grante.hxx:34
bool sanityCheck(ValueType value) const
Definition: grante.hxx:312
visitors::VerboseVisitor< GRANTE< GM > > VerboseVisitorType
Definition: grante.hxx:39
const GraphicalModelType & graphicalModel() const
Definition: grante.hxx:271
#define OPENGM_ASSERT(expression)
Definition: opengm.hxx:77
std::vector< unsigned int > granteState_
Definition: grante.hxx:95
std::string name() const
Definition: grante.hxx:266
InferenceTermination arg(std::vector< LabelType > &, const size_t &=1) const
Definition: grante.hxx:291
Grante::FactorGraph * granteGraph_
Definition: grante.hxx:93
GraphicalModelType::IndexType IndexType
Definition: inference.hxx:40
reference to a Factor of a GraphicalModel
Definition: view.hxx:13
GM::ValueType value() const
return the solution (value)
Definition: grante.hxx:305
void groupFactorTypes(const std::vector< std::vector< IndexType > > &groupedFactors, std::vector< std::vector< IndexType > > &groupedFactorTypes) const
Definition: grante.hxx:349
GraphicalModelType::ValueType ValueType
Definition: inference.hxx:41
static T ineutral()
inverse neutral element (with return)
Definition: minimizer.hxx:25
Inference algorithm interface.
Definition: inference.hxx:34
size_t numberOfIterations_
number of iterations for Belief Propagation method
Definition: grante.hxx:48
Grante::FactorGraphModel * granteModel_
Definition: grante.hxx:92
const GraphicalModelType & gm_
Definition: grante.hxx:87
static T neutral()
neutral element (with return)
Definition: minimizer.hxx:16
GRANTE(const GraphicalModelType &gm, const Parameter &para)
Definition: grante.hxx:115
GraphicalModelType::LabelType LabelType
Definition: inference.hxx:39
Minimization as a unary accumulation.
Definition: minimizer.hxx:12
InferenceTermination infer()
Definition: grante.hxx:276
std::vector< Grante::FactorDataSource * > granteDataSourceCollector_
Definition: grante.hxx:96
OpenGM runtime error.
Definition: opengm.hxx:100
opengm::Minimizer AccumulationType
Definition: grante.hxx:37
GM::ValueType bound() const
return a bound on the solution
Definition: grante.hxx:300
InferenceTermination
Definition: inference.hxx:24
visitors::EmptyVisitor< GRANTE< GM > > EmptyVisitorType
Definition: grante.hxx:40