2 #ifndef OPENGM_ASTAR_HXX
3 #define OPENGM_ASTAR_HXX
26 template<
class FactorType>
struct AStarNode {
27 typename std::vector<typename FactorType::LabelType> conf;
28 typename FactorType::ValueType value;
63 template<
class GM,
class ACC>
74 typedef typename std::vector<LabelType>
ConfVec ;
124 virtual std::string
name()
const {
return "AStar";}
127 virtual void reset();
138 Parameter parameter_;
141 Parameter parameterInitial_;
142 std::vector<AStarNode<IndependentFactorType> > array_;
143 std::vector<size_t> numStates_;
145 std::vector<IndependentFactorType> treeFactor_;
146 std::vector<IndependentFactorType> optimizedFactor_;
147 std::vector<ConfVec > optConf_;
148 std::vector<bool> isTreeFactor_;
151 template<
class VisitorType>
void expand(VisitorType& vistitor);
152 std::vector<ValueType> fastHeuristic(ConfVec conf);
153 inline static bool comp1(
const AStarNode<IndependentFactorType>& a,
const AStarNode<IndependentFactorType>& b)
154 {
return AccumulationType::ibop(a.value,b.value);};
155 inline static bool comp2(
const AStarNode<IndependentFactorType>& a,
const AStarNode<IndependentFactorType>& b)
156 {
return AccumulationType::bop(a.value,b.value);};
169 template<
class GM,
class ACC >
177 parameterInitial_=para;
179 if( parameter_.heuristic_ == Parameter::DEFAULTHEURISTIC) {
180 if(gm_.factorOrder()<=2)
181 parameter_.
heuristic_ = Parameter::FASTHEURISTIC;
183 parameter_.heuristic_ = Parameter::STANDARDHEURISTIC;
185 OPENGM_ASSERT(parameter_.heuristic_ == Parameter::FASTHEURISTIC || parameter_.heuristic_ == Parameter::STANDARDHEURISTIC);
186 ACC::ineutral(belowBound_);
187 ACC::neutral(aboveBound_);
189 isTreeFactor_.resize(gm_.numberOfFactors());
190 numStates_.resize(gm_.numberOfVariables());
191 numNodes_ = gm_.numberOfVariables();
192 for(
size_t i=0; i<numNodes_;++i)
193 numStates_[i] = gm_.numberOfLabels(i);
195 if(parameter_.nodeOrder_.size()==0) {
196 parameter_.nodeOrder_.resize(numNodes_);
197 std::vector<std::pair<IndexType,IndexType> > tmp(numNodes_,std::pair<IndexType,IndexType>());
198 for(
size_t i=0; i<numNodes_; ++i){
199 tmp[i].first = gm_.numberOfFactors(i);
202 std::sort(tmp.begin(),tmp.end());
203 for(
size_t i=0; i<numNodes_; ++i){
204 parameter_.nodeOrder_[i] = tmp[numNodes_-i-1].second;
208 if(parameter_.nodeOrder_.size()!=numNodes_)
209 throw RuntimeError(
"The node order does not fit to the model.");
210 OPENGM_ASSERT(std::set<size_t>(parameter_.nodeOrder_.begin(), parameter_.nodeOrder_.end()).size()==numNodes_);
211 for(
size_t i=0;i<numNodes_; ++i) {
216 if(parameter_.treeFactorIds_.size()==0) {
218 for(
size_t i=0; i<gm_.numberOfFactors(); ++i) {
219 if((gm_[i].numberOfVariables()==2) &&
220 (gm_[i].variableIndex(0)==parameter_.nodeOrder_.back() || gm_[i].variableIndex(1)==parameter_.nodeOrder_.back())
222 parameter_.addTreeFactorId(i);
225 for(
size_t i=0; i<parameter_.treeFactorIds_.size(); ++i)
226 OPENGM_ASSERT(gm_.numberOfFactors() > parameter_.treeFactorIds_[i]);
228 optimizedFactor_.resize(gm_.numberOfFactors());
229 for(
size_t i=0; i<gm_.numberOfFactors(); ++i) {
230 if(gm_[i].numberOfVariables()<=1)
continue;
231 std::vector<size_t> index(gm_[i].numberOfVariables());
232 gm_[i].variableIndices(index.begin());
233 optimizedFactor_[i].assign(gm_ ,index.end()-1, index.end());
234 opengm::accumulate<ACC>(gm[i],index.begin()+1,index.end(),optimizedFactor_[i]);
237 OPENGM_ASSERT(optimizedFactor_[i].variableIndex(0) == index[0]);
240 AStarNode<IndependentFactorType> a;
244 make_heap(array_.begin(), array_.end(), comp1);
246 if(parameter_.heuristic_ == parameter_.FASTHEURISTIC) {
247 for(
size_t i=0; i<parameter_.treeFactorIds_.size(); ++i) {
248 if(gm_[parameter_.treeFactorIds_[i]].numberOfVariables() > 2) {
249 throw RuntimeError(
"The heuristic includes factor of order > 2.");
255 for(
size_t i=0; i<gm_.numberOfFactors(); ++i)
256 isTreeFactor_[i] =
false;
257 for(
size_t i=0; i<parameter_.treeFactorIds_.size(); ++i) {
258 int factorId = parameter_.treeFactorIds_[i];
259 isTreeFactor_[factorId] =
true;
260 treeFactor_.push_back(gm_[factorId]);
270 template<
class GM,
class ACC >
277 template <
class GM,
class ACC>
287 template<
class GM,
class ACC>
288 template<
class VisitorType>
293 visitor.begin(*
this);
294 while(array_.size()>0 && exitFlag==0) {
295 if(parameter_.numberOfOpt_ == optConf_.size()) {
299 while(array_.front().conf.size() < numNodes_ && exitFlag==0) {
301 belowBound_ = array_.front().value;
302 exitFlag = visitor(*
this);
305 if(array_.front().conf.size()>=numNodes_){
308 std::vector<LabelType> conf(numNodes_);
309 for(
size_t n=0; n<numNodes_; ++n) {
310 conf[parameter_.nodeOrder_[n]] = array_.front().conf[n];
312 optConf_.push_back(conf);
314 if(ACC::bop(parameter_.objectiveBound_, value)) {
319 pop_heap(array_.begin(), array_.end(), comp1);
326 template<
class GM,
class ACC>
329 if(optConf_.size()>=1){
330 return gm_.evaluate(optConf_[0]);
333 return ACC::template neutral<ValueType>();
337 template<
class GM,
class ACC>
339 ::arg(ConfVec& conf,
const size_t n)
const
341 if(n>optConf_.size()) {
342 conf.resize(gm_.numberOfVariables(),0);
354 template<
class GM,
class ACC>
362 template<
class GM,
class ACC>
363 template<
class VisitorType>
367 if(array_.size()>parameter_.maxHeapSize_*0.99) {
368 partial_sort(array_.begin(), array_.begin()+(int)(parameter_.maxHeapSize_/2), array_.end(), comp2);
369 array_.resize((
int)(parameter_.maxHeapSize_/2));
372 AStarNode<IndependentFactorType> a = array_.front();
373 size_t subconfsize = a.conf.size();
376 pop_heap(array_.begin(), array_.end(), comp1);
378 if( parameter_.heuristic_ == parameter_.STANDARDHEURISTIC) {
383 typedef typename meta::TypeListGenerator< ExplicitFunction<ValueType,IndexType,LabelType>, ViewFixVariablesFunction<GM>, ViewFunction<GM>, ConstantFunction<ValueType, IndexType, LabelType> >::type MFunctionTypeList;
384 typedef GraphicalModel<ValueType, typename GM::OperatorType, MFunctionTypeList, MSpaceType> MGM;
386 IndexType numberOfVariables = 0;
387 std::vector<IndexType> varMap(gm_.numberOfVariables(),0);
388 std::vector<LabelType> fixVariableLabel(gm_.numberOfVariables(),0);
389 std::vector<bool> fixVariable(gm_.numberOfVariables(),
false);
390 for(
size_t i =0; i<subconfsize ; ++i) {
391 fixVariableLabel[parameter_.nodeOrder_[i]] = a.conf[i];
392 fixVariable[parameter_.nodeOrder_[i]] =
true;
395 for(IndexType var=0; var<gm_.numberOfVariables();++var){
396 if(fixVariable[var]==
false){
397 varMap[var] = numberOfVariables++;
400 std::vector<LabelType> shape(numberOfVariables,0);
401 for(IndexType var=0; var<gm_.numberOfVariables();++var){
402 if(fixVariable[var]==
false){
403 shape[varMap[var]] = gm_.numberOfLabels(var);
406 MSpaceType space(shape.begin(),shape.end());
409 std::vector<PositionAndLabel<IndexType,LabelType> > fixedVars;
410 std::vector<IndexType> MVars;
412 GM::OperatorType::neutral(constant);
414 for(IndexType f=0; f<gm_.numberOfFactors();++f){
417 for(IndexType i=0; i<gm_[f].numberOfVariables(); ++i){
418 const IndexType var = gm_[f].variableIndex(i);
419 if(fixVariable[var]){
420 fixedVars.push_back(PositionAndLabel<IndexType,LabelType>(i,fixVariableLabel[var]));
422 MVars.push_back(varMap[var]);
425 if(fixedVars.size()==gm_[f].numberOfVariables()){
426 std::vector<LabelType> fixedStates(gm_[f].numberOfVariables(),0);
427 for(IndexType i=0; i<gm_[f].numberOfVariables(); ++i){
428 fixedStates[i]=fixVariableLabel[ gm_[f].variableIndex(i)];
430 GM::OperatorType::op(gm_[f](fixedStates.begin()),constant);
432 if(MVars.size()<2 || isTreeFactor_[f]){
433 const ViewFixVariablesFunction<GM> func(gm_[f], fixedVars);
434 mgm.addFactor(mgm.addFunction(func),MVars.begin(), MVars.end());
436 std::vector<IndexType> variablesIndices(optimizedFactor_[f].numberOfVariables());
437 for(
size_t i=0; i<variablesIndices.size(); ++i)
438 variablesIndices[i] = varMap[optimizedFactor_[f].variableIndex(i)];
439 LabelType numberOfLabels = optimizedFactor_[f].numberOfLabels(0);
441 for(
LabelType i=0; i<numberOfLabels; ++i)
442 func(i) = optimizedFactor_[f](i);
443 mgm.addFactor(mgm.addFunction(func),variablesIndices.begin(),variablesIndices.end() );
444 OPENGM_ASSERT(mgm[mgm.numberOfFactors()-1].numberOfVariables()==1);
450 ConstantFunction<ValueType, IndexType, LabelType> func(&temp, &temp, constant);
451 mgm.addFactor(mgm.addFunction(func),MVars.begin(), MVars.begin());
454 typename MessagePassing<MGM, ACC, UpdateRules, opengm::MaxDistance>::Parameter bpPara;
456 bpPara.maximumNumberOfSteps_ = mgm.numberOfVariables();
458 MessagePassing<MGM, ACC, UpdateRules, opengm::MaxDistance> bp(mgm,bpPara);
463 throw RuntimeError(
"bp failed in astar");
465 ACC::op(bp.value(),aboveBound_,aboveBound_);
466 std::vector<LabelType> conf(mgm.numberOfVariables());
468 std::vector<IndexType> theVar(1, varMap[parameter_.nodeOrder_[subconfsize]]);
470 std::vector<LabelType> theLabel(1,0);
471 a.conf.resize(subconfsize+1);
472 for(
size_t i=0; i<numStates_[parameter_.nodeOrder_[subconfsize]]; ++i) {
473 a.conf[subconfsize] = i;
475 bp.constrainedOptimum(theVar,theLabel,conf);
476 a.value = mgm.evaluate(conf);
478 push_heap(array_.begin(), array_.end(), comp1);
481 if( parameter_.heuristic_ == parameter_.FASTHEURISTIC) {
482 std::vector<LabelType> conf(subconfsize);
483 for(
size_t i=0;i<subconfsize;++i)
485 std::vector<ValueType> bound = fastHeuristic(conf);
486 a.conf.resize(subconfsize+1);
487 for(
size_t i=0; i<numStates_[parameter_.nodeOrder_[subconfsize]]; ++i) {
488 a.conf[subconfsize] = i;
492 push_heap(array_.begin(), array_.end(), comp1);
498 template<
class GM,
class ACC>
499 std::vector<typename AStar<GM, ACC>::ValueType>
500 AStar<GM, ACC>::fastHeuristic(
typename AStar<GM, ACC>::ConfVec conf)
502 std::list<size_t> factorList;
503 std::vector<size_t> nodeDegree(numNodes_,0);
504 std::vector<int> nodeLabel(numNodes_,-1);
505 std::vector<std::vector<ValueType > > nodeEnergy(numNodes_);
506 size_t nextNode = parameter_.nodeOrder_[conf.size()];
507 for(
size_t i=0; i<numNodes_; ++i) {
508 nodeEnergy[i].resize(numStates_[i]);
509 for(
size_t j=0;j<numStates_[i];++j)
510 OperatorType::neutral(nodeEnergy[i][j]);
512 for(
size_t i=0;i<conf.size();++i) {
513 nodeLabel[parameter_.nodeOrder_[i]] = conf[i];
519 for(
size_t i=0; i<gm_.numberOfFactors(); ++i) {
520 const FactorType & f = gm_[i];
521 size_t nvar = f.numberOfVariables();
524 int index = f.variableIndex(0);
525 if(nodeLabel[index]>=0) {
526 nodeEnergy[index].resize(1);
529 OperatorType::op(f(coordinates), nodeEnergy[index][0]);
532 OPENGM_ASSERT(numStates_[index] == nodeEnergy[index].size());
533 for(
size_t j=0;j<numStates_[index];++j) {
536 OperatorType::op(f(coordinates),nodeEnergy[index][j]);
541 size_t index1 = f.variableIndex(0);
542 size_t index2 = f.variableIndex(1);
543 if(nodeLabel[index1]>=0) {
544 if(nodeLabel[index2]>=0) {
545 nodeEnergy[index1].resize(1);
548 static_cast<LabelType>(nodeLabel[index1]),
549 static_cast<LabelType>(nodeLabel[index2])
551 OperatorType::op(f(coordinates),nodeEnergy[index1][0]);
554 OPENGM_ASSERT(numStates_[index2] == nodeEnergy[index2].size());
555 for(
size_t j=0;j<numStates_[index2];++j) {
558 static_cast<LabelType>(nodeLabel[index1]),
559 static_cast<LabelType>(j)
561 OperatorType::op(f(coordinates), nodeEnergy[index2][j]);
565 else if(nodeLabel[index2]>=0) {
566 OPENGM_ASSERT(numStates_[index1] == nodeEnergy[index1].size());
567 for(
size_t j=0;j<numStates_[index1];++j) {
571 static_cast<LabelType>(nodeLabel[index2])
573 OperatorType::op(f(coordinates),nodeEnergy[index1][j]);
576 else if(isTreeFactor_[i]) {
577 factorList.push_front(i);
578 ++nodeDegree[index1];
579 ++nodeDegree[index2];
583 for(
size_t j=0;j<numStates_[index1];++j) {
586 OperatorType::op(optimizedFactor_[i](coordinates), nodeEnergy[index1][j]);
592 std::vector<size_t> state(nvar);
593 for(
size_t j=0; j<nvar; ++j) {
594 if(nodeLabel[f.variableIndex(j)]<0) {
595 state[j] = nodeLabel[f.variableIndex(j)];
600 nodeEnergy[f.variableIndex(0)][0] = f(state.begin());
602 for(
size_t j=0;j<numStates_[f.variableIndex(0)];++j) {
605 OperatorType::op(optimizedFactor_[i](coordinates), nodeEnergy[f.variableIndex(0)][j]);
610 nodeDegree[nextNode] += numNodes_;
612 while(factorList.size()>0) {
613 size_t id = factorList.front();
614 factorList.pop_front();
615 const FactorType & f = gm_[id];
616 size_t index1 = f.variableIndex(0);
617 size_t index2 = f.variableIndex(1);
618 typename FactorType::ValueType temp;
621 OPENGM_ASSERT(gm_.numberOfLabels(index1) == numStates_[index1]);
622 OPENGM_ASSERT(gm_.numberOfLabels(index2) == numStates_[index2]);
623 if(nodeDegree[index1]==1) {
624 typename FactorType::ValueType min;
625 OPENGM_ASSERT(numStates_[index2] == nodeEnergy[index2].size());
626 for(
size_t j2=0;j2<numStates_[index2];++j2) {
628 OPENGM_ASSERT(numStates_[index1] == nodeEnergy[index1].size());
629 for(
size_t j1=0;j1<numStates_[index1];++j1) {
631 OperatorType::op(f(coordinates),nodeEnergy[index1][j1],temp);
632 ACC::op(min,temp,min);
635 OperatorType::op(min,nodeEnergy[index2][j2]);
637 --nodeDegree[index1];
638 --nodeDegree[index2];
639 nodeEnergy[index1].resize(1);
640 OperatorType::neutral(nodeEnergy[index1][0]);
642 else if(nodeDegree[index2]==1) {
643 typename FactorType::ValueType min;
644 OPENGM_ASSERT(numStates_[index1] == nodeEnergy[index1].size());
645 for(
size_t j1=0;j1<numStates_[index1];++j1) {
647 OPENGM_ASSERT(numStates_[index2] == nodeEnergy[index2].size());
648 for(
size_t j2=0;j2<numStates_[index2];++j2) {
650 OperatorType::op(f(coordinates),nodeEnergy[index2][j2],temp);
651 ACC::op(min,temp,min);
655 OperatorType::op(min,nodeEnergy[index1][j1]);
657 --nodeDegree[index1];
658 --nodeDegree[index2];
659 nodeEnergy[index2].resize(1);
660 OperatorType::neutral(nodeEnergy[index2][0]);
663 factorList.push_back(
id);
669 OperatorType::neutral(result);
670 std::vector<ValueType > bound;
671 for(
size_t i=0;i<numNodes_;++i) {
672 if(i==nextNode)
continue;
674 for(
size_t j=0; j<nodeEnergy[i].size();++j)
675 ACC::op(min,nodeEnergy[i][j],min);
677 OperatorType::op(min,result);
679 bound.resize(nodeEnergy[nextNode].size());
680 for(
size_t j=0; j<nodeEnergy[nextNode].size();++j) {
682 OperatorType::op(nodeEnergy[nextNode][j],result,bound[j]);
687 template<
class GM,
class ACC>
688 inline const typename AStar<GM, ACC>::GraphicalModelType&
696 #endif // #ifndef OPENGM_ASTAR_HXX
ValueType bound() const
return a bound on the solution
Update rules for the MessagePassing framework.
opengm::visitors::EmptyVisitor< AStar< GM, ACC > > EmptyVisitorType
static const size_t DEFAULTHEURISTIC
DEFAULTHEURISTIC ;.
opengm::visitors::VerboseVisitor< AStar< GM, ACC > > VerboseVisitorType
visitor
Discrete space in which variables can have differently many labels.
virtual std::string name() const
opengm::visitors::TimingVisitor< AStar< GM, ACC > > TimingVisitorType
std::vector< LabelType > ConfVec
configuration vector type
AStar(const GM &gm, Parameter para=Parameter())
constructor
virtual InferenceTermination arg(std::vector< LabelType > &v, const size_t=1) const
output a solution
size_t numberOfOpt_
number od N-best solutions that should be found
virtual InferenceTermination args(std::vector< std::vector< LabelType > > &v) const
args
#define OPENGM_ASSERT(expression)
static const size_t STANDARDHEURISTIC
STANDARDHEURISTIC.
ValueType objectiveBound_
objective bound
void addTreeFactorId(size_t id)
constuctor
ACC AccumulationType
accumulation type
ValueType value() const
return the solution (value)
const GraphicalModelType & graphicalModel() const
std::vector< size_t > treeFactorIds_
size_t heuristic_
heuritstic
GraphicalModelType::ValueType ValueType
ConfVec::iterator ConfVecIt
configuration iterator
virtual InferenceTermination marginal(const size_t, IndependentFactorType &out) const
output a solution for a marginal for a specific variable
GM GraphicalModelType
graphical model type
Inference algorithm interface.
size_t maxHeapSize_
maxHeapSize_ maximum size of the heap
virtual void reset()
reset
static const size_t FASTHEURISTIC
FASTHEURISTIC.
virtual InferenceTermination infer()
std::vector< IndexType > nodeOrder_
GraphicalModelType::LabelType LabelType
virtual InferenceTermination factorMarginal(const size_t, IndependentFactorType &out) const
output a solution for a marginal for all variables connected to a factor
GraphicalModelType::IndependentFactorType IndependentFactorType