OpenGM  2.3.x
Discrete Graphical Model Library
alphabetaswap.hxx
Go to the documentation of this file.
1 #pragma once
2 #ifndef OPENGM_ALPHABEATSWAP_HXX
3 #define OPENGM_ALPHABETASWAP_HXX
4 
5 #include <vector>
6 
9 
10 namespace opengm {
11 
14 template<class GM, class INF>
15 class AlphaBetaSwap : public Inference<GM, typename INF::AccumulationType> {
16 public:
17  typedef GM GraphicalModelType;
18  typedef INF InferenceType;
19  typedef typename INF::AccumulationType AccumulationType;
24 
25  struct Parameter {
28  }
29 
30  typename InferenceType::Parameter parameter_;
32  };
33 
34  AlphaBetaSwap(const GraphicalModelType&, Parameter = Parameter());
35  std::string name() const;
36  const GraphicalModelType& graphicalModel() const;
38  template<class VISITOR>
39  InferenceTermination infer(VISITOR & );
40  void reset();
41  void setStartingPoint(typename std::vector<LabelType>::const_iterator);
42  InferenceTermination arg(std::vector<LabelType>&, const size_t = 1) const;
43 
44 private:
45  const GraphicalModelType& gm_;
46  Parameter parameter_;
47  std::vector<LabelType> label_;
48  size_t alpha_;
49  size_t beta_;
50  size_t maxState_;
51  void increment();
52  void addUnary(INF&, const size_t var, const ValueType v0, const ValueType v1);
53  void addPairwise(INF&, const size_t var1, const size_t var2, const ValueType v0, const ValueType v1, const ValueType v2, const ValueType v3);
54 };
55 
56 // reset assumes that the structure of the graphical model has not changed
57 template<class GM, class INF>
58 inline void
60  alpha_ = 0;
61  beta_ = 0;
62  std::fill(label_.begin(),label_.end(),0);
63 }
64 
65 template<class GM, class INF>
66 inline void
68  if (++beta_ >= maxState_) {
69  if (++alpha_ >= maxState_ - 1) {
70  alpha_ = 0;
71  }
72  beta_ = alpha_ + 1;
73  }
74  OPENGM_ASSERT(alpha_ < maxState_);
75  OPENGM_ASSERT(beta_ < maxState_);
76  OPENGM_ASSERT(alpha_ < beta_);
77 }
78 
79 template<class GM, class INF>
80 inline std::string
82  return "Alpha-Beta-Swap";
83 }
84 
85 template<class GM, class INF>
86 inline const typename AlphaBetaSwap<GM, INF>::GraphicalModelType&
88  return gm_;
89 }
90 
91 template<class GM, class INF>
93 (
94  const GraphicalModelType& gm,
95  Parameter para
96 )
97 : gm_(gm)
98 {
99  parameter_ = para;
100  label_.resize(gm_.numberOfVariables(), 0);
101  alpha_ = 0;
102  beta_ = 0;
103  for (size_t j = 0; j < gm_.numberOfFactors(); ++j) {
104  if (gm_[j].numberOfVariables() > 2) {
105  throw RuntimeError("This implementation of Alpha-Beta-Swap supports only factors of order <= 2.");
106  }
107  }
108  maxState_ = 0;
109  for (size_t i = 0; i < gm_.numberOfVariables(); ++i) {
110  size_t numSt = gm_.numberOfLabels(i);
111  if (numSt > maxState_)
112  maxState_ = numSt;
113  }
114 }
115 
116 template<class GM, class INF>
117 inline void
119 (
120  typename std::vector<typename AlphaBetaSwap<GM,INF>::LabelType>::const_iterator begin
121 ) {
122  try{
123  label_.assign(begin, begin+gm_.numberOfVariables());
124  }
125  catch(...) {
126  throw RuntimeError("unsuitable starting point");
127  }
128 }
129 
130 template<class GM, class INF>
131 inline void
133 (
134  INF& inf,
135  const size_t var1,
136  const ValueType v0,
137  const ValueType v1
138 ) {
139  const size_t shape[] = {2};
140  const size_t vars[] = {var1};
141  opengm::IndependentFactor<ValueType,IndexType,LabelType> fac(vars, vars + 1, shape, shape + 1);
142  fac(0) = v0;
143  fac(1) = v1;
144  inf.addFactor(fac);
145 }
146 
147 template<class GM, class INF>
148 inline void
149 AlphaBetaSwap<GM, INF>::addPairwise
150 (
151  INF& inf,
152  const size_t var1,
153  const size_t var2,
154  const ValueType v0,
155  const ValueType v1,
156  const ValueType v2,
157  const ValueType v3
158 ) {
159  const size_t shape[] = {2, 2};
160  const size_t vars[] = {var1, var2};
161  opengm::IndependentFactor<ValueType,IndexType,LabelType> fac(vars, vars + 2, shape, shape + 2);
162  fac(0, 0) = v0;
163  fac(0, 1) = v1;
164  fac(1, 0) = v2;
165  fac(1, 1) = v3;
166  OPENGM_ASSERT(v1 + v2 - v0 - v3 >= 0);
167  inf.addFactor(fac);
168 }
169 template<class GM, class INF>
172  EmptyVisitorType v;
173  return infer(v);
174 }
175 
176 template<class GM, class INF>
177 template<class VISITOR>
180 (
181  VISITOR & visitor
182 ) {
183  bool exitInf=false;
184  visitor.begin(*this);
185  size_t it = 0;
186  size_t countUnchanged = 0;
187  size_t numberOfVariables = gm_.numberOfVariables();
188  std::vector<size_t> variable2Node(numberOfVariables, 0);
189  ValueType energy = gm_.evaluate(label_);
190  size_t vecA[1];
191  size_t vecB[1];
192  size_t vecAA[2];
193  size_t vecAB[2];
194  size_t vecBA[2];
195  size_t vecBB[2];
196  size_t vecAX[2];
197  size_t vecBX[2];
198  size_t vecXA[2];
199  size_t vecXB[2];
200  size_t numberOfLabelPairs = maxState_*(maxState_ - 1)/2;
201  while (it++ < parameter_.maxNumberOfIterations_ && countUnchanged < numberOfLabelPairs && exitInf == false) {
202  increment();
203  size_t counter = 0;
204  std::vector<size_t> numFacDim(4, 0);
205  for (size_t i = 0; i < numberOfVariables; ++i) {
206  if (label_[i] == alpha_ || label_[i] == beta_) {
207  variable2Node[i] = counter++;
208  }
209  }
210  if (counter == 0) {
211  continue;
212  }
213  INF inf(counter, numFacDim);
214  vecA[0] = alpha_;
215  vecB[0] = beta_;
216  vecAA[0] = alpha_;
217  vecAA[1] = alpha_;
218  vecBB[0] = beta_;
219  vecBB[1] = beta_;
220  vecBA[0] = beta_;
221  vecBA[1] = alpha_;
222  vecAB[0] = alpha_;
223  vecAB[1] = beta_;
224  vecAX[0] = alpha_;
225  vecBX[0] = beta_;
226  vecXA[1] = alpha_;
227  vecXB[1] = beta_;
228  for (size_t k = 0; k < gm_.numberOfFactors(); ++k) {
229  const FactorType& factor = gm_[k];
230  if (factor.numberOfVariables() == 1) {
231  size_t var = factor.variableIndex(0);
232  size_t node = variable2Node[var];
233  if (label_[var] == alpha_ || label_[var] == beta_) {
234  OPENGM_ASSERT(alpha_ < gm_.numberOfLabels(var));
235  OPENGM_ASSERT(beta_ < gm_.numberOfLabels(var));
236  addUnary(inf, node, factor(vecA), factor(vecB));
237  //inf.addUnary(node, factor(vecA), factor(vecB));
238  }
239  } else if (factor.numberOfVariables() == 2) {
240  size_t var1 = factor.variableIndex(0);
241  size_t var2 = factor.variableIndex(1);
242  size_t node1 = variable2Node[var1];
243  size_t node2 = variable2Node[var2];
244 
245  if ((label_[var1] == alpha_ || label_[var1] == beta_) && (label_[var2] == alpha_ || label_[var2] == beta_)) {
246  addPairwise(inf, node1, node2, factor(vecAA), factor(vecAB), factor(vecBA), factor(vecBB));
247  //inf.addPairwise(node1, node2, factor(vecAA), factor(vecAB), factor(vecBA), factor(vecBB));
248  } else if ((label_[var1] == alpha_ || label_[var1] == beta_) && (label_[var2] != alpha_ && label_[var2] != beta_)) {
249  vecAX[1] = vecBX[1] = label_[var2];
250  addUnary(inf, node1, factor(vecAX), factor(vecBX));
251  //inf.addUnary(node1, factor(vecAX), factor(vecBX));
252  } else if ((label_[var2] == alpha_ || label_[var2] == beta_) && (label_[var1] != alpha_ && label_[var1] != beta_)) {
253  vecXA[0] = vecXB[0] = label_[var1];
254  addUnary(inf, node2, factor(vecXA), factor(vecXB));
255  //inf.addUnary(node2, factor(vecXA), factor(vecXB));
256  }
257  }
258  }
259  std::vector<LabelType> state; //(counter);
260  inf.infer();
261  inf.arg(state);
262  OPENGM_ASSERT(state.size() == counter);
263  for (size_t var = 0; var < numberOfVariables; ++var) {
264  if (label_[var] == alpha_ || label_[var] == beta_) {
265  if (state[variable2Node[var]] == 0)
266  label_[var] = alpha_;
267  else
268  label_[var] = beta_;
269  } else {
270  //do nothing
271  }
272  }
273  ValueType energy2 = gm_.evaluate(label_);
274  if( visitor(*this) != visitors::VisitorReturnFlag::ContinueInf ){
275  exitInf=true;
276  }
277  OPENGM_ASSERT(!AccumulationType::ibop(energy2, energy));
278  if (AccumulationType::bop(energy2, energy)) {
279  energy = energy2;
280  } else {
281  ++countUnchanged;
282  }
283  }
284  visitor.end(*this);
285  return NORMAL;
286 }
287 
288 template<class GM, class INF>
290 AlphaBetaSwap<GM, INF>::arg(std::vector<LabelType>& arg, const size_t n) const {
291  if (n > 1) {
292  return UNKNOWN;
293  } else {
294  OPENGM_ASSERT(label_.size() == gm_.numberOfVariables());
295  arg.resize(label_.size());
296  for (size_t i = 0; i < label_.size(); ++i)
297  arg[i] = label_[i];
298  return NORMAL;
299  }
300 }
301 
302 } // namespace opengm
303 
304 #endif // #ifndef OPENGM_ALPHABEATSWAP_HXX
const GraphicalModelType & graphicalModel() const
The OpenGM namespace.
Definition: config.hxx:43
Factor (with corresponding function and variable indices), independent of a GraphicalModel.
opengm::visitors::EmptyVisitor< AlphaBetaSwap< GM, INF > > EmptyVisitorType
opengm::visitors::VerboseVisitor< AlphaBetaSwap< GM, INF > > VerboseVisitorType
#define OPENGM_ASSERT(expression)
Definition: opengm.hxx:77
AlphaBetaSwap(const GraphicalModelType &, Parameter=Parameter())
void setStartingPoint(typename std::vector< LabelType >::const_iterator)
set initial labeling
InferenceType::Parameter parameter_
opengm::visitors::TimingVisitor< AlphaBetaSwap< GM, INF > > TimingVisitorType
GraphicalModelType::FactorType FactorType
Definition: inference.hxx:43
Alpha-Beta-Swap Algorithm.
INF::AccumulationType AccumulationType
GraphicalModelType::ValueType ValueType
Definition: inference.hxx:41
InferenceTermination arg(std::vector< LabelType > &, const size_t=1) const
output a solution
Inference algorithm interface.
Definition: inference.hxx:34
std::string name() const
GraphicalModelType::LabelType LabelType
Definition: inference.hxx:39
InferenceTermination infer()
OpenGM runtime error.
Definition: opengm.hxx:100
InferenceTermination
Definition: inference.hxx:24