Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ jobs:
run: cargo run --example portfolio

msrv:
name: Check MSRV (1.70)
name: Check MSRV (1.71)
runs-on: ubuntu-latest

steps:
Expand Down
44 changes: 36 additions & 8 deletions src/canon/canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use nalgebra_sparse::CscMatrix;

use super::lin_expr::{LinExpr, QuadExpr};
use crate::expr::{Array, Expr, ExprId, IndexSpec, Shape, VariableBuilder};
use crate::sparse::{csc_repeat_rows, csc_to_dense, csc_vstack, dense_to_csc};
use crate::sparse::{csc_add, csc_repeat_rows, csc_to_dense, csc_vstack, dense_to_csc};

/// A cone constraint in standard form: Ax + b in K.
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -903,15 +903,43 @@ impl CanonContext {

fn canonicalize_sum_squares_lin(&mut self, x: &LinExpr, for_objective: bool) -> CanonExpr {
if for_objective {
// For objective, use native QP: ||x||^2 = x' I x
// The (1/2) factor for Clarabel is handled in stuffing.rs
// For objective, use native QP: ||Ax + c||^2 = x'(A'A)x + 2c'Ax + ||c||^2
// stuffing.rs doubles P to account for Clarabel's (1/2)x'Px convention.
let vars = x.variables();
if vars.len() == 1 && x.constant.iter().all(|&v| v == 0.0) {
let var_id = vars[0];
let size = x.size();
let identity = CscMatrix::identity(size);
return CanonExpr::Quadratic(QuadExpr::quadratic(var_id, identity));
let c = &x.constant; // dense (m, 1) constant

let mut quad_coeffs = std::collections::HashMap::new();
let mut linear_coeffs = std::collections::HashMap::new();

for &var_i in &vars {
let ai = csc_to_dense(&x.coeffs[&var_i]); // (m, ni)
for &var_j in &vars {
let aj = csc_to_dense(&x.coeffs[&var_j]); // (m, nj)
// A_i' * A_j: (ni, m) * (m, nj) = (ni, nj)
let ai_t_aj = dense_to_csc(&(ai.transpose() * &aj));
quad_coeffs
.entry((var_i, var_j))
.and_modify(|existing| *existing = csc_add(existing, &ai_t_aj))
.or_insert(ai_t_aj);
}
// Linear term: 2 * c' * A_i → (1, ni) row coefficient
let q_col = ai.transpose() * c; // (ni, 1)
let q_row = dense_to_csc(&(q_col * 2.0).transpose()); // (1, ni)
linear_coeffs.insert(var_i, q_row);
}

// Constant: ||c||^2
let constant: f64 = c.iter().map(|v| v * v).sum();

return CanonExpr::Quadratic(QuadExpr {
quad_coeffs,
linear: LinExpr {
coeffs: linear_coeffs,
constant: DMatrix::zeros(1, 1),
shape: Shape::scalar(),
},
constant,
});
}

// SOC reformulation: ||x||^2 <= t iff SOC(sqrt(t), x)
Expand Down
107 changes: 107 additions & 0 deletions tests/sum_squares_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use cvxrust::prelude::*;

/// sum_squares(Ax - b) must be canonicalized as ||r||^2, not ||r||_2.
/// The old bug used a plain SOC `||r||_2 <= t`, which minimized the L2 norm
/// instead of its square, causing solution.value to be off by a square root.

#[test]
fn test_sum_squares_scalar_constrained() {
// minimize (x - 3)^2 s.t. x <= 1
// True optimum: x* = 1, obj* = (1 - 3)^2 = 4
// Old bug would report: |1 - 3| = 2
let x = variable(());
let residual = x.clone() - constant(3.0);

let sol = Problem::minimize(sum_squares(&residual))
.constraint(constraint!(x <= 1.0))
.solve()
.unwrap();

let reported = sol.value.unwrap();
let eval_sq = sum_squares(&residual).value(&sol).as_scalar().unwrap();
let eval_norm = norm2(&residual).value(&sol).as_scalar().unwrap();

assert!(
(reported - 4.0).abs() < 1e-4,
"objective should be 4.0, got {reported}"
);
assert!(
(reported - eval_sq).abs() < 1e-4,
"reported obj must equal sum_squares evaluated at solution"
);
// Sanity check: the L2 norm (sqrt(4) = 2) is clearly different from the correct answer
assert!((eval_norm - 2.0).abs() < 1e-4);
assert!(
(reported - eval_norm).abs() > 0.5,
"objective must not equal the L2 norm (old bug)"
);
}

#[test]
fn test_sum_squares_vector_constrained() {
// minimize ||x - [3, 4]||^2 s.t. x <= 2
// True optimum: x* = [2, 2], obj* = (2-3)^2 + (2-4)^2 = 5
// Old bug would report: sqrt(5) ≈ 2.236
let x = variable(2);
let residual = x.clone() - constant_vec(vec![3.0, 4.0]);

let sol = Problem::minimize(sum_squares(&residual))
.constraint(constraint!(x <= 2.0))
.solve()
.unwrap();

let reported = sol.value.unwrap();
let eval_sq = sum_squares(&residual).value(&sol).as_scalar().unwrap();
let eval_norm = norm2(&residual).value(&sol).as_scalar().unwrap();

assert!(
(reported - 5.0).abs() < 1e-4,
"objective should be 5.0, got {reported}"
);
assert!(
(reported - eval_sq).abs() < 1e-4,
"reported obj must equal sum_squares evaluated at solution"
);
// Sanity check: the L2 norm (sqrt(5) ≈ 2.236) is clearly different
assert!((eval_norm - 5f64.sqrt()).abs() < 1e-4);
assert!(
(reported - eval_norm).abs() > 0.5,
"objective must not equal the L2 norm (old bug)"
);
}

#[test]
fn test_sum_squares_matmul_constrained() {
// minimize ||Ax - b||^2 s.t. x <= 1
// A = [[1], [1]] (2x1), b = [2, 4], x scalar
// Unconstrained LS: x* = 3, obj* = (3-2)^2 + (3-4)^2 = 2
// With x <= 1: x* = 1, residual = [-1, -3], obj* = 1 + 9 = 10
// Old bug would report: sqrt(10) ≈ 3.162
let a = constant_matrix(vec![1.0, 1.0], 2, 1);
let b = constant_vec(vec![2.0, 4.0]);
let x = variable(());
let residual = matmul(&a, &x) - &b;

let sol = Problem::minimize(sum_squares(&residual))
.constraint(constraint!(x <= 1.0))
.solve()
.unwrap();

let reported = sol.value.unwrap();
let eval_sq = sum_squares(&residual).value(&sol).as_scalar().unwrap();
let eval_norm = norm2(&residual).value(&sol).as_scalar().unwrap();

assert!(
(reported - 10.0).abs() < 1e-4,
"objective should be 10.0, got {reported}"
);
assert!(
(reported - eval_sq).abs() < 1e-4,
"reported obj must equal sum_squares evaluated at solution"
);
assert!((eval_norm - 10f64.sqrt()).abs() < 1e-4);
assert!(
(reported - eval_norm).abs() > 0.5,
"objective must not equal the L2 norm (old bug)"
);
}