OpenGM  2.3.x
Discrete Graphical Model Library
srmp.hxx
Go to the documentation of this file.
1 #ifndef OPENGM_EXTERNAL_SRMP_HXX_
2 #define OPENGM_EXTERNAL_SRMP_HXX_
3 
9 
10 #include <srmp/SRMP.h>
11 #include <srmp/FactorTypes/PottsType.h>
12 #include <srmp/FactorTypes/GeneralType.h>
13 
14 namespace opengm {
15 namespace external {
16 
17 /*********************
18  * class definitions *
19  *********************/
20 template<class GM>
21 class SRMP : public Inference<GM, opengm::Minimizer> {
22 public:
23  typedef GM GraphicalModelType;
29 
30  struct Parameter : public srmpLib::Energy::Options {
31  Parameter() : srmpLib::Energy::Options(), BLPRelaxation_(false),
34  // disable verbose mode per default
35  verbose = false;
36  }
37 
40  int FullRelaxationMethod_; // method=0: add all possible pairs (A,B) with B \subset A, with the following exception:
41  // if there exists factor C with B\subset C \subset A then don't add (A,B)
42  // method=1: move all costs to outer factors (converting them to general types first, if they are not already of these types).
43  // Then run method=0 and remove unnecesary edges (i.e. those that do not affect the relaxation).
44  // Note, all edges outgoing from outer factors will be kepts.
45  // method=2: similar to method=1, but all edges {A->B, B->C} are replaced with {A->B, A->C} (so this results in a two-layer graph).
46  //
47  // Note, method=1 and method=2 merge duplicate factors while method=0 does not. For this reason the relaxation may be tighther.
48  // (If there are no duplicate factors then the resulting relaxation should be the same in all three cases).
49  //
50  // method=3: run method=2 and then create a new Energy instance with unary and pairwise terms in which nodes correspond
51  // to outer factors of the original energy, and pairwise terms with {0,+\infty} costs enforce consistency between them.
53  int FullDualRelaxationMethod_; // FullDualRelaxationMethod_ has the same meaning as in srmpLib::Energy::Options::sort_flag.
54  };
55 
56  // construction
57  SRMP(const GraphicalModelType& gm, const Parameter para = Parameter());
58  // destruction
59  ~SRMP();
60  // query
61  std::string name() const;
62  const GraphicalModelType& graphicalModel() const;
63  // inference
65  template<class VISITOR>
66  InferenceTermination infer(VISITOR & visitor);
67  InferenceTermination arg(std::vector<LabelType>& arg, const size_t& n = 1) const;
68  typename GM::ValueType bound() const;
69  typename GM::ValueType value() const;
70 private:
71  const GraphicalModelType& gm_;
72  Parameter parameter_;
73 
74  ValueType constTerm_;
75  ValueType value_;
76  ValueType lowerBound_;
77 
78  srmpLib::Energy::Options srmpOptions_;
79  srmpLib::Energy srmpSolver_;
80 
81  std::vector<srmpLib::PottsFactorType*> pottsFactorList_; // list of created potts functions which must be deleted when ~SRMP() is called
82  std::vector<srmpLib::GeneralFactorType*> generalFactorList_; // list of created general functions which must be deleted when ~SRMP() is called
83 
84  void addUnaryFactor(const IndexType FactorID);
85  void addPairwiseFactor(const IndexType FactorID);
86  void addPottsFactor(const IndexType FactorID);
87  void addGeneralFactor(const IndexType FactorID);
88 };
89 
90 /***********************
91  * class documentation *
92  ***********************/
93 //TODO add documentation
94 
95 /******************
96  * implementation *
97  ******************/
98 
99 template<class GM>
100 inline SRMP<GM>::SRMP(const GraphicalModelType& gm, const Parameter para)
101 : gm_(gm), parameter_(para), constTerm_(0.0), value_(), lowerBound_(),
102  srmpOptions_(), srmpSolver_(gm_.numberOfVariables()), pottsFactorList_(),
103  generalFactorList_() {
104  // set states of variables
105  for(IndexType i = 0; i < gm_.numberOfVariables(); ++i) {
106  srmpSolver_.AddNode(gm_.numberOfLabels(i));
107  }
108 
109  // set factors
110  for(IndexType i = 0; i < gm_.numberOfFactors(); ++i) {
111  if(gm_[i].numberOfVariables() == 0) {
112  // constant factor not supported by srmp, hence handle constant term external from srmp solver
113  LabelType l = 0;
114  constTerm_ += gm_[i](&l);
115  } else if(gm_[i].numberOfVariables() == 1) {
116  // add unary factor
117  addUnaryFactor(i);
118  } else if(gm_[i].numberOfVariables() == 2) {
119  if(gm_[i].numberOfLabels(0) == gm_[i].numberOfLabels(1) && gm_[i].isPotts()) {
120  // add potts factor
121  // srmp potts type does not support potts functions with more than
122  // two variables or with different number of labels
123  addPottsFactor(i);
124  } else {
125  // add pairwise factor
126  addPairwiseFactor(i);
127  }
128  } else {
129  // general factor
130  // TODO srmp provides other function types which can be used instead of general type for some factors (SharedPairwiseType, PatternType, PairwiseDualType)
131  addGeneralFactor(i);
132  }
133  }
134 
135  // set options
136  srmpOptions_.method = parameter_.method;
137  srmpOptions_.iter_max = parameter_.iter_max;
138  srmpOptions_.time_max = parameter_.time_max;
139  srmpOptions_.eps = parameter_.eps;
140  srmpOptions_.compute_solution_period = parameter_.compute_solution_period;
141  srmpOptions_.print_times = parameter_.print_times;
142  srmpOptions_.sort_flag = parameter_.sort_flag;
143  srmpOptions_.verbose = parameter_.verbose;
144  srmpOptions_.TRWS_weighting = parameter_.TRWS_weighting;
145 
146  // set initial value and lower bound
148  AccumulationType::ineutral(lowerBound_);
149 }
150 
151 template<class GM>
152 inline SRMP<GM>::~SRMP() {
153  for(size_t i = 0; i < pottsFactorList_.size(); ++i) {
154  delete pottsFactorList_[i];
155  }
156  for(size_t i = 0; i < generalFactorList_.size(); ++i) {
157  delete generalFactorList_[i];
158  }
159 }
160 
161 template<class GM>
162 inline std::string SRMP<GM>::name() const {
163  return "SRMP";
164 }
165 
166 template<class GM>
168  return gm_;
169 }
170 
171 template<class GM>
173  EmptyVisitorType visitor;
174  return this->infer(visitor);
175 }
176 
177 template<class GM>
178 template<class VISITOR>
179 inline InferenceTermination SRMP<GM>::infer(VISITOR & visitor) {
180  visitor.begin(*this);
181 
182  if (parameter_.BLPRelaxation_) {
183  srmpSolver_.SetMinimalEdges();
184  } else if (parameter_.FullRelaxation_) {
185  srmpSolver_.SetFullEdges(parameter_.FullRelaxationMethod_);
186  } else if (parameter_.FullDualRelaxation_) {
187  srmpSolver_.SetFullEdgesDual(parameter_.FullDualRelaxationMethod_);
188  }
189 
190  // call solver
191  lowerBound_ = srmpSolver_.Solve(srmpOptions_);
192  std::vector<LabelType> l;
193  arg(l);
194  value_ = gm_.evaluate(l);
195 
196  visitor.end(*this);
197  return NORMAL;
198 }
199 
200 template<class GM>
201 inline InferenceTermination SRMP<GM>::arg(std::vector<LabelType>& arg, const size_t& n) const {
202  if(n > 1) {
203  return UNKNOWN;
204  }
205  else {
206  arg.resize(gm_.numberOfVariables());
207  for(IndexType i = 0; i < gm_.numberOfVariables(); ++i) {
208  arg[i] = srmpSolver_.GetSolution(i);
209  }
210  return NORMAL;
211  }
212 }
213 
214 template<class GM>
215 inline typename GM::ValueType SRMP<GM>::bound() const {
216  return lowerBound_ + constTerm_;
217 }
218 
219 template<class GM>
220 inline typename GM::ValueType SRMP<GM>::value() const {
221  return value_;
222  //return value_ + constTerm_;
223 }
224 
225 template<class GM>
226 inline void SRMP<GM>::addUnaryFactor(const IndexType FactorID) {
227  double* values = new double[gm_[FactorID].numberOfLabels(0)];
228  LabelType label = 0;
229  for(LabelType i = 0; i < gm_[FactorID].numberOfLabels(0); ++i) {
230  values[i] = static_cast<double>(gm_[FactorID](&label));
231  ++label;
232  }
233  srmpSolver_.AddUnaryFactor(static_cast<srmpLib::Energy::NodeId>(gm_[FactorID].variableIndex(0)), values);
234  delete[] values;
235 }
236 
237 template<class GM>
238 inline void SRMP<GM>::addPairwiseFactor(const IndexType FactorID) {
239  double* values = new double[gm_[FactorID].numberOfLabels(0) * gm_[FactorID].numberOfLabels(1)];
240  LabelType labeling[2] = {0, 0};
241  for(LabelType i = 0; i < gm_[FactorID].numberOfLabels(0); ++i) {
242  labeling[0] = i;
243  for(LabelType j = 0; j < gm_[FactorID].numberOfLabels(1); ++j) {
244  labeling[1] = j;
245  values[(i * gm_[FactorID].numberOfLabels(1)) + j] = static_cast<double>(gm_[FactorID](labeling));
246  }
247  }
248  srmpSolver_.AddPairwiseFactor(static_cast<srmpLib::Energy::NodeId>(gm_[FactorID].variableIndex(0)), static_cast<srmpLib::Energy::NodeId>(gm_[FactorID].variableIndex(1)), values);
249  delete[] values;
250 }
251 
252 template<class GM>
253 inline void SRMP<GM>::addPottsFactor(const IndexType FactorID) {
254  ValueType valueEqual;
255  ValueType valueNotEqual;
256 
257  LabelType labeling[2] = {0, 0};
258  valueEqual = gm_[FactorID](labeling);
259  for(IndexType j = 0; j < 2; ++j) {
260  if(gm_[FactorID].numberOfLabels(j) > 1) {
261  labeling[j] = 1;
262  break;
263  }
264  }
265  valueNotEqual = gm_[FactorID](labeling);
266 
267  srmpLib::PottsFactorType* pottsFactor = new srmpLib::PottsFactorType;
268  pottsFactorList_.push_back(pottsFactor);
269 
270  // srmp potts type uses 0.0 as equal value, hence shift values
271  double lambda = valueNotEqual - valueEqual;
272  constTerm_ += valueEqual;
273 
274  srmpLib::Energy::NodeId nodes[2] = {static_cast<srmpLib::Energy::NodeId>(gm_[FactorID].variableIndex(0)), static_cast<srmpLib::Energy::NodeId>(gm_[FactorID].variableIndex(1))};
275 
276  srmpSolver_.AddFactor(2, nodes, &lambda, pottsFactor);
277 }
278 
279 template<class GM>
280 inline void SRMP<GM>::addGeneralFactor(const IndexType FactorID) {
281  double* values = new double[gm_[FactorID].size()];
282 
283  ShapeWalkerSwitchedOrder<typename FactorType::ShapeIteratorType> shapeWalker(gm_[FactorID].shapeBegin(), gm_[FactorID].dimension());
284  for(size_t i = 0; i < gm_[FactorID].size(); ++i) {
285  values[i] = gm_[FactorID](shapeWalker.coordinateTuple().begin());
286  ++shapeWalker;
287  }
288 
289  srmpLib::Energy::NodeId* nodes = new srmpLib::Energy::NodeId[gm_[FactorID].numberOfVariables()];
290  for(IndexType i = 0; i < gm_[FactorID].numberOfVariables(); ++i) {
291  nodes[i] = static_cast<srmpLib::Energy::NodeId>(gm_[FactorID].variableIndex(i));
292  }
293 
294  srmpLib::GeneralFactorType* generalFactor = new srmpLib::GeneralFactorType;
295 
296  srmpSolver_.AddFactor(gm_[FactorID].numberOfVariables(), nodes, values, generalFactor);
297 
298  delete[] nodes;
299  delete[] values;
300 }
301 
302 } // namespace external
303 } // namespace opengm
304 
305 #endif /* OPENGM_EXTERNAL_SRMP_HXX_ */
InferenceTermination arg(std::vector< LabelType > &arg, const size_t &n=1) const
Definition: srmp.hxx:201
The OpenGM namespace.
Definition: config.hxx:43
const GraphicalModelType & graphicalModel() const
Definition: srmp.hxx:167
opengm::Minimizer AccumulationType
Definition: srmp.hxx:24
visitors::VerboseVisitor< SRMP< GM > > VerboseVisitorType
Definition: srmp.hxx:26
std::string name() const
Definition: srmp.hxx:162
visitors::EmptyVisitor< SRMP< GM > > EmptyVisitorType
Definition: srmp.hxx:27
GM::ValueType bound() const
return a bound on the solution
Definition: srmp.hxx:215
GraphicalModelType::IndexType IndexType
Definition: inference.hxx:40
GraphicalModelType::ValueType ValueType
Definition: inference.hxx:41
static T ineutral()
inverse neutral element (with return)
Definition: minimizer.hxx:25
Inference algorithm interface.
Definition: inference.hxx:34
SRMP(const GraphicalModelType &gm, const Parameter para=Parameter())
Definition: srmp.hxx:100
visitors::TimingVisitor< SRMP< GM > > TimingVisitorType
Definition: srmp.hxx:28
static T neutral()
neutral element (with return)
Definition: minimizer.hxx:16
GraphicalModelType::LabelType LabelType
Definition: inference.hxx:39
Minimization as a unary accumulation.
Definition: minimizer.hxx:12
InferenceTermination infer()
Definition: srmp.hxx:172
InferenceTermination
Definition: inference.hxx:24
GM::ValueType value() const
return the solution (value)
Definition: srmp.hxx:220