|
22 | 22 | #include <gtsam/base/utilities.h> |
23 | 23 | #include <gtsam/discrete/DecisionTree-inl.h> |
24 | 24 | #include <gtsam/discrete/DecisionTree.h> |
| 25 | +#include <gtsam/discrete/DiscreteValues.h> |
25 | 26 | #include <gtsam/hybrid/HybridFactor.h> |
26 | 27 | #include <gtsam/hybrid/HybridGaussianFactor.h> |
27 | 28 | #include <gtsam/hybrid/HybridGaussianProductFactor.h> |
@@ -193,16 +194,41 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree( |
193 | 194 | } |
194 | 195 |
|
195 | 196 | /* *******************************************************************************/ |
196 | | -double HybridGaussianFactor::error(const HybridValues& values) const { |
| 197 | +double HybridGaussianFactor::error(const HybridValues& hybridValues) const { |
197 | 198 | // Directly index to get the component, no need to build the whole tree. |
198 | | - const GaussianFactorValuePair pair = factors_(values.discrete()); |
199 | | - return PotentiallyPrunedComponentError(pair, values.continuous()); |
| 199 | + const GaussianFactorValuePair pair = factors_(hybridValues.discrete()); |
| 200 | + return PotentiallyPrunedComponentError(pair, hybridValues.continuous()); |
200 | 201 | } |
201 | 202 |
|
202 | 203 | /* ************************************************************************ */ |
203 | 204 | std::shared_ptr<Factor> HybridGaussianFactor::restrict( |
204 | | - const DiscreteValues& assignment) const { |
205 | | - throw std::runtime_error("HybridGaussianFactor::restrict not implemented"); |
| 205 | + const DiscreteValues& assignment) const { |
| 206 | + FactorValuePairs restrictedTree = this->factors_; // Start with the original tree |
| 207 | + |
| 208 | + const DiscreteKeys& currentFactorDiscreteKeys = this->discreteKeys(); |
| 209 | + DiscreteKeys newFactorDiscreteKeys; // For the new, restricted factor |
| 210 | + |
| 211 | + // Iterate over the discrete keys of the current factor |
| 212 | + for (const DiscreteKey& discreteKey : currentFactorDiscreteKeys) { |
| 213 | + const Key& key = discreteKey.first; |
| 214 | + |
| 215 | + // Check if this key is specified in the assignment |
| 216 | + if (assignment.find(key) != assignment.end()) { |
| 217 | + // Key is in assignment: restrict the tree by choosing the branch |
| 218 | + restrictedTree = restrictedTree.choose(key, assignment.at(key)); |
| 219 | + // This key is now fixed, so it's not a discrete key for the new factor |
| 220 | + } |
| 221 | + else { |
| 222 | + // Key is not in assignment: it remains a discrete key for the new factor |
| 223 | + newFactorDiscreteKeys.push_back(discreteKey); |
| 224 | + } |
| 225 | + } |
| 226 | + |
| 227 | + // Create and return the new HybridGaussianFactor. |
| 228 | + // Its constructor will derive continuous keys from the GaussianFactor |
| 229 | + // shared_ptrs within the restrictedTree. |
| 230 | + return std::make_shared<HybridGaussianFactor>(newFactorDiscreteKeys, |
| 231 | + restrictedTree); |
206 | 232 | } |
207 | 233 |
|
208 | 234 | /* ************************************************************************ */ |
|
0 commit comments