Skip to content

Commit 9db8d95

Browse files
authored
Merge pull request #2120 from borglab/docs-hybrid
2 parents 2e353e8 + 045a5f5 commit 9db8d95

39 files changed

+5490
-158
lines changed

gtsam/discrete/DiscreteValues.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,6 @@ string DiscreteValues::html(const KeyFormatter& keyFormatter,
152152
return ss.str();
153153
}
154154

155-
/* ************************************************************************ */
156-
void PrintDiscreteValues(const DiscreteValues& values, const std::string& s,
157-
const KeyFormatter& keyFormatter) {
158-
values.print(s, keyFormatter);
159-
}
160-
161155
string markdown(const DiscreteValues& values, const KeyFormatter& keyFormatter,
162156
const DiscreteValues::Names& names) {
163157
return values.markdown(keyFormatter, names);

gtsam/discrete/DiscreteValues.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,6 @@ inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
197197
return DiscreteValues::CartesianProduct(keys);
198198
}
199199

200-
/// Free version of print for wrapper
201-
void GTSAM_EXPORT
202-
PrintDiscreteValues(const DiscreteValues& values, const std::string& s = "",
203-
const KeyFormatter& keyFormatter = DefaultKeyFormatter);
204-
205200
/// Free version of markdown.
206201
std::string GTSAM_EXPORT
207202
markdown(const DiscreteValues& values,

gtsam/discrete/discrete.i

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@ class DiscreteKeys {
2323
std::vector<gtsam::DiscreteValues> cartesianProduct(
2424
const gtsam::DiscreteKeys& keys);
2525

26-
void PrintDiscreteValues(
27-
const gtsam::DiscreteValues& values, const std::string& s = "",
28-
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
29-
3026
string markdown(
3127
const gtsam::DiscreteValues& values,
3228
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);

gtsam/discrete/tests/testDiscreteBayesTree.cpp

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <gtsam/discrete/DiscreteBayesNet.h>
2222
#include <gtsam/discrete/DiscreteBayesTree.h>
2323
#include <gtsam/discrete/DiscreteFactorGraph.h>
24+
#include <gtsam/base/TestableAssertions.h>
2425

2526
#include <CppUnitLite/TestHarness.h>
2627

@@ -290,30 +291,31 @@ TEST(DiscreteBayesTree, Dot) {
290291
std::string actual = self.bayesTree->dot();
291292
// print actual:
292293
if (debug) std::cout << actual << std::endl;
293-
EXPECT(actual ==
294-
"digraph G{\n"
295-
"0[label=\"13, 11, 6, 7\"];\n"
296-
"0->1\n"
297-
"1[label=\"14 : 11, 13\"];\n"
298-
"1->2\n"
299-
"2[label=\"9, 12 : 14\"];\n"
300-
"2->3\n"
301-
"3[label=\"3 : 9, 12\"];\n"
302-
"2->4\n"
303-
"4[label=\"2 : 9, 12\"];\n"
304-
"2->5\n"
305-
"5[label=\"8 : 12, 14\"];\n"
306-
"5->6\n"
307-
"6[label=\"1 : 8, 12\"];\n"
308-
"5->7\n"
309-
"7[label=\"0 : 8, 12\"];\n"
310-
"1->8\n"
311-
"8[label=\"10 : 13, 14\"];\n"
312-
"8->9\n"
313-
"9[label=\"5 : 10, 13\"];\n"
314-
"8->10\n"
315-
"10[label=\"4 : 10, 13\"];\n"
316-
"}");
294+
std::string expected =
295+
R"(digraph G{
296+
13[label="13, 11, 6, 7"];
297+
13->14
298+
14[label="14 : 11, 13"];
299+
14->9
300+
9[label="9, 12 : 14"];
301+
9->3
302+
3[label="3 : 9, 12"];
303+
9->2
304+
2[label="2 : 9, 12"];
305+
9->8
306+
8[label="8 : 12, 14"];
307+
8->1
308+
1[label="1 : 8, 12"];
309+
8->0
310+
0[label="0 : 8, 12"];
311+
14->10
312+
10[label="10 : 13, 14"];
313+
10->5
314+
5[label="5 : 10, 13"];
315+
10->4
316+
4[label="4 : 10, 13"];
317+
})";
318+
EXPECT(assert_equal(expected, actual));
317319
}
318320

319321
/* ************************************************************************* */

gtsam/hybrid/HybridEliminationTree.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,22 @@ class GTSAM_EXPORT HybridEliminationTree
4343
/// @{
4444

4545
/**
46-
* Build the elimination tree of a factor graph using pre-computed column
46+
* Construct the elimination tree of a factor graph using pre-computed column
4747
* structure.
4848
* @param factorGraph The factor graph for which to build the elimination tree
4949
* @param structure The set of factors involving each variable. If this is
5050
* not precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
5151
* named constructor instead.
52-
* @return The elimination tree
52+
* @param order The ordering of the variables.
5353
*/
5454
HybridEliminationTree(const HybridGaussianFactorGraph& factorGraph,
5555
const VariableIndex& structure, const Ordering& order);
5656

57-
/** Build the elimination tree of a factor graph. Note that this has to
57+
/** Construct the elimination tree of a factor graph. Note that this has to
5858
* compute the column structure as a VariableIndex, so if you already have
5959
* this precomputed, use the other constructor instead.
6060
* @param factorGraph The factor graph for which to build the elimination tree
61+
* @param order The ordering of the variables.
6162
*/
6263
HybridEliminationTree(const HybridGaussianFactorGraph& factorGraph,
6364
const Ordering& order);

gtsam/hybrid/HybridFactor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
135135

136136
/// Compute tree of linear errors.
137137
virtual AlgebraicDecisionTree<Key> errorTree(
138-
const VectorValues &values) const = 0;
138+
const VectorValues &continuousValues) const = 0;
139139

