6 #ifndef OPENGM_EXTERNAL_TRWS_HXX
7 #define OPENGM_EXTERNAL_TRWS_HXX
17 #include "MRFEnergy.h"
18 #include "instances.h"
19 #include "MRFEnergy.cpp"
20 #include "minimize.cpp"
21 #include "treeProbabilities.cpp"
22 #include "ordering.cpp"
70 numberOfIterations_ = 1000;
71 useRandomStart_ =
false;
72 useZeroStart_ =
false;
76 minDualChange_ = 0.00001;
77 calculateMinMarginals_ =
false;
81 TRWS(
const GraphicalModelType& gm,
const Parameter para = Parameter());
85 std::string
name()
const;
88 template<
class VISITOR>
93 typename GM::ValueType
bound()
const;
94 typename GM::ValueType
value()
const;
96 const GraphicalModelType& gm_;
108 TypeGeneral::REAL* minMarginals_;
109 size_t* minMarginalsOffsets_;
115 std::vector<LabelType> state_;
118 bool hasSameLabelNumber_;
119 void checkLabelNumber();
121 void generateMRFView();
122 void generateMRFTables();
123 void generateMRFTL1();
124 void generateMRFTL2();
130 bool truncatedAbsoluteDifferenceFactors()
const;
133 bool truncatedSquaredDifferenceFactors()
const;
135 template <
class ENERGYTYPE>
138 template<
class VISITOR,
class ENERGYTYPE>
142 template<
class GM,
class ENERGYTYPE>
144 static void*
create(
typename GM::IndexType numLabels);
149 static void*
create(
typename GM::IndexType numLabels);
154 static void*
create(
typename GM::IndexType numLabels);
159 static void*
create(
typename GM::IndexType numLabels);
164 static void*
create(
typename GM::IndexType numLabels);
167 template<
class GM,
class ENERGYTYPE>
197 : gm_(gm), parameter_(para), mrfView_(NULL), nodesView_(NULL), mrfGeneral_(NULL), nodesGeneral_(NULL),
198 mrfTL1_(NULL), nodesTL1_(NULL), mrfTL2_(NULL), nodesTL2_(NULL), minMarginals_(NULL), minMarginalsOffsets_(NULL),
199 numNodes_(gm_.numberOfVariables()), maxNumLabels_(gm_.numberOfLabels(0)) {
204 minMarginalsOffsets_ =
new size_t[gm_.numberOfVariables()];
205 for(
size_t i=0; i<gm_.numberOfVariables(); ++i){
206 minMarginalsOffsets_[i] = count;
207 count += gm_.numberOfLabels(i);
209 minMarginals_ =
new TypeGeneral::REAL[count];
223 if(!hasSameLabelNumber_) {
224 throw(
RuntimeError(
"TRWS TL1 only supports graphical models where each variable has the same number of states."));
230 if(!hasSameLabelNumber_) {
231 throw(
RuntimeError(
"TRWS TL2 only supports graphical models where each variable has the same number of states."));
263 delete[] nodesGeneral_;
280 delete[] minMarginals_;
281 delete[] minMarginalsOffsets_;
304 EmptyVisitorType visitor;
305 return this->infer(visitor);
309 template<
class VISITOR>
315 switch(parameter_.energyType_) {
317 return inferImpl(visitor, mrfView_);
321 return inferImpl(visitor, mrfGeneral_);
325 return inferImpl(visitor, mrfTL1_);
329 return inferImpl(visitor, mrfTL2_);
346 std::vector<LabelType>& arg,
354 arg.resize(numNodes_);
355 switch(parameter_.energyType_) {
357 for(
IndexType i = 0; i < numNodes_; i++) {
358 arg[i] = mrfView_->GetSolution(nodesView_[i]);
364 for(
IndexType i = 0; i < numNodes_; i++) {
365 arg[i] = mrfGeneral_->GetSolution(nodesGeneral_[i]);
371 for(
IndexType i = 0; i < numNodes_; i++) {
372 arg[i] = mrfTL1_->GetSolution(nodesTL1_[i]);
378 for(
IndexType i = 0; i < numNodes_; i++) {
379 arg[i] = mrfTL2_->GetSolution(nodesTL2_[i]);
404 const size_t variableIndex,
409 if(parameter_.calculateMinMarginals_){
410 out.assign(gm_, &variableIndex, &variableIndex+1, 0);
411 for(
size_t i=0; i<gm_.numberOfLabels(variableIndex); ++i){
412 out(i) = minMarginals_[i+minMarginalsOffsets_[variableIndex]];
421 inline typename GM::ValueType
423 return lowerBound_+constTerm_;
426 inline typename GM::ValueType
428 return value_+constTerm_;
433 hasSameLabelNumber_ =
true;
434 for(IndexType i = 1; i < gm_.numberOfVariables(); i++) {
435 if(gm_.numberOfLabels(i) != maxNumLabels_) {
436 hasSameLabelNumber_ =
false;
438 if(gm_.numberOfLabels(i) > maxNumLabels_) {
439 maxNumLabels_ = gm_.numberOfLabels(i);
445 inline void TRWS<GM>::generateMRFView() {
450 for(IndexType i = 0; i < numNodes_; i++) {
451 std::vector<typename GM::IndexType> factors;
452 for(
typename GM::ConstFactorIterator iter = gm_.factorsOfVariableBegin(i); iter != gm_.factorsOfVariableEnd(i); iter++) {
453 if(gm_[*iter].numberOfVariables() == 1) {
454 factors.push_back(*iter);
462 for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
463 if(gm_[i].numberOfVariables() == 0){
465 constTerm_ += gm_[i](&l);
467 if(gm_[i].numberOfVariables() == 2) {
468 IndexType a = gm_[i].variableIndex(0);
469 IndexType b = gm_[i].variableIndex(1);
474 if(parameter_.useRandomStart_) {
475 mrfView_->AddRandomMessages(1, 0.0, 1.0);
476 }
else if(parameter_.useZeroStart_) {
477 mrfView_->ZeroMessages();
482 inline void TRWS<GM>::generateMRFTables() {
484 typename TypeGeneral::REAL* D =
new typename TypeGeneral::REAL[maxNumLabels_];
485 addNodes(mrfGeneral_, nodesGeneral_, D);
491 for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
492 if(gm_[i].numberOfVariables() == 0){
494 constTerm_ += gm_[i](&l);
496 if(gm_[i].numberOfVariables() == 2) {
497 IndexType a = gm_[i].variableIndex(0);
498 IndexType b = gm_[i].variableIndex(1);
499 IndexType numLabels_a = gm_.numberOfLabels(a);
500 IndexType numLabels_b = gm_.numberOfLabels(b);
501 typename TypeGeneral::REAL* V =
new typename TypeGeneral::REAL[numLabels_a * numLabels_b];
502 for(
size_t j = 0; j < numLabels_a; j++) {
503 for(
size_t k = 0; k < numLabels_b; k++) {
506 V[j + k * numLabels_a] = gm_[i](index);
509 mrfGeneral_->AddEdge(nodesGeneral_[a], nodesGeneral_[b], TypeGeneral::EdgeData(TypeGeneral::GENERAL, V));
515 if(parameter_.useRandomStart_) {
516 mrfGeneral_->AddRandomMessages(1, 0.0, 1.0);
517 }
else if(parameter_.useZeroStart_) {
518 mrfGeneral_->ZeroMessages();
523 inline void TRWS<GM>::generateMRFTL1() {
527 typename TypeTruncatedLinear::REAL* D =
new typename TypeTruncatedLinear::REAL[maxNumLabels_];
528 addNodes(mrfTL1_, nodesTL1_, D);
533 for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
534 if(gm_[i].numberOfVariables() == 0){
536 constTerm_ += gm_[i](&l);
538 if(gm_[i].numberOfVariables() == 2) {
540 ValueType t = getT(i);
544 IndexType index[] = {0, 1};
545 ValueType w = gm_[i](index);
549 IndexType a = gm_[i].variableIndex(0);
550 IndexType b = gm_[i].variableIndex(1);
551 mrfTL1_->AddEdge(nodesTL1_[a], nodesTL1_[b], TypeTruncatedLinear::EdgeData(w, w * t));
556 if(parameter_.useRandomStart_) {
557 mrfTL1_->AddRandomMessages(1, 0.0, 1.0);
558 }
else if(parameter_.useZeroStart_) {
559 mrfTL1_->ZeroMessages();
564 inline void TRWS<GM>::generateMRFTL2() {
568 typename TypeTruncatedQuadratic::REAL* D =
new typename TypeTruncatedQuadratic::REAL[maxNumLabels_];
569 addNodes(mrfTL2_, nodesTL2_, D);
574 for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
575 if(gm_[i].numberOfVariables() == 0){
577 constTerm_ += gm_[i](&l);
579 if(gm_[i].numberOfVariables() == 2) {
581 ValueType t = getT(i);
585 IndexType index[] = {0, 1};
586 ValueType w = gm_[i](index);
590 IndexType a = gm_[i].variableIndex(0);
591 IndexType b = gm_[i].variableIndex(1);
592 mrfTL2_->AddEdge(nodesTL2_[a], nodesTL2_[b], TypeTruncatedQuadratic::EdgeData(w, w * t));
599 if(parameter_.useRandomStart_) {
600 mrfTL2_->AddRandomMessages(1, 0.0, 1.0);
601 }
else if(parameter_.useZeroStart_) {
602 mrfTL2_->ZeroMessages();
612 inline typename GM::ValueType TRWS<GM>::getT(IndexType factor)
const {
615 IndexType index1[] = {0, 1};
616 IndexType index0[] = {0, maxNumLabels_-1};
618 return gm_[factor](index0)/gm_[factor](index1);
622 inline bool TRWS<GM>::truncatedAbsoluteDifferenceFactors()
const {
623 for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
624 if(gm_.numberOfVariables(i) == 2) {
625 if(gm_[i].isTruncatedAbsoluteDifference() ==
false) {
634 inline bool TRWS<GM>::truncatedSquaredDifferenceFactors()
const {
635 for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
636 if(gm_.numberOfVariables(i) == 2) {
637 if(gm_[i].isTruncatedSquaredDifference() ==
false) {
646 template <
class ENERGYTYPE>
652 for(IndexType i = 0; i < numNodes_; i++) {
653 for(IndexType j = 0; j < gm_.numberOfLabels(i); j++) {
656 for(
typename GM::ConstFactorIterator iter = gm_.factorsOfVariableBegin(i); iter != gm_.factorsOfVariableEnd(i); iter++) {
657 if(gm_[*iter].numberOfVariables() == 1) {
658 for(IndexType j = 0; j < gm_.numberOfLabels(i); j++) {
659 D[j] += gm_[*iter](&j);
667 template<
class GM,
class ENERGYTYPE>
693 template<
class GM,
class ENERGYTYPE>
706 return mrf->AddNode(
typename TypeGeneral::LocalSize(numLabels),
typename TypeGeneral::NodeData(D));
711 return mrf->AddNode(
typename TypeTruncatedLinear::LocalSize(),
typename TypeTruncatedLinear::NodeData(D));
716 return mrf->AddNode(
typename TypeTruncatedQuadratic::LocalSize(),
typename TypeTruncatedQuadratic::NodeData(D));
720 template<
class VISITOR,
class ENERGYTYPE>
723 options.m_iterMax = 1;
724 options.m_printIter = 2 * parameter_.numberOfIterations_;
725 visitor.begin(*
this);
728 if(parameter_.doBPS_) {
729 typename ENERGYTYPE::REAL v;
730 for(
size_t i = 0; i < parameter_.numberOfIterations_; ++i) {
731 mrf->Minimize_BP(options, v, minMarginals_);
738 typename ENERGYTYPE::REAL v;
739 typename ENERGYTYPE::REAL b;
740 typename ENERGYTYPE::REAL d;
741 for(
size_t i = 0; i < parameter_.numberOfIterations_; ++i) {
742 mrf->Minimize_TRW_S(options, b, v, minMarginals_);
749 if(fabs(value_ - lowerBound_) /
opengmMax(static_cast<double>(fabs(value_)), 1.0) < parameter_.tolerance_) {
752 if(d<parameter_.minDualChange_){
766 #endif // #ifndef OPENGM_EXTERNAL_TRWS_HXX
const GraphicalModelType & graphicalModel() const
bool calculateMinMarginals_
Calculate MinMarginals.
visitors::EmptyVisitor< TRWS< GM > > EmptyVisitorType
bool useZeroStart_
zero starting message
InferenceTermination marginal(const size_t variableIndex, IndependentFactorType &out) const
output a solution for a marginal for a specific variable
T opengmMax(const T &x, const T &y)
InferenceTermination infer()
bool useRandomStart_
random starting message
void create(const hid_t &, const std::string &, ShapeIterator, ShapeIterator, CoordinateOrder)
Create and close an HDF5 dataset to store Marray data.
#define OPENGM_ASSERT(expression)
InferenceTermination arg(std::vector< LabelType > &, const size_t &=1) const
GM::ValueType bound() const
return a bound on the solution
EnergyType
possible energy types for TRWS
static MRFEnergy< ENERGYTYPE >::NodeId add(MRFEnergy< ENERGYTYPE > *mrf, typename GM::IndexType numLabels, typename ENERGYTYPE::REAL *D)
double minDualChange_
TRWS termintas if fabs(bound(t)-bound(t+1)) < minDualChange_.
GraphicalModelType::IndexType IndexType
GraphicalModelType::ValueType ValueType
static T ineutral()
inverse neutral element (with return)
Inference algorithm interface.
double tolerance_
TRWS termintas if fabs(value - bound) / max(fabs(value), 1) < trwsTolerance_.
size_t numberOfIterations_
number of iterations
visitors::VerboseVisitor< TRWS< GM > > VerboseVisitorType
static T neutral()
neutral element (with return)
TRWS(const GraphicalModelType &gm, const Parameter para=Parameter())
GM::ValueType value() const
return the solution (value)
Minimization as a unary accumulation.
static void * create(typename GM::IndexType numLabels)
message passing (BPS, TRWS): [?]
static const size_t ContinueInf
bool doBPS_
use normal LBP
EnergyType energyType_
selected energy type
opengm::Minimizer AccumulationType
visitors::TimingVisitor< TRWS< GM > > TimingVisitorType
GraphicalModelType::IndependentFactorType IndependentFactorType