OpenGM  2.3.x
Discrete Graphical Model Library
messagepassing.hxx
Go to the documentation of this file.
1 #pragma once
2 #ifndef OPENGM_MESSAGE_PASSING_HXX
3 #define OPENGM_MESSAGE_PASSING_HXX
4 
5 #include <vector>
6 #include <map>
7 #include <list>
8 #include <set>
9 
10 #include "opengm/opengm.hxx"
19 
20 namespace opengm {
21 
24 struct MaxDistance {
28  template<class M>
29  static typename M::ValueType
30  op(const M& in1, const M& in2)
31  {
32  typedef typename M::ValueType ValueType;
33  ValueType v1,v2,d1,d2;
34  Maximizer::neutral(v1);
36  for(size_t n=0; n<in1.size(); ++n) {
37  d1=in1(n)-in2(n);
38  d2=-d1;
39  Maximizer::op(d1,v1);
40  Maximizer::op(d2,v2);
41  }
42  Maximizer::op(v2,v1);
43  return v1;
44  }
45 };
46 
49 template<class GM, class ACC, class UPDATE_RULES, class DIST=opengm::MaxDistance>
50 class MessagePassing : public Inference<GM, ACC> {
51 public:
52  typedef GM GraphicalModelType;
53  typedef ACC Accumulation;
54  typedef ACC AccumulatorType;
56  typedef DIST Distance;
57  typedef typename UPDATE_RULES::FactorHullType FactorHullType;
58  typedef typename UPDATE_RULES::VariableHullType VariableHullType;
59 
66 
67  struct Parameter {
68  typedef typename UPDATE_RULES::SpecialParameterType SpecialParameterType;
69  Parameter
70  (
71  const size_t maximumNumberOfSteps = 100,
72  const ValueType bound = static_cast<ValueType> (0.000000),
73  const ValueType damping = static_cast<ValueType> (0),
74  const SpecialParameterType & specialParameter =SpecialParameterType(),
75  const opengm::Tribool isAcyclic = opengm::Tribool::Maybe
76  )
77  : maximumNumberOfSteps_(maximumNumberOfSteps),
78  bound_(bound),
79  damping_(damping),
80  inferSequential_(false),
81  useNormalization_(true),
82  specialParameter_(specialParameter),
83  isAcyclic_(isAcyclic)
84  {}
85 
90  std::vector<size_t> sortedNodeList_;
92  //bool useNormalization_;
93  SpecialParameterType specialParameter_;
95  };
96 
98  struct Message {
99  Message()
100  : nodeId_(-1),
101  internalMessageId_(-1)
102  {}
103  Message(const size_t nodeId, const size_t & internalMessageId)
104  : nodeId_(nodeId),
105  internalMessageId_(internalMessageId)
106  {}
107 
108  size_t nodeId_;
109  size_t internalMessageId_;
110  };
112 
113  MessagePassing(const GraphicalModelType&, const Parameter& = Parameter());
114  std::string name() const;
115  const GraphicalModelType& graphicalModel() const;
116  InferenceTermination marginal(const size_t, IndependentFactorType& out) const;
117  InferenceTermination factorMarginal(const size_t, IndependentFactorType & out) const;
118  ValueType convergenceXF() const;
119  ValueType convergenceFX() const;
120  ValueType convergence() const;
121  virtual void reset();
123  template<class VisitorType>
124  InferenceTermination infer(VisitorType&);
125  void propagate(const ValueType& = 0);
126  InferenceTermination arg(std::vector<LabelType>&, const size_t = 1) const;
127  void setMaxSteps(size_t maxSteps) {parameter_.maximumNumberOfSteps_ = maxSteps;}
128  //InferenceTermination bound(ValueType&) const;
129  //ValueType bound() const;
130 
131 private:
132  void inferAcyclic();
133  void inferParallel();
134  void inferSequential();
135  template<class VisitorType>
136  void inferParallel(VisitorType&);
137  template<class VisitorType>
138  void inferAcyclic(VisitorType&);
139  template<class VisitorType>
140  void inferSequential(VisitorType&);
141 private:
142  const GraphicalModelType& gm_;
143  Parameter parameter_;
144  std::vector<FactorHullType> factorHulls_;
145  std::vector<VariableHullType> variableHulls_;
146 };
147 
148 template<class GM, class ACC, class UPDATE_RULES, class DIST>
150 (
151  const GraphicalModelType& gm,
153 )
154 : gm_(gm),
155  parameter_(parameter)
156 {
157  if(parameter_.sortedNodeList_.size() == 0) {
158  parameter_.sortedNodeList_.resize(gm.numberOfVariables());
159  for (size_t i = 0; i < gm.numberOfVariables(); ++i)
160  parameter_.sortedNodeList_[i] = i;
161  }
162  OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm.numberOfVariables());
163 
164  UPDATE_RULES::initializeSpecialParameter(gm_,this->parameter_);
165 
166  // set hulls
167  variableHulls_.resize(gm.numberOfVariables(), VariableHullType ());
168  for (size_t i = 0; i < gm.numberOfVariables(); ++i) {
169  variableHulls_[i].assign(gm, i, &parameter_.specialParameter_);
170  }
171  factorHulls_.resize(gm.numberOfFactors(), FactorHullType ());
172  for (size_t i = 0; i < gm.numberOfFactors(); i++) {
173  factorHulls_[i].assign(gm, i, variableHulls_, &parameter_.specialParameter_);
174  }
175 }
176 
177 template<class GM, class ACC, class UPDATE_RULES, class DIST>
178 void
180 {
181  if(parameter_.sortedNodeList_.size() == 0) {
182  parameter_.sortedNodeList_.resize(gm_.numberOfVariables());
183  for (size_t i = 0; i < gm_.numberOfVariables(); ++i)
184  parameter_.sortedNodeList_[i] = i;
185  }
186  OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm_.numberOfVariables());
187  UPDATE_RULES::initializeSpecialParameter(gm_,this->parameter_);
188 
189  // set hulls
190  variableHulls_.resize(gm_.numberOfVariables(), VariableHullType ());
191  for (size_t i = 0; i < gm_.numberOfVariables(); ++i) {
192  variableHulls_[i].assign(gm_, i, &parameter_.specialParameter_);
193  }
194  factorHulls_.resize(gm_.numberOfFactors(), FactorHullType ());
195  for (size_t i = 0; i < gm_.numberOfFactors(); i++) {
196  factorHulls_[i].assign(gm_, i, variableHulls_, &parameter_.specialParameter_);
197  }
198 }
199 
200 template<class GM, class ACC, class UPDATE_RULES, class DIST>
201 inline std::string
203  return "MP";
204 }
205 
206 template<class GM, class ACC, class UPDATE_RULES, class DIST>
209  return gm_;
210 }
211 
212 template<class GM, class ACC, class UPDATE_RULES, class DIST>
215  EmptyVisitorType v;
216  return infer(v);
217 }
218 
219 template<class GM, class ACC, class UPDATE_RULES, class DIST>
220 template<class VisitorType>
223 (
224  VisitorType& visitor
225 ) {
226  if (parameter_.isAcyclic_ == opengm::Tribool::True) {
227  if(parameter_.useNormalization_==opengm::Tribool::Maybe)
228  parameter_.useNormalization_=false;
229  inferAcyclic(visitor);
230  } else if (parameter_.isAcyclic_ == opengm::Tribool::False) {
231  if (parameter_.inferSequential_) {
232  inferSequential(visitor);
233  } else {
234  inferParallel(visitor);
235  }
236  } else { //triibool maby
237  if (gm_.isAcyclic()) {
238  parameter_.isAcyclic_ = opengm::Tribool::True;
239  if(parameter_.useNormalization_==opengm::Tribool::Maybe)
240  parameter_.useNormalization_=false;
241  inferAcyclic(visitor);
242  } else {
243  parameter_.isAcyclic_ = opengm::Tribool::False;
244  if (parameter_.inferSequential_) {
245  inferSequential(visitor);
246  } else {
247  inferParallel(visitor);
248  }
249  }
250  }
251  return NORMAL;
252 }
253 
259 template<class GM, class ACC, class UPDATE_RULES, class DIST>
260 inline void
262  EmptyVisitorType v;
263  return inferAcyclic(v);
264 }
265 
267 //
273 template<class GM, class ACC, class UPDATE_RULES, class DIST>
274 template<class VisitorType>
275 void
276 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferAcyclic
277 (
278  VisitorType& visitor
279 )
280 {
281  OPENGM_ASSERT(gm_.isAcyclic());
282  visitor.begin(*this);
283  size_t numberOfVariables = gm_.numberOfVariables();
284  size_t numberOfFactors = gm_.numberOfFactors();
285  // number of messages which have not yet been recevied
286  // but are required for sending
287  std::vector<std::vector<size_t> > counterVar2FacMessage(numberOfVariables);
288  std::vector<std::vector<size_t> > counterFac2VarMessage(numberOfFactors);
289  // list of messages which are ready to send
290  std::vector<Message> ready2SendVar2FacMessage;
291  std::vector<Message> ready2SendFac2VarMessage;
292  ready2SendVar2FacMessage.reserve(100);
293  ready2SendFac2VarMessage.reserve(100);
294  for (size_t fac = 0; fac < numberOfFactors; ++fac) {
295  counterFac2VarMessage[fac].resize(gm_[fac].numberOfVariables(), gm_[fac].numberOfVariables() - 1);
296  }
297  for (size_t var = 0; var < numberOfVariables; ++var) {
298  counterVar2FacMessage[var].resize(gm_.numberOfFactors(var));
299  for (size_t i = 0; i < gm_.numberOfFactors(var); ++i) {
300  counterVar2FacMessage[var][i] = gm_.numberOfFactors(var) - 1;
301  }
302  }
303  // find all messages which are ready for sending
304  for (size_t var = 0; var < numberOfVariables; ++var) {
305  for (size_t i = 0; i < counterVar2FacMessage[var].size(); ++i) {
306  if (counterVar2FacMessage[var][i] == 0) {
307  --counterVar2FacMessage[var][i];
308  ready2SendVar2FacMessage.push_back(Message(var, i));
309  }
310  }
311  }
312  for (size_t fac = 0; fac < numberOfFactors; ++fac) {
313  for (size_t i = 0; i < counterFac2VarMessage[fac].size(); ++i) {
314  if (counterFac2VarMessage[fac][i] == 0) {
315  --counterFac2VarMessage[fac][i];
316  ready2SendFac2VarMessage.push_back(Message(fac, i));
317  }
318  }
319  }
320  // send messages
321  while (ready2SendVar2FacMessage.size() > 0 || ready2SendFac2VarMessage.size() > 0) {
322  while (ready2SendVar2FacMessage.size() > 0) {
323  Message m = ready2SendVar2FacMessage.back();
324  size_t nodeId = m.nodeId_;
325  size_t factorId = gm_.factorOfVariable(nodeId,m.internalMessageId_);
326  // send message
327  variableHulls_[nodeId].propagate(gm_, m.internalMessageId_, 0, false);
328  ready2SendVar2FacMessage.pop_back();
329  //check if new messages can be sent
330  for (size_t i = 0; i < gm_[factorId].numberOfVariables(); ++i) {
331  if (gm_[factorId].variableIndex(i) != nodeId) {
332  if (--counterFac2VarMessage[factorId][i] == 0) {
333  ready2SendFac2VarMessage.push_back(Message(factorId, i));
334  }
335  }
336  }
337  }
338  while (ready2SendFac2VarMessage.size() > 0) {
339  Message m = ready2SendFac2VarMessage.back();
340  size_t factorId = m.nodeId_;
341  size_t nodeId = gm_[factorId].variableIndex(m.internalMessageId_);
342  // send message
343  factorHulls_[factorId].propagate(m.internalMessageId_, 0, parameter_.useNormalization_);
344  ready2SendFac2VarMessage.pop_back();
345  // check if new messages can be sent
346  for (size_t i = 0; i < gm_.numberOfFactors(nodeId); ++i) {
347  if (gm_.factorOfVariable(nodeId,i) != factorId) {
348  if (--counterVar2FacMessage[nodeId][i] == 0) {
349  ready2SendVar2FacMessage.push_back(Message(nodeId, i));
350  }
351  }
352  }
353  }
354  if(visitor(*this)!=0)
355  break;
356  }
357  visitor.end(*this);
358 
359 }
360 
362 template<class GM, class ACC, class UPDATE_RULES, class DIST>
364 (
365  const ValueType& damping
366 ) {
367  for (size_t i = 0; i < variableHulls_.size(); ++i) {
368  variableHulls_[i].propagateAll(damping, false);
369  }
370  for (size_t i = 0; i < factorHulls_.size(); ++i) {
371  factorHulls_[i].propagateAll(damping, parameter_.useNormalization_);
372  }
373 }
374 
376 template<class GM, class ACC, class UPDATE_RULES, class DIST>
378  EmptyVisitorType v;
379  return inferParallel(v);
380 }
381 
384 template<class GM, class ACC, class UPDATE_RULES, class DIST>
385 template<class VisitorType>
386 inline void MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferParallel
387 (
388  VisitorType& visitor
389 )
390 {
391  ValueType c = 0;
392  ValueType damping = parameter_.damping_;
393  visitor.begin(*this);
394 
395  // let all Factors with a order lower than 2 sending their Message
396  for (size_t i = 0; i < factorHulls_.size(); ++i) {
397  if (factorHulls_[i].numberOfBuffers() < 2) {
398  factorHulls_[i].propagateAll(0, parameter_.useNormalization_);
399  factorHulls_[i].propagateAll(0, parameter_.useNormalization_); // 2 times to fill both buffers
400  }
401  }
402  for (unsigned long n = 0; n < parameter_.maximumNumberOfSteps_; ++n) {
403  for (size_t i = 0; i < variableHulls_.size(); ++i) {
404  variableHulls_[i].propagateAll(gm_, damping, false);
405  }
406  for (size_t i = 0; i < factorHulls_.size(); ++i) {
407  if (factorHulls_[i].numberOfBuffers() >= 2)// messages from factors of order <2 do not change
408  factorHulls_[i].propagateAll(damping, parameter_.useNormalization_);
409  }
410  if(visitor(*this)!=0)
411  break;
412  c = convergence();
413  if (c < parameter_.bound_) {
414  break;
415  }
416  }
417  visitor.end(*this);
418 
419 }
420 
429 template<class GM, class ACC, class UPDATE_RULES, class DIST>
430 inline void MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferSequential() {
431  EmptyVisitorType v;
432  return inferSequential(v);
433 }
434 
445 template<class GM, class ACC, class UPDATE_RULES, class DIST>
446 template<class VisitorType>
447 inline void MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferSequential
448 (
449  VisitorType& visitor
450 ) {
451  OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm_.numberOfVariables());
452  visitor.begin(*this);
453  ValueType damping = parameter_.damping_;
454 
455  // set nodeOrder
456  std::vector<size_t> nodeOrder(gm_.numberOfVariables());
457  for (size_t o = 0; o < gm_.numberOfVariables(); ++o) {
458  nodeOrder[parameter_.sortedNodeList_[o]] = o;
459  }
460 
461  // let all Factors with a order lower than 2 sending their Message
462  for (size_t f = 0; f < factorHulls_.size(); ++f) {
463  if (factorHulls_[f].numberOfBuffers() < 2) {
464  factorHulls_[f].propagateAll(0, parameter_.useNormalization_);
465  factorHulls_[f].propagateAll(0, parameter_.useNormalization_); //2 times to fill both buffers
466  }
467  }
468 
469  // calculate inverse positions
470  std::vector<std::vector<size_t> > inversePositions(gm_.numberOfVariables());
471  for(size_t var=0; var<gm_.numberOfVariables();++var) {
472  for(size_t i=0; i<gm_.numberOfFactors(var); ++i) {
473  size_t factorId = gm_.factorOfVariable(var,i);
474  for(size_t j=0; j<gm_.numberOfVariables(factorId);++j) {
475  if(gm_.variableOfFactor(factorId,j)==var) {
476  inversePositions[var].push_back(j);
477  break;
478  }
479  }
480  }
481  }
482 
483 
484  // the following Code is not optimized and maybe too slow for small factors
485  for (unsigned long itteration = 0; itteration < parameter_.maximumNumberOfSteps_; ++itteration) {
486  if(itteration%2==0) {
487  // in increasing ordering
488  for (size_t o = 0; o < gm_.numberOfVariables(); ++o) {
489  size_t variableId = parameter_.sortedNodeList_[o];
490  // update messages to the variable node
491  for(size_t i=0; i<gm_.numberOfFactors(variableId); ++i) {
492  size_t factorId = gm_.factorOfVariable(variableId,i);
493  factorHulls_[factorId].propagate(inversePositions[variableId][i], damping, parameter_.useNormalization_);
494  }
495 
496  // update messages from the variable node
497  variableHulls_[variableId].propagateAll(gm_, damping, false);
498  }
499  }
500  else{
501  // in decreasing ordering
502  for (size_t o = 0; o < gm_.numberOfVariables(); ++o) {
503  size_t variableId = parameter_.sortedNodeList_[gm_.numberOfVariables() - 1 - o];
504  // update messages to the variable node
505  for(size_t i=0; i<gm_.numberOfFactors(variableId); ++i) {
506  size_t factorId = gm_.factorOfVariable(variableId,i);
507  factorHulls_[factorId].propagate(inversePositions[variableId][i], damping, parameter_.useNormalization_);
508  }
509  // update messages from Variable
510  variableHulls_[variableId].propagateAll(gm_, damping, false);
511  }
512  }
513  if(visitor(*this)!=0)
514  break;
515  ValueType c = convergence();
516  if (c < parameter_.bound_) {
517  break;
518  }
519 
520  }
521  visitor.end(*this);
522 }
523 
524 template<class GM, class ACC, class UPDATE_RULES, class DIST>
527 (
528  const size_t variableIndex,
530 ) const {
531  OPENGM_ASSERT(variableIndex < variableHulls_.size());
532  variableHulls_[variableIndex].marginal(gm_, variableIndex, out, parameter_.useNormalization_);
533  return NORMAL;
534 }
535 
536 template<class GM, class ACC, class UPDATE_RULES, class DIST>
539 (
540  const size_t factorIndex,
542 ) const {
543  typedef typename GM::OperatorType OP;
544  OPENGM_ASSERT(factorIndex < factorHulls_.size());
545  out.assign(gm_, gm_[factorIndex].variableIndicesBegin(), gm_[factorIndex].variableIndicesEnd(), OP::template neutral<ValueType>());
546  factorHulls_[factorIndex].marginal(out, parameter_.useNormalization_);
547  return NORMAL;
548 }
549 
551 template<class GM, class ACC, class UPDATE_RULES, class DIST>
554  ValueType result = 0;
555  for (size_t j = 0; j < factorHulls_.size(); ++j) {
556  for (size_t i = 0; i < factorHulls_[j].numberOfBuffers(); ++i) {
557  ValueType d = factorHulls_[j].template distance<DIST > (i);
558  if (d > result) {
559  result = d;
560  }
561  }
562  }
563  return result;
564 }
565 
567 template<class GM, class ACC, class UPDATE_RULES, class DIST>
570  ValueType result = 0;
571  for (size_t j = 0; j < variableHulls_.size(); ++j) {
572  for (size_t i = 0; i < variableHulls_[j].numberOfBuffers(); ++i) {
573  ValueType d = variableHulls_[j].template distance<DIST > (i);
574  if (d > result) {
575  result = d;
576  }
577  }
578  }
579  return result;
580 }
581 
583 template<class GM, class ACC, class UPDATE_RULES, class DIST>
586  return convergenceXF();
587 }
588 
589 template<class GM, class ACC,class UPDATE_RULES, class DIST >
592 (
593  std::vector<LabelType>& conf,
594  const size_t N
595 ) const {
596  if (N != 1) {
597  throw RuntimeError("This implementation of message passing cannot return the k-th optimal configuration.");
598  }
599  else {
600  if (parameter_.isAcyclic_ == opengm::Tribool::True) {
601  return this->modeFromFactorMarginal(conf);
602  }
603  else {
604  return this->modeFromFactorMarginal(conf);
605  //return modeFromMarginal(conf);
606  }
607  }
608 }
609 
610 } // namespace opengm
611 
612 #endif // #ifndef OPENGM_BELIEFPROPAGATION_HXX
UPDATE_RULES::SpecialParameterType SpecialParameterType
The OpenGM namespace.
Definition: config.hxx:43
std::vector< size_t > sortedNodeList_
InferenceTermination factorMarginal(const size_t, IndependentFactorType &out) const
output a solution for a marginal for all variables connected to a factor
UPDATE_RULES::FactorHullType FactorHullType
A framework for message passing algorithms Cf. F. R. Kschischang, B. J. Frey and H...
SpecialParameterType specialParameter_
visitors::TimingVisitor< MessagePassing< GM, ACC, UPDATE_RULES, DIST > > TimingVisitorType
Visitor.
#define OPENGM_ASSERT(expression)
Definition: opengm.hxx:77
visitors::VerboseVisitor< MessagePassing< GM, ACC, UPDATE_RULES, DIST > > VerboseVisitorType
Visitor.
ValueType convergenceXF() const
cumulative distance between all pairs of messages from variables to factors (between the previous and...
static T neutral()
neutral element (with return)
Definition: maximizer.hxx:14
visitors::EmptyVisitor< MessagePassing< GM, ACC, UPDATE_RULES, DIST > > EmptyVisitorType
Visitor.
InferenceTermination marginal(const size_t, IndependentFactorType &out) const
output a solution for a marginal for a specific variable
UPDATE_RULES::VariableHullType VariableHullType
Parameter(const size_t maximumNumberOfSteps=100, const ValueType bound=static_cast< ValueType >(0.000000), const ValueType damping=static_cast< ValueType >(0), const SpecialParameterType &specialParameter=SpecialParameterType(), const opengm::Tribool isAcyclic=opengm::Tribool::Maybe)
virtual ValueType bound() const
return a bound on the solution
Definition: inference.hxx:414
GraphicalModelType::ValueType ValueType
Definition: inference.hxx:41
Inference algorithm interface.
Definition: inference.hxx:34
void propagate(const ValueType &=0)
invoke one iteration of message passing
std::string name() const
ValueType convergence() const
cumulative distance between all pairs of messages (between the previous and the current interation) ...
static M::ValueType op(const M &in1, const M &in2)
operation
Variable with three values (true=1, false=0, maybe=-1)
Definition: tribool.hxx:8
InferenceTermination arg(std::vector< LabelType > &, const size_t=1) const
output a solution
InferenceTermination infer()
static void op(const T1 &in1, T2 &out)
operation (in-place)
Definition: maximizer.hxx:34
void setMaxSteps(size_t maxSteps)
OpenGM runtime error.
Definition: opengm.hxx:100
ValueType convergenceFX() const
cumulative distance between all pairs of messages from factors to variables (between the previous and...
MessagePassing(const GraphicalModelType &, const Parameter &=Parameter())
InferenceTermination
Definition: inference.hxx:24
GraphicalModelType::IndependentFactorType IndependentFactorType
Definition: inference.hxx:44
const GraphicalModelType & graphicalModel() const