140140
/// Restrict the factor to the given discrete values.
141141
virtual std::shared_ptr<Factor> restrict(

gtsam/hybrid/HybridGaussianFactor.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <gtsam/base/utilities.h>
2323
#include <gtsam/discrete/DecisionTree-inl.h>
2424
#include <gtsam/discrete/DecisionTree.h>
25+
#include <gtsam/discrete/DiscreteValues.h>
2526
#include <gtsam/hybrid/HybridFactor.h>
2627
#include <gtsam/hybrid/HybridGaussianFactor.h>
2728
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
@@ -193,16 +194,41 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
193194
}
194195

195196
/* *******************************************************************************/
196-
double HybridGaussianFactor::error(const HybridValues& values) const {
197+
double HybridGaussianFactor::error(const HybridValues& hybridValues) const {
197198
// Directly index to get the component, no need to build the whole tree.
198-
const GaussianFactorValuePair pair = factors_(values.discrete());
199-
return PotentiallyPrunedComponentError(pair, values.continuous());
199+
const GaussianFactorValuePair pair = factors_(hybridValues.discrete());
200+
return PotentiallyPrunedComponentError(pair, hybridValues.continuous());
200201
}
201202

202203
/* ************************************************************************ */
203204
std::shared_ptr<Factor> HybridGaussianFactor::restrict(
204-
const DiscreteValues& assignment) const {
205-
throw std::runtime_error("HybridGaussianFactor::restrict not implemented");
205+
const DiscreteValues& assignment) const {
206+
FactorValuePairs restrictedTree = this->factors_; // Start with the original tree
207+
208+
const DiscreteKeys& currentFactorDiscreteKeys = this->discreteKeys();
209+
DiscreteKeys newFactorDiscreteKeys; // For the new, restricted factor
210+
211+
// Iterate over the discrete keys of the current factor
212+
for (const DiscreteKey& discreteKey : currentFactorDiscreteKeys) {
213+
const Key& key = discreteKey.first;
214+
215+
// Check if this key is specified in the assignment
216+
if (assignment.find(key) != assignment.end()) {
217+
// Key is in assignment: restrict the tree by choosing the branch
218+
restrictedTree = restrictedTree.choose(key, assignment.at(key));
219+
// This key is now fixed, so it's not a discrete key for the new factor
220+
}
221+
else {
222+
// Key is not in assignment: it remains a discrete key for the new factor
223+
newFactorDiscreteKeys.push_back(discreteKey);
224+
}
225+
}
226+
227+
// Create and return the new HybridGaussianFactor.
228+
// Its constructor will derive continuous keys from the GaussianFactor
229+
// shared_ptrs within the restrictedTree.
230+
return std::make_shared<HybridGaussianFactor>(newFactorDiscreteKeys,
231+
restrictedTree);
206232
}
207233

208234
/* ************************************************************************ */

gtsam/hybrid/HybridGaussianFactor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
144144
* @brief Compute the log-likelihood, including the log-normalizing constant.
145145
* @return double
146146
*/
147-
double error(const HybridValues &values) const override;
147+
double error(const HybridValues &hybridValues) const override;
148148

149149
/// Getter for GaussianFactor decision tree
150150
const FactorValuePairs &factors() const { return factors_; }

gtsam/hybrid/HybridJunctionTree.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,7 @@ class GTSAM_EXPORT HybridJunctionTree
5757
typedef HybridJunctionTree This; ///< This class
5858
typedef std::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
5959

60-
/**
61-
* Build the elimination tree of a factor graph using precomputed column
62-
* structure.
63-
* @param factorGraph The factor graph for which to build the elimination tree
64-
* @param structure The set of factors involving each variable. If this is
65-
* not precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
66-
* named constructor instead.
67-
* @return The elimination tree
68-
*/
60+
/// Construct the junction tree from an elimination tree
6961
HybridJunctionTree(const HybridEliminationTree& eliminationTree);
7062
};
7163

gtsam/hybrid/HybridNonlinearFactor.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ double HybridNonlinearFactor::error(
117117
}
118118

119119
/* *******************************************************************************/
120-
double HybridNonlinearFactor::error(const HybridValues& values) const {
121-
return error(values.nonlinear(), values.discrete());
120+
double HybridNonlinearFactor::error(const HybridValues& hybridValues) const {
121+
return error(hybridValues.nonlinear(), hybridValues.discrete());
122122
}
123123

124124
/* *******************************************************************************/
@@ -138,6 +138,7 @@ void HybridNonlinearFactor::print(const std::string& s,
138138
auto [factor, val] = v;
139139
if (factor) {
140140
RedirectCout rd;
141+
std::cout << "(val=" << val << ") ";
141142
factor->print("", keyFormatter);
142143
return rd.str();
143144
} else {

0 commit comments

Comments
 (0)