2 #ifndef OPENGM_EXTERNAL_AD3_HXX
3 #define OPENGM_EXTERNAL_AD3_HXX
11 #include "ad3/FactorGraph.h"
22 template<
class GM,
class ACC>
42 const double eta = 0.1,
43 const bool adaptEta =
true,
45 const double residualThreshold = 1e-6,
46 const int verbosity = 0
71 std::string
name()
const;
75 template<
class VisitorType>
80 return gm_.evaluate(arg_);
94 if( meta::Compare<OperatorType,Adder>::value && meta::Compare<AccumulationType,Minimizer>::value){
97 else if( meta::Compare<OperatorType,Adder>::value && meta::Compare<AccumulationType,Maximizer>::value){
103 if( meta::Compare<OperatorType,Adder>::value && meta::Compare<AccumulationType,Minimizer>::value){
106 else if( meta::Compare<OperatorType,Adder>::value && meta::Compare<AccumulationType,Maximizer>::value){
114 template<
class N_LABELS_ITER>
115 AD3Inf(N_LABELS_ITER nLabelsBegin,N_LABELS_ITER nLabelsEnd,
const Parameter para = Parameter());
121 template<
class VI_ITERATOR,
class FUNCTION>
122 void addFactor(VI_ITERATOR viBegin,VI_ITERATOR viEnd,
const FUNCTION &
function);
130 return additional_posteriors_;
135 const GraphicalModelType& gm_;
136 Parameter parameter_;
140 AD3::FactorGraph factor_graph_;
141 std::vector<AD3::MultiVariable*> multi_variables_;
143 std::vector<double> posteriors_;
144 std::vector<double> additional_posteriors_;
147 std::vector<LabelType> arg_;
151 std::vector<LabelType> space_;
159 template<
class GM,
class ACC>
167 numVar_(gm.numberOfVariables()),
169 multi_variables_(gm.numberOfVariables()),
171 additional_posteriors_(),
173 arg_(gm.numberOfVariables(),static_cast<
LabelType>(0)),
174 inferenceDone_(false),
178 if(meta::Compare<OperatorType,Adder>::value==
false){
179 throw RuntimeError(
"AD3 does not only support opengm::Adder as Operator");
182 if(meta::Compare<AccumulationType,Minimizer>::value==
false and meta::Compare<AccumulationType,Maximizer>::value==
false ){
183 throw RuntimeError(
"AD3 does not only support opengm::Minimizer and opengm::Maximizer as Accumulatpr");
187 bound_ = ACC::template ineutral<ValueType>();
191 factor_graph_.SetVerbosity(parameter_.verbosity_);
193 for(IndexType fi=0;fi<gm_.numberOfFactors();++fi){
194 maxFactorSize=std::max(static_cast<UInt64Type>(gm_[fi].size()),maxFactorSize);
197 ValueType * facVal =
new ValueType[maxFactorSize];
203 for(IndexType vi=0;vi<gm_.numberOfVariables();++vi){
204 multi_variables_[vi] = factor_graph_.CreateMultiVariable(gm_.numberOfLabels(vi));
205 for(
LabelType l=0;l<gm_.numberOfLabels(vi);++l){
206 multi_variables_[vi]->SetLogPotential(l,0.0);
213 for(IndexType fi=0;fi<gm_.numberOfFactors();++fi){
215 gm_[fi].copyValuesSwitchedOrder(facVal);
216 const IndexType nVar=gm_[fi].numberOfVariables();
219 const IndexType vi0 = gm_[fi].variableIndex(0);
220 const IndexType nl0 = gm_.numberOfLabels(vi0);
223 const ValueType logP = multi_variables_[vi0]->GetLogPotential(l);
224 const ValueType val = this->valueToMaxSum(facVal[l]);
225 multi_variables_[vi0]->SetLogPotential(l,logP+val);
231 std::vector<double> additional_log_potentials(gm_[fi].size());
232 for(IndexType i=0;i<gm_[fi].size();++i){
233 additional_log_potentials[i]=this->valueToMaxSum(facVal[i]);
237 std::vector<AD3::MultiVariable*> multi_variables_local(nVar);
238 for(IndexType v=0;v<nVar;++v){
239 multi_variables_local[v]=multi_variables_[gm_[fi].variableIndex(v)];
243 factor_graph_.CreateFactorDense(multi_variables_local,additional_log_potentials);
246 OPENGM_CHECK(
false,
"const factors are not yet implemented");
255 template<
class GM,
class ACC>
256 template<
class N_LABELS_ITER>
258 N_LABELS_ITER nLabelsBegin,
259 N_LABELS_ITER nLabelsEnd,
264 numVar_(
std::distance(nLabelsBegin,nLabelsEnd)),
266 multi_variables_(
std::distance(nLabelsBegin,nLabelsEnd)),
268 additional_posteriors_(),
270 arg_(
std::distance(nLabelsBegin,nLabelsEnd),static_cast<
LabelType>(0)),
271 space_(nLabelsBegin,nLabelsEnd)
274 if(meta::Compare<OperatorType,Adder>::value==
false){
275 throw RuntimeError(
"AD3 does not only support opengm::Adder as Operator");
277 if(meta::Compare<AccumulationType,Minimizer>::value==
false and meta::Compare<AccumulationType,Maximizer>::value==
false ){
278 throw RuntimeError(
"AD3 does not only support opengm::Minimizer and opengm::Maximizer as Accumulatpr");
280 bound_ = ACC::template ineutral<ValueType>();
281 factor_graph_.SetVerbosity(parameter_.
verbosity_);
285 multi_variables_[vi] = factor_graph_.CreateMultiVariable(space_[vi]);
287 multi_variables_[vi]->SetLogPotential(l,0.0);
292 template<
class GM,
class ACC>
303 multi_variables_(nVar),
305 additional_posteriors_(),
311 if(meta::Compare<OperatorType,Adder>::value==
false){
312 throw RuntimeError(
"AD3 does not only support opengm::Adder as Operator");
314 if(meta::Compare<AccumulationType,Minimizer>::value==
false and meta::Compare<AccumulationType,Maximizer>::value==
false ){
315 throw RuntimeError(
"AD3 does not only support opengm::Minimizer and opengm::Maximizer as Accumulatpr");
317 bound_ = ACC::template ineutral<ValueType>();
318 factor_graph_.SetVerbosity(parameter_.
verbosity_);
320 multi_variables_[vi] = factor_graph_.CreateMultiVariable(space_[vi]);
322 multi_variables_[vi]->SetLogPotential(l,0.0);
328 template<
class GM,
class ACC>
329 template<
class VI_ITERATOR,
class FUNCTION>
332 VI_ITERATOR visBegin,
334 const FUNCTION &
function
336 const IndexType nVis = std::distance(visBegin,visEnd);
337 OPENGM_CHECK_OP(nVis,==,
function.dimension(),
"functions dimension does not match number of variabole indices");
340 OPENGM_CHECK_OP(space_[visBegin[v]],==,
function.shape(v),
"functions shape does not match space");
346 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0]){
347 const ValueType logP = multi_variables_[visBegin[0]]->GetLogPotential(l[0]);
348 const ValueType val = this->valueToMaxSum(
function(l));
349 multi_variables_[visBegin[0]]->SetLogPotential(l[0],logP+val);
356 std::vector<AD3::MultiVariable*> multi_variables_local(nVis);
358 multi_variables_local[v]=multi_variables_[visBegin[v]];
362 std::vector<double> additional_log_potentials(
function.size());
369 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
370 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1]){
371 additional_log_potentials[c]=this->valueToMaxSum(
function(l));
378 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
379 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
380 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2]){
381 additional_log_potentials[c]=this->valueToMaxSum(
function(l));
388 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
389 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
390 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
391 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3]){
392 additional_log_potentials[c]=this->valueToMaxSum(
function(l));
399 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
400 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
401 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
402 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
403 for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4]){
404 additional_log_potentials[c]=this->valueToMaxSum(
function(l));
411 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
412 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
413 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
414 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
415 for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
416 for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5]){
417 additional_log_potentials[c]=this->valueToMaxSum(
function(l));
424 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
425 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
426 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
427 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
428 for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
429 for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5])
430 for(l[6]=0; l[6]<space_[visBegin[6]]; ++l[6]){
431 additional_log_potentials[c]=this->valueToMaxSum(
function(l));
438 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
439 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
440 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
441 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
442 for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
443 for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5])
444 for(l[6]=0; l[6]<space_[visBegin[6]]; ++l[6])
445 for(l[7]=0; l[7]<space_[visBegin[7]]; ++l[7])
447 additional_log_potentials[c]=this->valueToMaxSum(
function(l));
454 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
455 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
456 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
457 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
458 for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
459 for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5])
460 for(l[6]=0; l[6]<space_[visBegin[6]]; ++l[6])
461 for(l[7]=0; l[7]<space_[visBegin[7]]; ++l[7])
462 for(l[8]=0; l[8]<space_[visBegin[8]]; ++l[8])
464 additional_log_potentials[c]=this->valueToMaxSum(
function(l));
471 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
472 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
473 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
474 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
475 for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
476 for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5])
477 for(l[6]=0; l[6]<space_[visBegin[6]]; ++l[6])
478 for(l[7]=0; l[7]<space_[visBegin[7]]; ++l[7])
479 for(l[8]=0; l[8]<space_[visBegin[8]]; ++l[8])
480 for(l[9]=0; l[9]<space_[visBegin[9]]; ++l[9])
482 additional_log_potentials[c]=this->valueToMaxSum(
function(l));
487 throw RuntimeError(
"order must be <=10 for inplace building of Ad3Inf (call us if you need higher order)");
495 factor_graph_.CreateFactorDense(multi_variables_local,additional_log_potentials);
502 template<
class GM,
class ACC>
508 template<
class GM,
class ACC>
515 template<
class GM,
class ACC>
522 template<
class GM,
class ACC>
530 template<
class GM,
class ACC>
531 template<
class VisitorType>
535 visitor.begin(*
this);
538 if(parameter_.solverType_ == AD3_LP || parameter_.solverType_ == AD3_ILP){
539 factor_graph_.SetEtaAD3(parameter_.eta_);
540 factor_graph_.AdaptEtaAD3(parameter_.adaptEta_);
541 factor_graph_.SetMaxIterationsAD3(parameter_.steps_);
542 factor_graph_.SetResidualThresholdAD3(parameter_.residualThreshold_);
544 if(parameter_.solverType_ == PSDD_LP){
545 factor_graph_.SetEtaPSDD(parameter_.eta_);
546 factor_graph_.SetMaxIterationsPSDD(parameter_.steps_);
552 if ( parameter_.solverType_ == AD3_LP){
554 factor_graph_.SolveLPMAPWithAD3(&posteriors_, &additional_posteriors_, &value, &bound_);
556 if ( parameter_.solverType_ == AD3_ILP){
558 factor_graph_.SolveExactMAPWithAD3(&posteriors_, &additional_posteriors_, &value, &bound_);
560 if (parameter_.solverType_ == PSDD_LP){
562 factor_graph_.SolveExactMAPWithAD3(&posteriors_, &additional_posteriors_, &value, &bound_);
566 bound_ =this->valueFromMaxSum(bound_);
570 for(
IndexType vi = 0; vi < numVar_; ++vi) {
572 double bestVal = -100000;
573 const LabelType nLabels = (space_.size()==0 ? gm_.numberOfLabels(vi) : space_[vi] );
575 const double val = posteriors_[c];
578 if(bestVal<0 || val>bestVal){
593 template<
class GM,
class ACC>
596 ::arg(std::vector<LabelType>& arg,
const size_t& n)
const {
602 std::copy(arg_.begin(),arg_.end(),arg.begin());
611 #endif // #ifndef OPENGM_EXTERNAL_AD3Inf_HXX
InferenceTermination arg(std::vector< LabelType > &, const size_t &=1) const
const GraphicalModelType & graphicalModel() const
AD3Inf(const GraphicalModelType &gm, const Parameter para=Parameter())
void addFactor(VI_ITERATOR viBegin, VI_ITERATOR viEnd, const FUNCTION &function)
ValueType bound() const
return a bound on the solution
Parameter(const SolverType solverType=AD3_ILP, const double eta=0.1, const bool adaptEta=true, UInt64Type steps=1000, const double residualThreshold=1e-6, const int verbosity=0)
ValueType valueToMaxSum(const ValueType val) const
detail_types::UInt64Type UInt64Type
uint64
const std::vector< double > & posteriors() const
double residualThreshold_
GraphicalModelType::IndexType IndexType
ValueType value() const
return the solution (value)
visitors::EmptyVisitor< AD3Inf< GM, ACC > > EmptyVisitorType
GraphicalModelType::ValueType ValueType
const std::vector< double > & higherOrderPosteriors() const
Inference algorithm interface.
#define OPENGM_CHECK_OP(A, OP, B, TXT)
visitors::VerboseVisitor< AD3Inf< GM, ACC > > VerboseVisitorType
InferenceTermination infer()
#define OPENGM_CHECK(B, TXT)
visitors::TimingVisitor< AD3Inf< GM, ACC > > TimingVisitorType
GraphicalModelType::LabelType LabelType
ValueType valueFromMaxSum(const ValueType val) const