OpenGM  2.3.x
Discrete Graphical Model Library
dynamicprogramming.hxx
Go to the documentation of this file.
1 #pragma once
2 #ifndef OPENGM_DYNAMICPROGRAMMING_HXX
3 #define OPENGM_DYNAMICPROGRAMMING_HXX
4 
5 #include <typeinfo>
6 #include <limits>
9 
10 namespace opengm {
11 
15  template<class GM, class ACC>
16  class DynamicProgramming : public Inference<GM, ACC> {
17  public:
18  typedef ACC AccumulationType;
19  typedef ACC AccumulatorType;
20  typedef GM GraphicalModelType;
27  struct Parameter {
28  std::vector<IndexType> roots_;
29  };
30 
31  DynamicProgramming(const GraphicalModelType&, const Parameter& = Parameter());
33 
34  std::string name() const;
35  const GraphicalModelType& graphicalModel() const;
37  template< class VISITOR>
38  InferenceTermination infer(VISITOR &);
39  InferenceTermination arg(std::vector<LabelType>&, const size_t = 1) const;
40 
41 
42  void getNodeInfo(const IndexType Inode, std::vector<ValueType>& values, std::vector<std::vector<LabelType> >& substates, std::vector<IndexType>& nodes) const;
43 
44 
45  private:
46  const GraphicalModelType& gm_;
47  Parameter para_;
48  MyValueType* valueBuffer_;
49  MyStateType* stateBuffer_;
50  std::vector<MyValueType*> valueBuffers_;
51  std::vector<MyStateType*> stateBuffers_;
52  std::vector<size_t> nodeOrder_;
53  std::vector<size_t> orderedNodes_;
54  bool inferenceStarted_;
55  };
56 
57  template<class GM, class ACC>
58  inline std::string
60  return "DynamicProgramming";
61  }
62 
63  template<class GM, class ACC>
66  return gm_;
67  }
68 
69  template<class GM, class ACC>
71  {
72  free(valueBuffer_);
73  free(stateBuffer_);
74  }
75 
76  template<class GM, class ACC>
78  (
79  const GraphicalModelType& gm,
80  const Parameter& para
81  )
82  : gm_(gm), inferenceStarted_(false)
83  {
84  OPENGM_ASSERT(gm_.isAcyclic());
85  para_ = para;
86 
87  // Set nodeOrder
88  std::vector<size_t> numChildren(gm_.numberOfVariables(),0);
89  std::vector<size_t> nodeList;
90  size_t orderCount = 0;
91  size_t varCount = 0;
92  nodeOrder_.resize(gm_.numberOfVariables(),std::numeric_limits<std::size_t>::max());
93  size_t rootCounter=0;
94  while(varCount < gm_.numberOfVariables() && orderCount < gm_.numberOfVariables()){
95  if(rootCounter<para_.roots_.size()){
96  nodeOrder_[para_.roots_[rootCounter]] = orderCount++;
97  nodeList.push_back(para_.roots_[rootCounter]);
98  ++rootCounter;
99  }
100  else if(nodeOrder_[varCount]==std::numeric_limits<std::size_t>::max()){
101  nodeOrder_[varCount] = orderCount++;
102  nodeList.push_back(varCount);
103  }
104  ++varCount;
105  while(nodeList.size()>0){
106  size_t node = nodeList.back();
107  nodeList.pop_back();
108  for(typename GM::ConstFactorIterator it=gm_.factorsOfVariableBegin(node); it !=gm_.factorsOfVariableEnd(node); ++it){
109  const typename GM::FactorType& factor = gm_[(*it)];
110  if( factor.numberOfVariables() == 2 ){
111  if( factor.variableIndex(1) == node && nodeOrder_[factor.variableIndex(0)]==std::numeric_limits<std::size_t>::max() ){
112  nodeOrder_[factor.variableIndex(0)] = orderCount++;
113  nodeList.push_back(factor.variableIndex(0));
114  ++numChildren[node];
115  }
116  if( factor.variableIndex(0) == node && nodeOrder_[factor.variableIndex(1)]==std::numeric_limits<std::size_t>::max() ){
117  nodeOrder_[factor.variableIndex(1)] = orderCount++;
118  nodeList.push_back(factor.variableIndex(1));
119  ++numChildren[node];
120  }
121  }
122  }
123  }
124  }
125 
126  // Allocate memmory
127  size_t memSizeValue = 0;
128  size_t memSizeState = 0;
129  for(size_t i=0; i<gm_.numberOfVariables();++i){
130  memSizeValue += gm_.numberOfLabels(i);
131  memSizeState += gm.numberOfLabels(i) * numChildren[i];
132  }
133  valueBuffer_ = (MyValueType*) malloc(memSizeValue*sizeof(MyValueType));
134  stateBuffer_ = (MyStateType*) malloc(memSizeState*sizeof(MyStateType));
135  valueBuffers_.resize(gm_.numberOfVariables());
136  stateBuffers_.resize(gm_.numberOfVariables());
137 
138  MyValueType* valuePointer = valueBuffer_;
139  MyStateType* statePointer = stateBuffer_;
140  for(size_t i=0; i<gm_.numberOfVariables();++i){
141  valueBuffers_[i] = valuePointer;
142  valuePointer += gm.numberOfLabels(i);
143  stateBuffers_[i] = statePointer;
144  statePointer += gm.numberOfLabels(i) * numChildren[i];
145  }
146 
147  orderedNodes_.resize(gm_.numberOfVariables(),std::numeric_limits<std::size_t>::max());
148  for(size_t i=0; i<gm_.numberOfVariables(); ++i)
149  orderedNodes_[nodeOrder_[i]] = i;
150 
151  }
152 
153  template<class GM, class ACC>
154  inline InferenceTermination
156  EmptyVisitorType v;
157  return infer(v);
158  }
159 
160  template<class GM, class ACC>
161  template<class VISITOR>
162  inline InferenceTermination
164  (
165  VISITOR & visitor
166  ){
167  visitor.begin(*this);
168  inferenceStarted_ = true;
169  for(size_t i=1; i<=gm_.numberOfVariables();++i){
170  const size_t node = orderedNodes_[gm_.numberOfVariables()-i];
171  // set buffer neutral
172  for(size_t n=0; n<gm_.numberOfLabels(node); ++n){
173  OperatorType::neutral(valueBuffers_[node][n]);
174  }
175  // accumulate messages
176  size_t childrenCounter = 0;
177  for(typename GM::ConstFactorIterator it=gm_.factorsOfVariableBegin(node); it !=gm_.factorsOfVariableEnd(node); ++it){
178  const typename GM::FactorType& factor = gm_[(*it)];
179 
180  // unary
181  if(factor.numberOfVariables()==1 ){
182  for(size_t n=0; n<gm_.numberOfLabels(node); ++n){
183  const ValueType fac = factor(&n);
184  OperatorType::op(fac, valueBuffers_[node][n]);
185  }
186  }
187 
188  //pairwise
189  if( factor.numberOfVariables()==2 ){
190  size_t vec[] = {0,0};
191  if(factor.variableIndex(0) == node && nodeOrder_[factor.variableIndex(1)]>nodeOrder_[node] ){
192  const size_t node2 = factor.variableIndex(1);
193  MyStateType s;
194  MyValueType v,v2;
195  for(vec[0]=0; vec[0]<gm_.numberOfLabels(node); ++vec[0]){
196  ACC::neutral(v);
197  for(vec[1]=0; vec[1]<gm_.numberOfLabels(node2); ++vec[1]){
198  const ValueType fac = factor(vec);
199  OperatorType::op(fac,valueBuffers_[node2][vec[1]],v2) ;
200  if(ACC::bop(v2,v)){
201  v=v2;
202  s=vec[1];
203  }
204  }
205  stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+vec[0]] = s;
206  OperatorType::op(v,valueBuffers_[node][vec[0]]);
207  }
208  ++childrenCounter;
209 
210  }
211  if(factor.variableIndex(1) == node && nodeOrder_[factor.variableIndex(0)]>nodeOrder_[node]){
212  const size_t node2 = factor.variableIndex(0);
213  MyStateType s;
214  MyValueType v,v2;
215  for(vec[1]=0; vec[1]<gm_.numberOfLabels(node); ++vec[1]){
216  ACC::neutral(v);
217  for(vec[0]=0; vec[0]<gm_.numberOfLabels(node2); ++vec[0]){
218  const ValueType fac = factor(vec);
219  OperatorType::op(fac,valueBuffers_[node2][vec[0]],v2);
220  if(ACC::bop(v2,v)){
221  v=v2;
222  s=vec[0];
223  }
224  }
225  stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+vec[1]] = s;
226  OperatorType::op(v,valueBuffers_[node][vec[1]]);
227  }
228  ++childrenCounter;
229  }
230  }
231  // higher order
232  if( factor.numberOfVariables()>2 ){
233  throw std::runtime_error("This implementation of Dynamic Programming does only support second order models so far, but could be extended.");
234  }
235 
236  }
237  }
238  visitor.end(*this);
239  return NORMAL;
240  }
241 
242  template<class GM, class ACC>
244  (
245  std::vector<LabelType>& arg,
246  const size_t n
247  ) const {
248  if(n > 1) {
249  arg.assign(gm_.numberOfVariables(), 0);
250  return UNKNOWN;
251  }
252  else {
253  if(inferenceStarted_) {
254  std::vector<size_t> nodeList;
255  arg.assign(gm_.numberOfVariables(), std::numeric_limits<LabelType>::max() );
256  size_t var = 0;
257  while(var < gm_.numberOfVariables()){
258  if(arg[var]==std::numeric_limits<LabelType>::max()){
259  MyValueType v; ACC::neutral(v);
260  for(size_t i=0; i<gm_.numberOfLabels(var); ++i){
261  if(ACC::bop(valueBuffers_[var][i], v)){
262  v = valueBuffers_[var][i];
263  arg[var]=i;
264  }
265  }
266  nodeList.push_back(var);
267  }
268  ++var;
269  while(nodeList.size()>0){
270  size_t node = nodeList.back();
271  size_t childrenCounter = 0;
272  nodeList.pop_back();
273  for(typename GM::ConstFactorIterator it=gm_.factorsOfVariableBegin(node); it !=gm_.factorsOfVariableEnd(node); ++it){
274  const typename GM::FactorType& factor = gm_[(*it)];
275  if(factor.numberOfVariables()==2 ){
276  if(factor.variableIndex(1)==node && nodeOrder_[factor.variableIndex(0)] > nodeOrder_[node] ){
277  arg[factor.variableIndex(0)] = stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+arg[node]];
278  nodeList.push_back(factor.variableIndex(0));
279  ++childrenCounter;
280  }
281  if(factor.variableIndex(0)==node && nodeOrder_[factor.variableIndex(1)] > nodeOrder_[node] ){
282  arg[factor.variableIndex(1)] = stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+arg[node]];
283  nodeList.push_back(factor.variableIndex(1));
284  ++childrenCounter;
285  }
286  }
287  }
288  }
289  }
290  return NORMAL;
291  } else {
292  arg.assign(gm_.numberOfVariables(), 0);
293  return UNKNOWN;
294  }
295  }
296  }
297 
298  template<class GM, class ACC>
299  inline void DynamicProgramming<GM, ACC>::getNodeInfo(const IndexType Inode, std::vector<ValueType>& values, std::vector<std::vector<LabelType> >& substates, std::vector<IndexType>& nodes) const{
300  values.clear();
301  substates.clear();
302  nodes.clear();
303  values.resize(gm_.numberOfLabels(Inode));
304  substates.resize(gm_.numberOfLabels(Inode));
305  std::vector<LabelType> arg;
306  bool firstround = true;
307  std::vector<size_t> nodeList;
308  for(IndexType i=0;i<gm_.numberOfLabels(Inode); ++i){
309  arg.assign(gm_.numberOfVariables(), std::numeric_limits<LabelType>::max() );
310  arg[Inode]=i;
311  values[i]=valueBuffers_[Inode][i];
312  nodeList.push_back(Inode);
313  if(i!=0){
314  firstround=false;
315  }
316 
317  while(nodeList.size()>0){
318  size_t node = nodeList.back();
319  size_t childrenCounter = 0;
320  nodeList.pop_back();
321  for(typename GM::ConstFactorIterator it=gm_.factorsOfVariableBegin(node); it !=gm_.factorsOfVariableEnd(node); ++it){
322  const typename GM::FactorType& factor = gm_[(*it)];
323  if(factor.numberOfVariables()==2 ){
324  if(factor.variableIndex(1)==node && nodeOrder_[factor.variableIndex(0)] > nodeOrder_[node] ){
325  arg[factor.variableIndex(0)] = stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+arg[node]];
326  substates[i].push_back(stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+arg[node]]);
327  if(firstround==true){
328  nodes.push_back(factor.variableIndex(0));
329  }
330  nodeList.push_back(factor.variableIndex(0));
331  ++childrenCounter;
332  }
333  if(factor.variableIndex(0)==node && nodeOrder_[factor.variableIndex(1)] > nodeOrder_[node] ){
334  arg[factor.variableIndex(1)] = stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+arg[node]];
335  substates[i].push_back(stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+arg[node]]);
336  if(firstround==true){
337  nodes.push_back(factor.variableIndex(1));
338  }
339  nodeList.push_back(factor.variableIndex(1));
340  ++childrenCounter;
341  }
342  }
343  }
344  }
345  }
346  }
347 
348 
349 } // namespace opengm
350 
351 #endif // #ifndef OPENGM_DYNAMICPROGRAMMING_HXX
352 
The OpenGM namespace.
Definition: config.hxx:43
#define OPENGM_ASSERT(expression)
Definition: opengm.hxx:77
GraphicalModelType::IndexType IndexType
Definition: inference.hxx:40
void getNodeInfo(const IndexType Inode, std::vector< ValueType > &values, std::vector< std::vector< LabelType > > &substates, std::vector< IndexType > &nodes) const
GraphicalModelType::ValueType ValueType
Definition: inference.hxx:41
Inference algorithm interface.
Definition: inference.hxx:34
InferenceTermination arg(std::vector< LabelType > &, const size_t=1) const
output a solution
const GraphicalModelType & graphicalModel() const
visitors::VerboseVisitor< DynamicProgramming< GM, ACC > > VerboseVisitorType
visitors::EmptyVisitor< DynamicProgramming< GM, ACC > > EmptyVisitorType
InferenceTermination infer()
GraphicalModelType::LabelType LabelType
Definition: inference.hxx:39
DynamicProgramming(const GraphicalModelType &, const Parameter &=Parameter())
InferenceTermination
Definition: inference.hxx:24
visitors::TimingVisitor< DynamicProgramming< GM, ACC > > TimingVisitorType