OpenGM  2.3.x
Discrete Graphical Model Library
trws_trws.hxx
Go to the documentation of this file.
1 #ifndef TRWS_INTERFACE_HXX_
2 #define TRWS_INTERFACE_HXX_
6 
7 namespace opengm{
8 
9 template<class GM>
10 struct TRWSi_Parameter : public trws_base::MaxSumTRWS_Parameters<typename GM::ValueType>
11 {
12  typedef typename GM::ValueType ValueType;
15  typedef std::vector<typename GM::ValueType> DDVectorType;
16 
17  TRWSi_Parameter(size_t maxIternum=0,
19  ValueType precision=1.0,
20  bool absolutePrecision=true,
21  bool verbose=false)
22  :parent(maxIternum,precision,absolutePrecision),
25  initPoint_(0)
26 {
27 }
28 
30  bool verbose_;
31  DDVectorType initPoint_;
32 
35  //void setMaxNumberOfIterations(size_t maxNumberOfIterations) {parent::maxNumberOfIterations_=maxNumberOfIterations; if ()}
36 
37  ValueType& precision(){return parent::precision_;}
38  const ValueType& precision()const{return parent::precision_;}
39 
40  bool& isAbsolutePrecision(){return parent::absolutePrecision_;};//true for absolute precision, false for relative w.r.t. dual value
41  const bool& isAbsolutePrecision()const{return parent::absolutePrecision_;};//true for absolute precision, false for relative w.r.t. dual value
42 
45 
47  const bool& fastComputations()const{return parent::fastComputations_;}
48 
51 
54 
55  bool& verbose(){return verbose_;};
56  const bool& verbose()const{return verbose_;};
57 
58 #ifdef TRWS_DEBUG_OUTPUT
59  void print(std::ostream& fout)const
60  {
61  fout << "maxNumberOfIterations="<<maxNumberOfIterations()<<std::endl;
62  fout <<"precision="<<precision()<<std::endl;
63  fout <<"isAbsolutePrecision="<<isAbsolutePrecision()<<std::endl;
64  fout <<"minRelativeDualImprovement="<<minRelativeDualImprovement()<<std::endl;
65  fout <<"fastComputations="<<fastComputations()<<std::endl;
66  fout <<"canonicalNormalization="<<canonicalNormalization()<<std::endl;
67  fout << "decompositionType=" << Storage::getString(decompositionType()) << std::endl;
68 
69  fout <<"verbose="<<verbose()<<std::endl;
70  fout <<"treeAgreeMaxStableIter="<<parent::treeAgreeMaxStableIter()<<std::endl;
71  }
72 #endif
73 };
74 
91 
92 template<class GM, class ACC>
93 class TRWSi : public Inference<GM, ACC>
94 {
95 public:
96  typedef ACC AccumulationType;
97  typedef GM GraphicalModelType;
104 
106 // typedef typename Solver::ReparametrizerType ReparametrizerType;
109 
110  TRWSi(const GraphicalModelType& gm, const Parameter& param
111 #ifdef TRWS_DEBUG_OUTPUT
112  ,std::ostream& fout=std::cout
113 #endif
114  ):
115  _storage(gm,param.decompositionType_,(param.initPoint_.size()==0 ? 0 : &param.initPoint_)),
116  _solver(_storage,param
117 #ifdef TRWS_DEBUG_OUTPUT
118  ,(param.verbose_ ? fout : *OUT::nullstream::Instance()) //fout
119 #endif
120  ){
121 #ifdef TRWS_DEBUG_OUTPUT
122  std::ostream& out=(param.verbose_ ? fout : *OUT::nullstream::Instance());
123  out << "Parameters of the "<< name() <<" algorithm:"<<std::endl;
124  param.print(out);
125 #endif
126 
127  if (param.maxNumberOfIterations_==0) throw
128  std::runtime_error("TRWSi: Maximal number of iterations (> 0) has to be specified!");
129  }
130  std::string name() const{ return "TRWSi"; }
131  const GraphicalModelType& graphicalModel() const { return _storage.masterModel(); }
133  _solver.infer();
134  return NORMAL;
135  };
136 
137  template<class VISITOR> InferenceTermination infer(VISITOR & visitor){
138  visitors::VisitorWrapper<VISITOR,TRWSi<GM, ACC> > visiwrap(&visitor,this);
139  _solver.infer(visiwrap);
140  return NORMAL;
141  };
142 
143  InferenceTermination arg(std::vector<LabelType>& out, const size_t = 1) const
144  {
145  out = _solver.arg();
146  return opengm::NORMAL;}
147  virtual ValueType bound() const{return _solver.bound();}
148  virtual ValueType value() const{return _solver.value();}
149  void getTreeAgreement(std::vector<bool>& out,std::vector<LabelType>* plabeling=0,std::vector<std::vector<LabelType> >* ptreeLabelings=0){_solver.getTreeAgreement(out,plabeling,ptreeLabelings);}
150  //const Storage& getDecompositionStorage()const{return _storage;}
151  Storage& getDecompositionStorage(){return _storage;}
152  const typename Solver::FactorProperties& getFactorProperties()const {return _solver.getFactorProperties();}
153 
154 // ReparametrizerType* getReparametrizer(const typename ReparametrizerType::Parameter& params= typename ReparametrizerType::Parameter())const
155 // {return _solver.getReparametrizer(params);}
156 
157 
158  ReparametrizerType * getReparametrizer(const typename ReparametrizerType::Parameter& params=typename ReparametrizerType::Parameter())//const //TODO: make it constant
159  {return new ReparametrizerType(_storage,_solver.getFactorProperties(),params);}
160 
161  void getDDVector(DDVectorType* pddvector)const{_storage.getDDVector(pddvector);}
162 
163  private:
164  Storage _storage;
165  Solver _solver;
166 };
167 
168 }
169 #endif
170 
trws_base::DecompositionStorage< GM > Storage
Definition: trws_trws.hxx:100
TRWSi(const GraphicalModelType &gm, const Parameter &param)
Definition: trws_trws.hxx:110
bool & isAbsolutePrecision()
Definition: trws_trws.hxx:40
virtual ValueType value() const
Definition: trws_base.hxx:234
The OpenGM namespace.
Definition: config.hxx:43
Storage::StructureType & decompositionType()
Definition: trws_trws.hxx:52
virtual InferenceTermination infer()
Definition: trws_base.hxx:254
GM::ValueType ValueType
Definition: trws_trws.hxx:12
const ValueType & precision() const
Definition: trws_trws.hxx:38
Storage & getDecompositionStorage()
Definition: trws_trws.hxx:151
const bool & canonicalNormalization() const
Definition: trws_trws.hxx:50
DDVectorType initPoint_
Definition: trws_trws.hxx:31
Storage::StructureType decompositionType_
Definition: trws_trws.hxx:29
InferenceTermination infer()
Definition: trws_trws.hxx:132
visitors::EmptyVisitor< TRWSi< GM, ACC > > EmptyVisitorType
Definition: trws_trws.hxx:103
const FactorProperties & getFactorProperties() const
Definition: trws_base.hxx:265
const Solver::FactorProperties & getFactorProperties() const
Definition: trws_trws.hxx:152
TRWSi_Parameter(size_t maxIternum=0, typename Storage::StructureType decompositionType=Storage::GENERALSTRUCTURE, ValueType precision=1.0, bool absolutePrecision=true, bool verbose=false)
Definition: trws_trws.hxx:17
virtual const std::vector< LabelType > & arg() const
Definition: trws_base.hxx:236
InferenceTermination infer(VISITOR &visitor)
Definition: trws_trws.hxx:137
void getTreeAgreement(std::vector< bool > &out, std::vector< LabelType > *plabeling=0, std::vector< std::vector< LabelType > > *ptreeLabelings=0)
Definition: trws_base.hxx:1152
GM GraphicalModelType
Definition: trws_trws.hxx:97
const ValueType & minRelativeDualImprovement() const
Definition: trws_trws.hxx:44
std::vector< typename GM::ValueType > DDVectorType
Definition: trws_trws.hxx:15
ReparametrizerType * getReparametrizer(const typename ReparametrizerType::Parameter &params=typename ReparametrizerType::Parameter())
Definition: trws_trws.hxx:158
TRWS_Reparametrizer< Storage, ACC > ReparametrizerType
Definition: trws_trws.hxx:107
ValueType & minRelativeDualImprovement()
Definition: trws_trws.hxx:43
const Storage::StructureType & decompositionType() const
Definition: trws_trws.hxx:53
const GraphicalModelType & graphicalModel() const
Definition: trws_trws.hxx:131
void getTreeAgreement(std::vector< bool > &out, std::vector< LabelType > *plabeling=0, std::vector< std::vector< LabelType > > *ptreeLabelings=0)
Definition: trws_trws.hxx:149
Storage::DDVectorType DDVectorType
Definition: trws_trws.hxx:108
Inference algorithm interface.
Definition: inference.hxx:34
ACC AccumulationType
Definition: trws_trws.hxx:96
size_t & maxNumberOfIterations()
Definition: trws_trws.hxx:33
visitors::VerboseVisitor< TRWSi< GM, ACC > > VerboseVisitorType
Definition: trws_trws.hxx:101
std::vector< typename GM::ValueType > DDVectorType
Definition: trws_base.hxx:51
const bool & verbose() const
Definition: trws_trws.hxx:56
virtual ValueType bound() const
return a bound on the solution
Definition: trws_trws.hxx:147
trws_base::DecompositionStorage< GM > Storage
Definition: trws_trws.hxx:14
[class trwsi] TRWSi - tree-reweighted sequential message passing Based on the paper: V...
Definition: trws_trws.hxx:93
const size_t & maxNumberOfIterations() const
Definition: trws_trws.hxx:34
trws_base::MaxSumTRWS_Parameters< ValueType > parent
Definition: trws_trws.hxx:13
virtual ValueType value() const
return the solution (value)
Definition: trws_trws.hxx:148
virtual ValueType bound() const
Definition: trws_base.hxx:235
const bool & isAbsolutePrecision() const
Definition: trws_trws.hxx:41
TRWSi_Parameter< GM > Parameter
Definition: trws_trws.hxx:105
const bool & fastComputations() const
Definition: trws_trws.hxx:47
InferenceTermination arg(std::vector< LabelType > &out, const size_t=1) const
output a solution
Definition: trws_trws.hxx:143
void getDDVector(DDVectorType *pddvector) const
Definition: trws_trws.hxx:161
void getDDVector(DDVectorType *ddvector) const
Definition: trws_base.hxx:1033
static std::string getString(StructureType structure)
Definition: trws_base.hxx:37
std::string name() const
Definition: trws_trws.hxx:130
visitors::TimingVisitor< TRWSi< GM, ACC > > TimingVisitorType
Definition: trws_trws.hxx:102
ValueType & precision()
Definition: trws_trws.hxx:37
bool & canonicalNormalization()
Definition: trws_trws.hxx:49
InferenceTermination
Definition: inference.hxx:24
trws_base::MaxSumTRWS< GM, ACC > Solver
Definition: trws_trws.hxx:99