1 #ifndef OPENGM_EXTERNAL_SRMP_HXX_
2 #define OPENGM_EXTERNAL_SRMP_HXX_
10 #include <srmp/SRMP.h>
11 #include <srmp/FactorTypes/PottsType.h>
12 #include <srmp/FactorTypes/GeneralType.h>
30 struct Parameter :
public srmpLib::Energy::Options {
61 std::string
name()
const;
65 template<
class VISITOR>
68 typename GM::ValueType
bound()
const;
69 typename GM::ValueType
value()
const;
71 const GraphicalModelType& gm_;
78 srmpLib::Energy::Options srmpOptions_;
79 srmpLib::Energy srmpSolver_;
81 std::vector<srmpLib::PottsFactorType*> pottsFactorList_;
82 std::vector<srmpLib::GeneralFactorType*> generalFactorList_;
84 void addUnaryFactor(
const IndexType FactorID);
85 void addPairwiseFactor(
const IndexType FactorID);
86 void addPottsFactor(
const IndexType FactorID);
87 void addGeneralFactor(
const IndexType FactorID);
101 : gm_(gm), parameter_(para), constTerm_(0.0), value_(), lowerBound_(),
102 srmpOptions_(), srmpSolver_(gm_.numberOfVariables()), pottsFactorList_(),
103 generalFactorList_() {
105 for(
IndexType i = 0; i < gm_.numberOfVariables(); ++i) {
106 srmpSolver_.AddNode(gm_.numberOfLabels(i));
110 for(
IndexType i = 0; i < gm_.numberOfFactors(); ++i) {
111 if(gm_[i].numberOfVariables() == 0) {
114 constTerm_ += gm_[i](&l);
115 }
else if(gm_[i].numberOfVariables() == 1) {
118 }
else if(gm_[i].numberOfVariables() == 2) {
119 if(gm_[i].numberOfLabels(0) == gm_[i].numberOfLabels(1) && gm_[i].isPotts()) {
126 addPairwiseFactor(i);
136 srmpOptions_.method = parameter_.method;
137 srmpOptions_.iter_max = parameter_.iter_max;
138 srmpOptions_.time_max = parameter_.time_max;
139 srmpOptions_.eps = parameter_.eps;
140 srmpOptions_.compute_solution_period = parameter_.compute_solution_period;
141 srmpOptions_.print_times = parameter_.print_times;
142 srmpOptions_.sort_flag = parameter_.sort_flag;
143 srmpOptions_.verbose = parameter_.verbose;
144 srmpOptions_.TRWS_weighting = parameter_.TRWS_weighting;
153 for(
size_t i = 0; i < pottsFactorList_.size(); ++i) {
154 delete pottsFactorList_[i];
156 for(
size_t i = 0; i < generalFactorList_.size(); ++i) {
157 delete generalFactorList_[i];
173 EmptyVisitorType visitor;
174 return this->infer(visitor);
178 template<
class VISITOR>
180 visitor.begin(*
this);
182 if (parameter_.BLPRelaxation_) {
183 srmpSolver_.SetMinimalEdges();
184 }
else if (parameter_.FullRelaxation_) {
185 srmpSolver_.SetFullEdges(parameter_.FullRelaxationMethod_);
186 }
else if (parameter_.FullDualRelaxation_) {
187 srmpSolver_.SetFullEdgesDual(parameter_.FullDualRelaxationMethod_);
191 lowerBound_ = srmpSolver_.Solve(srmpOptions_);
192 std::vector<LabelType> l;
194 value_ = gm_.evaluate(l);
206 arg.resize(gm_.numberOfVariables());
207 for(
IndexType i = 0; i < gm_.numberOfVariables(); ++i) {
208 arg[i] = srmpSolver_.GetSolution(i);
216 return lowerBound_ + constTerm_;
227 double* values =
new double[gm_[FactorID].numberOfLabels(0)];
229 for(
LabelType i = 0; i < gm_[FactorID].numberOfLabels(0); ++i) {
230 values[i] =
static_cast<double>(gm_[FactorID](&label));
233 srmpSolver_.AddUnaryFactor(static_cast<srmpLib::Energy::NodeId>(gm_[FactorID].variableIndex(0)), values);
238 inline void SRMP<GM>::addPairwiseFactor(
const IndexType FactorID) {
239 double* values =
new double[gm_[FactorID].numberOfLabels(0) * gm_[FactorID].numberOfLabels(1)];
241 for(
LabelType i = 0; i < gm_[FactorID].numberOfLabels(0); ++i) {
243 for(
LabelType j = 0; j < gm_[FactorID].numberOfLabels(1); ++j) {
245 values[(i * gm_[FactorID].numberOfLabels(1)) + j] = static_cast<double>(gm_[FactorID](labeling));
248 srmpSolver_.AddPairwiseFactor(static_cast<srmpLib::Energy::NodeId>(gm_[FactorID].variableIndex(0)), static_cast<srmpLib::Energy::NodeId>(gm_[FactorID].variableIndex(1)), values);
253 inline void SRMP<GM>::addPottsFactor(
const IndexType FactorID) {
254 ValueType valueEqual;
255 ValueType valueNotEqual;
258 valueEqual = gm_[FactorID](labeling);
259 for(IndexType j = 0; j < 2; ++j) {
260 if(gm_[FactorID].numberOfLabels(j) > 1) {
265 valueNotEqual = gm_[FactorID](labeling);
267 srmpLib::PottsFactorType* pottsFactor =
new srmpLib::PottsFactorType;
268 pottsFactorList_.push_back(pottsFactor);
271 double lambda = valueNotEqual - valueEqual;
272 constTerm_ += valueEqual;
274 srmpLib::Energy::NodeId nodes[2] = {
static_cast<srmpLib::Energy::NodeId
>(gm_[FactorID].variableIndex(0)), static_cast<srmpLib::Energy::NodeId>(gm_[FactorID].variableIndex(1))};
276 srmpSolver_.AddFactor(2, nodes, &lambda, pottsFactor);
280 inline void SRMP<GM>::addGeneralFactor(
const IndexType FactorID) {
281 double* values =
new double[gm_[FactorID].size()];
283 ShapeWalkerSwitchedOrder<typename FactorType::ShapeIteratorType> shapeWalker(gm_[FactorID].shapeBegin(), gm_[FactorID].dimension());
284 for(
size_t i = 0; i < gm_[FactorID].size(); ++i) {
285 values[i] = gm_[FactorID](shapeWalker.coordinateTuple().begin());
289 srmpLib::Energy::NodeId* nodes =
new srmpLib::Energy::NodeId[gm_[FactorID].numberOfVariables()];
290 for(IndexType i = 0; i < gm_[FactorID].numberOfVariables(); ++i) {
291 nodes[i] =
static_cast<srmpLib::Energy::NodeId
>(gm_[FactorID].variableIndex(i));
294 srmpLib::GeneralFactorType* generalFactor =
new srmpLib::GeneralFactorType;
296 srmpSolver_.AddFactor(gm_[FactorID].numberOfVariables(), nodes, values, generalFactor);
InferenceTermination arg(std::vector< LabelType > &arg, const size_t &n=1) const
const GraphicalModelType & graphicalModel() const
int FullDualRelaxationMethod_
opengm::Minimizer AccumulationType
visitors::VerboseVisitor< SRMP< GM > > VerboseVisitorType
int FullRelaxationMethod_
visitors::EmptyVisitor< SRMP< GM > > EmptyVisitorType
GM::ValueType bound() const
return a bound on the solution
GraphicalModelType::IndexType IndexType
GraphicalModelType::ValueType ValueType
static T ineutral()
inverse neutral element (with return)
Inference algorithm interface.
SRMP(const GraphicalModelType &gm, const Parameter para=Parameter())
visitors::TimingVisitor< SRMP< GM > > TimingVisitorType
static T neutral()
neutral element (with return)
GraphicalModelType::LabelType LabelType
Minimization as a unary accumulation.
InferenceTermination infer()
GM::ValueType value() const
return the solution (value)