2 #ifndef OPENGM_MESSAGE_PASSING_HXX
3 #define OPENGM_MESSAGE_PASSING_HXX
29 static typename M::ValueType
30 op(
const M& in1,
const M& in2)
32 typedef typename M::ValueType ValueType;
33 ValueType v1,v2,d1,d2;
36 for(
size_t n=0; n<in1.size(); ++n) {
49 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST=opengm::MaxDistance>
71 const size_t maximumNumberOfSteps = 100,
73 const ValueType damping = static_cast<ValueType> (0),
101 internalMessageId_(-1)
103 Message(
const size_t nodeId,
const size_t & internalMessageId)
105 internalMessageId_(internalMessageId)
109 size_t internalMessageId_;
113 MessagePassing(
const GraphicalModelType&,
const Parameter& = Parameter());
114 std::string
name()
const;
121 virtual void reset();
123 template<
class VisitorType>
133 void inferParallel();
134 void inferSequential();
135 template<
class VisitorType>
136 void inferParallel(VisitorType&);
137 template<
class VisitorType>
138 void inferAcyclic(VisitorType&);
139 template<
class VisitorType>
140 void inferSequential(VisitorType&);
142 const GraphicalModelType& gm_;
143 Parameter parameter_;
144 std::vector<FactorHullType> factorHulls_;
145 std::vector<VariableHullType> variableHulls_;
148 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
151 const GraphicalModelType& gm,
155 parameter_(parameter)
157 if(parameter_.sortedNodeList_.size() == 0) {
158 parameter_.sortedNodeList_.resize(gm.numberOfVariables());
159 for (
size_t i = 0; i < gm.numberOfVariables(); ++i)
160 parameter_.sortedNodeList_[i] = i;
162 OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm.numberOfVariables());
164 UPDATE_RULES::initializeSpecialParameter(gm_,this->parameter_);
167 variableHulls_.resize(gm.numberOfVariables(), VariableHullType ());
168 for (
size_t i = 0; i < gm.numberOfVariables(); ++i) {
169 variableHulls_[i].assign(gm, i, ¶meter_.specialParameter_);
171 factorHulls_.resize(gm.numberOfFactors(), FactorHullType ());
172 for (
size_t i = 0; i < gm.numberOfFactors(); i++) {
173 factorHulls_[i].assign(gm, i, variableHulls_, ¶meter_.specialParameter_);
177 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
181 if(parameter_.sortedNodeList_.size() == 0) {
182 parameter_.sortedNodeList_.resize(gm_.numberOfVariables());
183 for (
size_t i = 0; i < gm_.numberOfVariables(); ++i)
184 parameter_.sortedNodeList_[i] = i;
186 OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm_.numberOfVariables());
187 UPDATE_RULES::initializeSpecialParameter(gm_,this->parameter_);
190 variableHulls_.resize(gm_.numberOfVariables(), VariableHullType ());
191 for (
size_t i = 0; i < gm_.numberOfVariables(); ++i) {
192 variableHulls_[i].assign(gm_, i, ¶meter_.specialParameter_);
194 factorHulls_.resize(gm_.numberOfFactors(), FactorHullType ());
195 for (
size_t i = 0; i < gm_.numberOfFactors(); i++) {
196 factorHulls_[i].assign(gm_, i, variableHulls_, ¶meter_.specialParameter_);
200 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
206 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
212 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
219 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
220 template<
class VisitorType>
228 parameter_.useNormalization_=
false;
229 inferAcyclic(visitor);
231 if (parameter_.inferSequential_) {
232 inferSequential(visitor);
234 inferParallel(visitor);
237 if (gm_.isAcyclic()) {
240 parameter_.useNormalization_=
false;
241 inferAcyclic(visitor);
244 if (parameter_.inferSequential_) {
245 inferSequential(visitor);
247 inferParallel(visitor);
259 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
263 return inferAcyclic(v);
273 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
274 template<
class VisitorType>
276 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferAcyclic
282 visitor.begin(*
this);
283 size_t numberOfVariables = gm_.numberOfVariables();
284 size_t numberOfFactors = gm_.numberOfFactors();
287 std::vector<std::vector<size_t> > counterVar2FacMessage(numberOfVariables);
288 std::vector<std::vector<size_t> > counterFac2VarMessage(numberOfFactors);
290 std::vector<Message> ready2SendVar2FacMessage;
291 std::vector<Message> ready2SendFac2VarMessage;
292 ready2SendVar2FacMessage.reserve(100);
293 ready2SendFac2VarMessage.reserve(100);
294 for (
size_t fac = 0; fac < numberOfFactors; ++fac) {
295 counterFac2VarMessage[fac].resize(gm_[fac].numberOfVariables(), gm_[fac].numberOfVariables() - 1);
297 for (
size_t var = 0; var < numberOfVariables; ++var) {
298 counterVar2FacMessage[var].resize(gm_.numberOfFactors(var));
299 for (
size_t i = 0; i < gm_.numberOfFactors(var); ++i) {
300 counterVar2FacMessage[var][i] = gm_.numberOfFactors(var) - 1;
304 for (
size_t var = 0; var < numberOfVariables; ++var) {
305 for (
size_t i = 0; i < counterVar2FacMessage[var].size(); ++i) {
306 if (counterVar2FacMessage[var][i] == 0) {
307 --counterVar2FacMessage[var][i];
308 ready2SendVar2FacMessage.push_back(Message(var, i));
312 for (
size_t fac = 0; fac < numberOfFactors; ++fac) {
313 for (
size_t i = 0; i < counterFac2VarMessage[fac].size(); ++i) {
314 if (counterFac2VarMessage[fac][i] == 0) {
315 --counterFac2VarMessage[fac][i];
316 ready2SendFac2VarMessage.push_back(Message(fac, i));
321 while (ready2SendVar2FacMessage.size() > 0 || ready2SendFac2VarMessage.size() > 0) {
322 while (ready2SendVar2FacMessage.size() > 0) {
323 Message m = ready2SendVar2FacMessage.back();
324 size_t nodeId = m.nodeId_;
325 size_t factorId = gm_.factorOfVariable(nodeId,m.internalMessageId_);
327 variableHulls_[nodeId].propagate(gm_, m.internalMessageId_, 0,
false);
328 ready2SendVar2FacMessage.pop_back();
330 for (
size_t i = 0; i < gm_[factorId].numberOfVariables(); ++i) {
331 if (gm_[factorId].variableIndex(i) != nodeId) {
332 if (--counterFac2VarMessage[factorId][i] == 0) {
333 ready2SendFac2VarMessage.push_back(Message(factorId, i));
338 while (ready2SendFac2VarMessage.size() > 0) {
339 Message m = ready2SendFac2VarMessage.back();
340 size_t factorId = m.nodeId_;
341 size_t nodeId = gm_[factorId].variableIndex(m.internalMessageId_);
343 factorHulls_[factorId].propagate(m.internalMessageId_, 0, parameter_.useNormalization_);
344 ready2SendFac2VarMessage.pop_back();
346 for (
size_t i = 0; i < gm_.numberOfFactors(nodeId); ++i) {
347 if (gm_.factorOfVariable(nodeId,i) != factorId) {
348 if (--counterVar2FacMessage[nodeId][i] == 0) {
349 ready2SendVar2FacMessage.push_back(Message(nodeId, i));
354 if(visitor(*
this)!=0)
362 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
367 for (
size_t i = 0; i < variableHulls_.size(); ++i) {
368 variableHulls_[i].propagateAll(damping,
false);
370 for (
size_t i = 0; i < factorHulls_.size(); ++i) {
371 factorHulls_[i].propagateAll(damping, parameter_.useNormalization_);
376 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
379 return inferParallel(v);
384 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
385 template<
class VisitorType>
386 inline void MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferParallel
392 ValueType damping = parameter_.damping_;
393 visitor.begin(*
this);
396 for (
size_t i = 0; i < factorHulls_.size(); ++i) {
397 if (factorHulls_[i].numberOfBuffers() < 2) {
398 factorHulls_[i].propagateAll(0, parameter_.useNormalization_);
399 factorHulls_[i].propagateAll(0, parameter_.useNormalization_);
402 for (
unsigned long n = 0; n < parameter_.maximumNumberOfSteps_; ++n) {
403 for (
size_t i = 0; i < variableHulls_.size(); ++i) {
404 variableHulls_[i].propagateAll(gm_, damping,
false);
406 for (
size_t i = 0; i < factorHulls_.size(); ++i) {
407 if (factorHulls_[i].numberOfBuffers() >= 2)
408 factorHulls_[i].propagateAll(damping, parameter_.useNormalization_);
410 if(visitor(*
this)!=0)
413 if (c < parameter_.bound_) {
429 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
430 inline void MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferSequential() {
432 return inferSequential(v);
445 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
446 template<
class VisitorType>
447 inline void MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferSequential
451 OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm_.numberOfVariables());
452 visitor.begin(*
this);
453 ValueType damping = parameter_.damping_;
456 std::vector<size_t> nodeOrder(gm_.numberOfVariables());
457 for (
size_t o = 0; o < gm_.numberOfVariables(); ++o) {
458 nodeOrder[parameter_.sortedNodeList_[o]] = o;
462 for (
size_t f = 0; f < factorHulls_.size(); ++f) {
463 if (factorHulls_[f].numberOfBuffers() < 2) {
464 factorHulls_[f].propagateAll(0, parameter_.useNormalization_);
465 factorHulls_[f].propagateAll(0, parameter_.useNormalization_);
470 std::vector<std::vector<size_t> > inversePositions(gm_.numberOfVariables());
471 for(
size_t var=0; var<gm_.numberOfVariables();++var) {
472 for(
size_t i=0; i<gm_.numberOfFactors(var); ++i) {
473 size_t factorId = gm_.factorOfVariable(var,i);
474 for(
size_t j=0; j<gm_.numberOfVariables(factorId);++j) {
475 if(gm_.variableOfFactor(factorId,j)==var) {
476 inversePositions[var].push_back(j);
485 for (
unsigned long itteration = 0; itteration < parameter_.maximumNumberOfSteps_; ++itteration) {
486 if(itteration%2==0) {
488 for (
size_t o = 0; o < gm_.numberOfVariables(); ++o) {
489 size_t variableId = parameter_.sortedNodeList_[o];
491 for(
size_t i=0; i<gm_.numberOfFactors(variableId); ++i) {
492 size_t factorId = gm_.factorOfVariable(variableId,i);
493 factorHulls_[factorId].propagate(inversePositions[variableId][i], damping, parameter_.useNormalization_);
497 variableHulls_[variableId].propagateAll(gm_, damping,
false);
502 for (
size_t o = 0; o < gm_.numberOfVariables(); ++o) {
503 size_t variableId = parameter_.sortedNodeList_[gm_.numberOfVariables() - 1 - o];
505 for(
size_t i=0; i<gm_.numberOfFactors(variableId); ++i) {
506 size_t factorId = gm_.factorOfVariable(variableId,i);
507 factorHulls_[factorId].propagate(inversePositions[variableId][i], damping, parameter_.useNormalization_);
510 variableHulls_[variableId].propagateAll(gm_, damping,
false);
513 if(visitor(*
this)!=0)
515 ValueType c = convergence();
516 if (c < parameter_.bound_) {
524 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
528 const size_t variableIndex,
532 variableHulls_[variableIndex].marginal(gm_, variableIndex, out, parameter_.useNormalization_);
536 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
540 const size_t factorIndex,
543 typedef typename GM::OperatorType OP;
545 out.assign(gm_, gm_[factorIndex].variableIndicesBegin(), gm_[factorIndex].variableIndicesEnd(), OP::template neutral<ValueType>());
546 factorHulls_[factorIndex].marginal(out, parameter_.useNormalization_);
551 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
555 for (
size_t j = 0; j < factorHulls_.size(); ++j) {
556 for (
size_t i = 0; i < factorHulls_[j].numberOfBuffers(); ++i) {
557 ValueType d = factorHulls_[j].template distance<DIST > (i);
567 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
571 for (
size_t j = 0; j < variableHulls_.size(); ++j) {
572 for (
size_t i = 0; i < variableHulls_[j].numberOfBuffers(); ++i) {
573 ValueType d = variableHulls_[j].template distance<DIST > (i);
583 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
586 return convergenceXF();
589 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST >
593 std::vector<LabelType>& conf,
597 throw RuntimeError(
"This implementation of message passing cannot return the k-th optimal configuration.");
601 return this->modeFromFactorMarginal(conf);
604 return this->modeFromFactorMarginal(conf);
612 #endif // #ifndef OPENGM_BELIEFPROPAGATION_HXX
UPDATE_RULES::SpecialParameterType SpecialParameterType
std::vector< size_t > sortedNodeList_
InferenceTermination factorMarginal(const size_t, IndependentFactorType &out) const
output a solution for a marginal for all variables connected to a factor
UPDATE_RULES::FactorHullType FactorHullType
opengm::Tribool isAcyclic_
A framework for message passing algorithms Cf. F. R. Kschischang, B. J. Frey and H...
SpecialParameterType specialParameter_
visitors::TimingVisitor< MessagePassing< GM, ACC, UPDATE_RULES, DIST > > TimingVisitorType
Visitor.
#define OPENGM_ASSERT(expression)
opengm::Tribool useNormalization_
visitors::VerboseVisitor< MessagePassing< GM, ACC, UPDATE_RULES, DIST > > VerboseVisitorType
Visitor.
ValueType convergenceXF() const
cumulative distance between all pairs of messages from variables to factors (between the previous and...
static T neutral()
neutral element (with return)
visitors::EmptyVisitor< MessagePassing< GM, ACC, UPDATE_RULES, DIST > > EmptyVisitorType
Visitor.
InferenceTermination marginal(const size_t, IndependentFactorType &out) const
output a solution for a marginal for a specific variable
UPDATE_RULES::VariableHullType VariableHullType
Parameter(const size_t maximumNumberOfSteps=100, const ValueType bound=static_cast< ValueType >(0.000000), const ValueType damping=static_cast< ValueType >(0), const SpecialParameterType &specialParameter=SpecialParameterType(), const opengm::Tribool isAcyclic=opengm::Tribool::Maybe)
virtual ValueType bound() const
return a bound on the solution
GraphicalModelType::ValueType ValueType
Inference algorithm interface.
void propagate(const ValueType &=0)
invoke one iteration of message passing
ValueType convergence() const
cumulative distance between all pairs of messages (between the previous and the current interation) ...
static M::ValueType op(const M &in1, const M &in2)
operation
Variable with three values (true=1, false=0, maybe=-1)
InferenceTermination arg(std::vector< LabelType > &, const size_t=1) const
output a solution
InferenceTermination infer()
static void op(const T1 &in1, T2 &out)
operation (in-place)
void setMaxSteps(size_t maxSteps)
size_t maximumNumberOfSteps_
ValueType convergenceFX() const
cumulative distance between all pairs of messages from factors to variables (between the previous and...
MessagePassing(const GraphicalModelType &, const Parameter &=Parameter())
GraphicalModelType::IndependentFactorType IndependentFactorType
const GraphicalModelType & graphicalModel() const