OpenGM  2.3.x
Discrete Graphical Model Library
trws_reparametrization.hxx
Go to the documentation of this file.
1 #ifndef REPARAMETRIZATION_HXX
2 #define REPARAMETRIZATION_HXX
3 #include <valarray>
4 #include <iostream>
5 #include <map>
10 
12 
13 namespace opengm {
14 
15 namespace trws_base{
16 
17 /*
18  * Trivialization solver computes dual variables, which trivialize the input problem, making local decision globally consistent
19  */
20 
21 template<class GM,class ACC,class InputIterator>
22 class TrivializationSolver : protected MaxSumSolver<GM,ACC,InputIterator>
23 {
24 public:
26  typedef typename parent::ValueType ValueType;
27  typedef typename parent::IndexType IndexType;
28  typedef typename parent::LabelType LabelType;
29 
31  typedef typename parent::Storage Storage;
34  typedef typename parent::UnaryFactor UnaryFactor;
35  typedef typename UnaryFactor::const_iterator const_uIterator;
37 
38  typedef typename std::vector<bool> MaskType;
39  typedef typename std::vector<MaskType> ImmovableLabelingType;
40 
41  TrivializationSolver(Storage& primalstorage,
42  DualStorage& dualstorage,
43  const FactorProperties& fp,
44  bool fastComputations=true)
45  :parent(primalstorage,fp,fastComputations),
46  _dualstorage(dualstorage){};
47  void ForwardMove(MoveDirection direction=Storage::Direct){parent::InitMove(direction); parent::Move();};
48  void BackwardMove(const MaskType* pmask=0);
49  //void BackwardMove(const ImmovableLabelingType& immovableLabels);
50 
51  ValueType GetObjectiveValue()const{return parent::GetObjectiveValue();}
52 private:
53  void _PushBack();
54  //void _PushBack(const MaskType& mask);
55  IndexType _distanceFromStart();
56  void _InitBackwardMoveBuffer(IndexType index);
57  void _setDuals(IndexType index,typename SequenceStorage<GM>::MoveDirection moveDir,const_uIterator it);
58  DualStorage& _dualstorage;
59  MaskType _mask;
60  IndexType _numberOfBoundaryTerms;
61 
62  // computation optimization
63 
64  //std::vector<ValueType> _multipliers;
65 };
66 
67 //=======================TrivializationSolver implementation ===========================================
68 template<class GM,class ACC,class InputIterator>
69 typename TrivializationSolver<GM,ACC,InputIterator>::IndexType
70 TrivializationSolver<GM,ACC,InputIterator>::_distanceFromStart()
71 {
72  if (parent::_moveDirection==Storage::Direct)
73  return parent::_currentUnaryIndex;
74  else
75  return parent::size()-1-parent::_currentUnaryIndex;
76 }
77 
78 template<class GM,class ACC,class InputIterator>
79 void TrivializationSolver<GM,ACC,InputIterator>::_InitBackwardMoveBuffer(IndexType index)
80 {
81  assert(index < parent::_storage.size());
82  parent::_currentUnaryIndex=index;
83  parent::_currentUnaryFactor.resize(parent::_storage.unaryFactors(parent::_currentUnaryIndex).size());
84  std::copy(parent::_marginals[parent::_currentUnaryIndex].
85  begin(),parent::_marginals[parent::_currentUnaryIndex].end(),parent::_currentUnaryFactor.begin());
86 
87  _numberOfBoundaryTerms=std::count(_mask.begin(),_mask.end(),true);
88 }
89 
90 
91 
92 template<class GM,class ACC,class InputIterator>
93 void TrivializationSolver<GM,ACC,InputIterator>::_PushBack()
94 {
95  OPENGM_ASSERT(_mask.size()==parent::size());
96  ValueType multiplier;
97  if (_mask[parent::_currentUnaryIndex])
98  {
99  multiplier=((ValueType)_numberOfBoundaryTerms-1.0)/_numberOfBoundaryTerms;
100  --_numberOfBoundaryTerms;
101  }
102  else
103  {
104  multiplier=1.0;
105  }
106 
107  transform_inplace(parent::_currentUnaryFactor.begin(),
108  parent::_currentUnaryFactor.end(),
109  std::bind2nd(std::multiplies<ValueType>(),multiplier));
110 
111  std::transform(parent::_currentUnaryFactor.begin(),
112  parent::_currentUnaryFactor.end(),
113  parent::_marginals[parent::_currentUnaryIndex].begin(),
114  parent::_currentUnaryFactor.begin(),
115  std::minus<ValueType>());
116  std::transform(parent::_currentUnaryFactor.begin(),parent::_currentUnaryFactor.end(),
117  parent::_storage.unaryFactors(parent::_currentUnaryIndex).begin(),
118  parent::_currentUnaryFactor.begin(),std::plus<ValueType>());
119 
120  _setDuals(parent::_currentUnaryIndex,parent::_moveDirection,parent::_currentUnaryFactor.begin());
121 
122  parent::_PushMessagesToFactor();
123  parent::_currentUnaryIndex=parent::_next(parent::_currentUnaryIndex);//instead of _InitCurrentUnaryBuffer(_next(_currentUnaryIndex));
124  parent::_currentUnaryFactor.assign(parent::_storage.unaryFactors(parent::_currentUnaryIndex).size(),0.0);
125  parent::_ClearMessages();
126 
127  transform_inplace(parent::_currentUnaryFactor.begin(),
128  parent::_currentUnaryFactor.end(),
129  std::bind2nd(std::multiplies<ValueType>(),-1.0));
130  _setDuals(parent::_currentUnaryIndex,Storage::ReverseDirection(parent::_moveDirection),
131  parent::_currentUnaryFactor.begin());
132 
133  std::transform(parent::_marginals[parent::_currentUnaryIndex].begin(),
134  parent::_marginals[parent::_currentUnaryIndex].end(),
135  parent::_currentUnaryFactor.begin(),parent::_currentUnaryFactor.begin(),
136  std::minus<ValueType>());
137 }
138 
139 //template<class GM,class ACC,class InputIterator>
140 //void TrivializationSolver<GM,ACC,InputIterator>::_PushBack(const MaskType& mask)
141 //{
142 // OPENGM_ASSERT(mask.size()==parent::_currentUnaryFactor.size());
143 // _multipliers.resize(parent::_currentUnaryFactor.size());
144 // bool decrease=false;
145 //
146 // //std::cout << "_numberOfBoundaryTerms="<<_numberOfBoundaryTerms<<std::endl;
147 //
148 // for (IndexType label=0;label<_multipliers.size();++label)
149 // {
150 // if (mask[label]) _multipliers[label]=1.0;
151 // else
152 // {
153 // _multipliers[label]=((ValueType)_numberOfBoundaryTerms-1.0)/_numberOfBoundaryTerms;
154 // decrease=true;
155 // }
156 // }
157 //
158 // if (decrease)
159 // {
160 // --_numberOfBoundaryTerms;
161 // decrease=false;
162 // }
163 //
164 // //std::cout << "_multipliers:" <<_multipliers<<std::endl;
165 //
166 // transform(parent::_currentUnaryFactor.begin(),
167 // parent::_currentUnaryFactor.end(),
168 // _multipliers.begin(),parent::_currentUnaryFactor.begin(),
169 // std::multiplies<ValueType>());
170 //
171 // std::transform(parent::_currentUnaryFactor.begin(),
172 // parent::_currentUnaryFactor.end(),
173 // parent::_marginals[parent::_currentUnaryIndex].begin(),
174 // parent::_currentUnaryFactor.begin(),
175 // std::minus<ValueType>());
176 // std::transform(parent::_currentUnaryFactor.begin(),parent::_currentUnaryFactor.end(),
177 // parent::_storage.unaryFactors(parent::_currentUnaryIndex).begin(),
178 // parent::_currentUnaryFactor.begin(),std::plus<ValueType>());
179 //
180 // _setDuals(parent::_currentUnaryIndex,parent::_moveDirection,parent::_currentUnaryFactor.begin());
181 //
182 // parent::_PushMessagesToFactor();
183 // parent::_currentUnaryIndex=parent::_next(parent::_currentUnaryIndex);//instead of _InitCurrentUnaryBuffer(_next(_currentUnaryIndex));
184 // parent::_currentUnaryFactor.assign(parent::_storage.unaryFactors(parent::_currentUnaryIndex).size(),0.0);
185 // parent::_ClearMessages();
186 //
187 // transform_inplace(parent::_currentUnaryFactor.begin(),
188 // parent::_currentUnaryFactor.end(),
189 // std::bind2nd(std::multiplies<ValueType>(),-1.0));
190 // _setDuals(parent::_currentUnaryIndex,Storage::ReverseDirection(parent::_moveDirection),
191 // parent::_currentUnaryFactor.begin());
192 //
193 // std::transform(parent::_marginals[parent::_currentUnaryIndex].begin(),
194 // parent::_marginals[parent::_currentUnaryIndex].end(),
195 // parent::_currentUnaryFactor.begin(),parent::_currentUnaryFactor.begin(),
196 // std::minus<ValueType>());
197 //}
198 
199 
200 template<class GM,class ACC,class InputIterator>
202 {
203  if (pmask==0)
204  _mask.assign(parent::size(),true);
205  else
206  _mask=*pmask;
207 
208  parent::_moveDirection=Storage::ReverseDirection(parent::_moveDirection);
209  if (parent::_moveDirection==Storage::Direct)
210  _InitBackwardMoveBuffer(0);
211  else
212  _InitBackwardMoveBuffer(parent::size()-1);
213 
214  for (IndexType i=0;i<parent::size()-1;++i)
215  _PushBack(); //_Push(size()-i) - the current value of i is known, as _currentIndex
216 
217  parent::_bInitializationNeeded=true;
218 }
219 
220 //template<class GM,class ACC,class InputIterator>
221 //void TrivializationSolver<GM,ACC,InputIterator>::BackwardMove(const ImmovableLabelingType& immovableLabels)
222 //{
223 // _mask.assign(parent::size(),true);
224 // parent::_moveDirection=Storage::ReverseDirection(parent::_moveDirection);
225 //
226 // if (parent::_moveDirection==Storage::Direct)
227 // {
228 // //std::cout << "Direct"<<std::endl;
229 // _InitBackwardMoveBuffer(0);
230 // }
231 // else
232 // {
233 // //std::cout << "Reverse"<<std::endl;
234 // _InitBackwardMoveBuffer(parent::size()-1);
235 // }
236 //
237 // //for (IndexType i=0;i<parent::size()-1;++i)//!> number of iterations is equal to a number of pairwise factors
238 // for (IndexType i=0;i<parent::size()-1;++i)//!> number of iterations is equal to a number of pairwise factors
239 // _PushBack(immovableLabels[parent::_currentUnaryIndex]); //_Push(size()-i) - the current value of i is known, as _currentIndex
240 //
241 // parent::_bInitializationNeeded=true;
242 //}
243 
244 template<class GM,class ACC,class InputIterator>
246 ::_setDuals(IndexType index,typename SequenceStorage<GM>::MoveDirection movedir,const_uIterator it)
247  {
248  IndexType pwId, varId;
249 
250  if (movedir==Storage::Direct)
251  {
252  pwId=parent::_storage.pwForwardFactor(index);
253  varId=(parent::_storage.pwDirection(index)==Storage::Direct ? 0 : 1);
254  }
255  else
256  {
257  pwId=parent::_storage.pwForwardFactor(index-1);
258  varId=(parent::_storage.pwDirection(index-1)==Storage::Direct ? 1 : 0);
259  }
260  std::pair<typename DualStorage::uIterator,typename DualStorage::uIterator> dualIt
261  =_dualstorage.getIterators(pwId,varId);
262  std::copy(it,it+(dualIt.second-dualIt.first),dualIt.first);
263  };
264 }//namespace trws_base
265 
266 
267 template<class ValueType>
269 {
271 
272  TRWS_Reparametrizer_Parameters(bool fastComputations=true):
273  fastComputations_(fastComputations)
274  {};
275 };
276 
277 template<class Storage,class ACC>
278 class TRWS_Reparametrizer : public opengm::LPReparametrizer<typename Storage::GraphicalModelType, ACC>
279 {
280 public:
283  typedef typename GraphicalModelType::ValueType ValueType;
284  typedef typename GraphicalModelType::IndexType IndexType;
285  typedef typename GraphicalModelType::LabelType LabelType;
286 
288  typedef typename parent::MaskType MaskType;
291 
293 
296 
297  TRWS_Reparametrizer(Storage& storage,const FunctionParametersType& fparams,const Parameter& params=Parameter());
298  virtual ~TRWS_Reparametrizer();
299  void reparametrize(const MaskType* pmask=0);
300  void reparametrize(const ImmovableLabelingType& immovableLabeling);
301 
302 private:
303  Storage& _storage;
304  std::vector<SubSolverType*> _subSolvers;
305 };
306 
307 template<class Storage,class ACC>
309 {
310  std::for_each(_subSolvers.begin(),_subSolvers.end(),trws_base::DeallocatePointer<SubSolverType>);
311 }
312 
313 template<class Storage,class ACC>
315  const FunctionParametersType& fparams,
316  const Parameter& params):
317  parent(storage.masterModel()),
318  _storage(storage)
319  {
320  _subSolvers.resize(_storage.numberOfModels());
321 
322  for (size_t modelId=0;modelId<_subSolvers.size();++modelId)
323  {
324  _subSolvers[modelId]= new SubSolverType(_storage.subModel(modelId),parent::Reparametrization(),fparams,params.fastComputations_);
325  }
326 }
327 
328 
329 template<class Storage,class ACC>
331 {
332 
333  MaskType mask(pmask!=0 ? *pmask : MaskType(_storage.masterModel().numberOfVariables(),true));
334  OPENGM_ASSERT(mask.size()==_storage.masterModel().numberOfVariables());
335  ValueType bound=0;
336  MaskType sequenceMask;
337  for (size_t i=0;i<_subSolvers.size();++i)
338  {
339  typename Storage::SubModel& model=_storage.subModel(i);
340  sequenceMask.resize(model.size());
341  for (IndexType localInd=0; localInd<sequenceMask.size();++localInd)
342  {
343  OPENGM_ASSERT(model.varIndex(localInd) < mask.size());
344  sequenceMask[localInd]=mask[model.varIndex(localInd)];
345  }
346 
347  _subSolvers[i]->ForwardMove();
348  _subSolvers[i]->BackwardMove(&sequenceMask);
349  bound+=_subSolvers[i]->GetObjectiveValue();
350  }
351 
352 }
353 
354 //template<class Storage,class ACC>
355 //void TRWS_Reparametrizer<Storage,ACC>::reparametrize(const ImmovableLabelingType& immovableLabeling)
356 //{
357 //
358 // //MaskType mask(pmask!=0 ? *pmask : MaskType(_storage.masterModel().numberOfVariables(),true));
359 // OPENGM_ASSERT(immovableLabeling.size()==_storage.masterModel().numberOfVariables());
360 // ValueType bound=0;
361 // ImmovableLabelingType sequenceLabeling;
362 // for (size_t i=0;i<_subSolvers.size();++i)
363 // {
364 // typename Storage::SubModel& model=_storage.subModel(i);
365 // sequenceLabeling.resize(model.size());
366 // for (IndexType localInd=0; localInd<sequenceLabeling.size();++localInd)
367 // {
368 // OPENGM_ASSERT(model.varIndex(localInd) < immovableLabeling.size());
369 // sequenceLabeling[localInd]=immovableLabeling[model.varIndex(localInd)];
370 // }
371 //
372 // //std::cout << "ForwardMove: ";
373 // _subSolvers[i]->ForwardMove();
374 // //std::cout << "BackwardMove: ";
375 // _subSolvers[i]->BackwardMove(sequenceLabeling);
376 // bound+=_subSolvers[i]->GetObjectiveValue();
377 // }
378 //
379 //}
380 
381 template<class Storage,class ACC>
382 void TRWS_Reparametrizer<Storage,ACC>::reparametrize(const ImmovableLabelingType& immovableLabeling)
383 {
384  OPENGM_ASSERT(immovableLabeling.size()==_storage.masterModel().numberOfVariables());
385  reparametrize();
386 
387  typedef typename parent::RepaStorageType::uIterator uIterator;
388 
389  const typename Storage::GraphicalModelType& gm=_storage.masterModel();
390  for (IndexType factorID=0;factorID < gm.numberOfFactors();++factorID)
391  {
392  if (gm[factorID].numberOfVariables()<2) continue;
393  /*
394  * Make zero potentials for immovable labels
395  */
396 
397  for (IndexType localVarID=0;localVarID<gm[factorID].numberOfVariables();++localVarID)
398  {
399  std::pair<uIterator,uIterator> it=parent::Reparametrization().getIterators(factorID,localVarID);
400  IndexType globalVarID=gm[factorID].variableIndex(localVarID);
401  typename MaskType::const_iterator labIt=immovableLabeling[globalVarID].begin();
402  for (;it.first!=it.second;++it.first)
403  if (*labIt++) *it.first=0;
404  }
405 
406  /*
407  * Make reparametrized pairwise factors non-negative
408  */
409 
410  if (gm[factorID].numberOfVariables()!=2) throw std::runtime_error("TRWS_Reparametrizer<Storage,ACC>::reparametrize(): factors of order higher than 2 are not supported!");
411 
412  std::vector<IndexType> labeling(2);
413  for (IndexType localVarID=0;localVarID<gm[factorID].numberOfVariables();++localVarID)
414  {
415  std::pair<uIterator,uIterator> it=parent::Reparametrization().getIterators(factorID,localVarID);
416  uIterator it_begin=it.first;
417  IndexType globalVarID=gm[factorID].variableIndex(localVarID);
418  typename MaskType::const_iterator labIt=immovableLabeling[globalVarID].begin();
419  ValueType res=ACC::template neutral<ValueType>();
420 
421  for (;it.first!=it.second;++it.first)
422  if (!(*labIt++))
423  {
424  IndexType otherVarID=(localVarID==0 ? 1 : 0);
425  labeling[localVarID]=it.first-it_begin;
426  for (LabelType label=0;label<gm.numberOfLabels(otherVarID);++label)
427  {
428  labeling[otherVarID]=label;
429 
430  ValueType res1=parent::Reparametrization().getFactorValue(factorID,labeling.begin());
431  ACC::op(res,res1,res);
432  }
433  *it.first-=res;
434  }
435  }
436  }
437 
438 }
439 
440 
441 
442 //============ LP reparametrization to TRWS reparametrization ===========================
443 template<class GM>
445 {
446  OPENGM_ASSERT(&lpRepa.graphicalModel() == &ptrwsRepa->masterModel());
447 
448  typedef typename LPReparametrisationStorage<GM>::uIterator uIterator;
449  typedef typename GM::ValueType ValueType;
450  typedef typename GM::IndexType IndexType;
451  typedef typename GM::LabelType LabelType;
452  typedef typename trws_base::DecompositionStorage<GM> DecompositionStorage;
453 
454  std::vector<ValueType> repaUnary;
455  //for all variables (and related unary factors)
456  for (IndexType varId=0;varId<lpRepa.graphicalModel().numberOfVariables();++varId)// all variables
457  { const typename DecompositionStorage::SubVariableListType& varList=ptrwsRepa->getSubVariableList(varId);
458 
459  if (varList.size()==1) continue;
460 
461  // compute common part - the sum of all potentials
462  repaUnary.resize(lpRepa.graphicalModel().numberOfLabels(varId));
463  for (LabelType label=0;label<repaUnary.size();++label)
464  {
465  repaUnary[label]=lpRepa.getVariableValue(varId,label);
466  }
467 
468  trws_base::transform_inplace(repaUnary.begin(),repaUnary.end(),std::bind2nd(std::multiplies<ValueType>(),1.0/varList.size()));
469 
470  //for all submodels
471  for(typename DecompositionStorage::SubVariableListType::const_iterator modelIt=varList.begin();
472  modelIt!=varList.end();++modelIt) //all related models
473  {
474  typename DecompositionStorage::SubModel& subModel=ptrwsRepa->subModel(modelIt->subModelId_);
475  typename DecompositionStorage::SubModel::UnaryFactor::iterator uit_begin=subModel.ufBegin(modelIt->subVariableId_);
476  typename DecompositionStorage::SubModel::UnaryFactor::iterator uit_end =subModel.ufEnd(modelIt->subVariableId_);
477  //unary=repaUnary/numberTrees
478  std::copy(repaUnary.begin(),repaUnary.end(),uit_begin);
479  //add only potentials belonging to the submodel
480 // std::pair<uIterator,uIterator> repaIt;
481  const typename LPReparametrisationStorage<GM>::UnaryFactor* prepaUF;
482  if (modelIt->subVariableId_ < subModel.size()-1)
483  {
484  IndexType pwId=subModel.pwForwardFactor(modelIt->subVariableId_);
485  if (lpRepa.graphicalModel()[pwId].variableIndex(0)==varId)
486  prepaUF=&lpRepa.get(pwId,0);
487  else prepaUF=&lpRepa.get(pwId,1);
488 
489  std::transform(uit_begin,uit_end,prepaUF->begin(),uit_begin,std::plus<ValueType>());
490  }
491  if (modelIt->subVariableId_ >0)
492  {
493  IndexType pwId=subModel.pwForwardFactor(modelIt->subVariableId_-1);
494  if (lpRepa.graphicalModel()[pwId].variableIndex(0)==varId)
495  prepaUF=&lpRepa.get(pwId,0);
496  else prepaUF=&lpRepa.get(pwId,1);
497 
498  std::transform(uit_begin,uit_end,prepaUF->begin(),uit_begin,std::plus<ValueType>());
499  }
500 
501  }
502  }
503 }
504 
505 
506 } //namespace opengm
507 #endif
The OpenGM namespace.
Definition: config.hxx:43
const SubVariableListType & getSubVariableList(IndexType varId) const
Definition: trws_base.hxx:64
TRWS_Reparametrizer_Parameters(bool fastComputations=true)
ValueType getVariableValue(IndexType varIndex, LabelType label) const
virtual void Move()
>initializes move, which is reverse to the current one//TODO: remove virtual ?
trws_base::FunctionParameters< GraphicalModelType > FunctionParametersType
TrivializationSolver(Storage &primalstorage, DualStorage &dualstorage, const FactorProperties &fp, bool fastComputations=true)
trws_base::TrivializationSolver< GraphicalModelType, ACC, typename std::vector< typename GraphicalModelType::ValueType >::const_iterator > SubSolverType
#define OPENGM_ASSERT(expression)
Definition: opengm.hxx:77
std::vector< MaskType > ImmovableLabelingType
TRWS_Reparametrizer_Parameters< ValueType > Parameter
InputIterator transform_inplace(InputIterator first, InputIterator last, UnaryOperator op)
Definition: utilities2.hxx:79
GraphicalModelType::IndexType IndexType
void LPtoDecompositionStorage(const LPReparametrisationStorage< GM > &lpRepa, trws_base::DecompositionStorage< GM > *ptrwsRepa)
TRWS_Reparametrizer(Storage &storage, const FunctionParametersType &fparams, const Parameter &params=Parameter())
LPReparametrisationStorage< GM > RepaStorageType
MaxSumSolver< GM, ACC, InputIterator > parent
opengm::LPReparametrisationStorage< GM > DualStorage
parent::GraphicalModelType GraphicalModelType
parent::ReparametrizedGMType ReparametrizedGMType
const UnaryFactor & get(IndexType factorIndex, IndexType relativeVarIndex) const
GraphicalModelType::ValueType ValueType
opengm::LPReparametrizer< typename Storage::GraphicalModelType, ACC > parent
parent::InputIteratorType InputIteratorType
GraphicalModelType::LabelType LabelType
parent::ImmovableLabelingType ImmovableLabelingType
parent::FactorProperties FactorProperties
parent::RepaStorageType RepaStorageType
SubModel & subModel(IndexType modelId)
Definition: trws_base.hxx:60
RepaStorageType & Reparametrization()
void reparametrize(const MaskType *pmask=0)
void ForwardMove(MoveDirection direction=Storage::Direct)