Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7573e4c
add proxsuite
PierreQuinton May 7, 2026
69e42a3
Translate from qpsolvers to proxsuite (done by opencode)
PierreQuinton May 7, 2026
952e617
remove `project_weights_vector`
PierreQuinton May 7, 2026
b1b8c8d
Change from qplayer to BatchQP (lower level)
PierreQuinton May 8, 2026
6dcd0f9
Merge branch 'main' into use-proxsuite
PierreQuinton May 10, 2026
8968715
Make proxsuite run QPs in parallel, put the cast of result to tensor …
PierreQuinton May 10, 2026
14e1f2d
Merge branch 'main' into use-proxsuite
ValerianRey May 11, 2026
edf5272
Merge branch 'main' into use-proxsuite
ValerianRey May 12, 2026
fcc19d3
refactor!: Add `DualConeProjector` (#678)
PierreQuinton May 14, 2026
5b37228
expose ProxsuiteProjector documentation.
PierreQuinton May 15, 2026
d546f91
test UPGrad and DualProj with ProxsuiteProjector
PierreQuinton May 15, 2026
3d866fd
Add back ProxsuiteProjector
PierreQuinton May 15, 2026
bcb04c0
Merge branch 'main' into use-proxsuite
PierreQuinton May 15, 2026
8f2bb91
Make `_project_weight_vector_batch` handle only numpy arrays.
PierreQuinton May 15, 2026
ab04850
readd qpsolvers in pyproject.toml
PierreQuinton May 15, 2026
88c4557
Add repr tests.
PierreQuinton May 15, 2026
b1c3a45
add initial guess (doens't solve permutation invaiance but is a good …
PierreQuinton May 17, 2026
b7b6c42
Decrease absolute precision (Which makes convergence faster).
PierreQuinton May 17, 2026
cc46744
Update precision of permutation invariance tests (it's not too bad gi…
PierreQuinton May 17, 2026
2254415
Merge branch 'main' into use-proxsuite
ValerianRey May 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/docs/linalg/dual_cone.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ Dual Cone Projectors
:members: __call__

.. autoclass:: torchjd.linalg.QuadprogProjector

.. autoclass:: torchjd.linalg.ProxsuiteProjector
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"qpsolvers>=1.0.1", # Does not work before 1.0.1
"proxsuite>=0.7.2",
]
classifiers = [
"Development Status :: 4 - Beta",
Expand Down Expand Up @@ -101,6 +102,7 @@ plot = [
lower_bounds = [
"torch==2.3.0",
"numpy==1.21.2",
"proxsuite==0.7.2",
"quadprog==0.1.9",
"qpsolvers==1.0.1",
]
Expand Down
8 changes: 7 additions & 1 deletion src/torchjd/_linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from ._dual_cone import DualConeProjector, QuadprogProjector, projector_or_default
from ._dual_cone import (
DualConeProjector,
ProxsuiteProjector,
QuadprogProjector,
projector_or_default,
)
from ._generalized_gramian import flatten, movedim, reshape
from ._gramian import compute_gramian, normalize, regularize
from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor
Expand All @@ -18,5 +23,6 @@
"movedim",
"DualConeProjector",
"QuadprogProjector",
"ProxsuiteProjector",
"projector_or_default",
]
74 changes: 74 additions & 0 deletions src/torchjd/_linalg/_dual_cone.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from abc import ABC, abstractmethod

import numpy as np
import torch
from proxsuite import proxqp

Check failure on line 6 in src/torchjd/_linalg/_dual_cone.py

View workflow job for this annotation

GitHub Actions / Code quality (ty and pre-commit hooks)

ty (unresolved-import)

src/torchjd/_linalg/_dual_cone.py:6:6: unresolved-import: Cannot resolve imported module `proxsuite` info: Searched in the following paths during module resolution: info: 1. /home/runner/work/TorchJD/TorchJD/src (first-party code) info: 2. /home/runner/work/TorchJD/TorchJD (first-party code) info: 3. vendored://stdlib (stdlib typeshed stubs vendored by ty) info: 4. /home/runner/work/TorchJD/TorchJD/.venv/lib/python3.14/site-packages (site-packages) info: 5. /home/runner/work/TorchJD/TorchJD/.venv/lib64/python3.14/site-packages (site-packages) info: make sure your Python environment is properly configured: https://docs.astral.sh/ty/modules/#python-environment
from qpsolvers import solve_qp
from torch import Tensor

Expand Down Expand Up @@ -115,6 +117,78 @@
return w


class ProxsuiteProjector(DualConeProjector):
r"""
Solves the quadratic program defined in :meth:`DualConeProjector.__call__` using the
`proxsuite <https://github.com/Simple-Robotics/proxsuite>`_ QP solver.
"""

def __init__(self) -> None:
pass

def __repr__(self) -> str:
return "ProxsuiteProjector()"

def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor:
original_shape = U.shape
m = G.shape[0]
G_ = _to_array(G)
U_flat = _to_array(U.reshape(-1, m)) # [nBatch, m]

W = self._project_weight_vector_batch(U_flat, G_)

return torch.as_tensor(W, device=G.device, dtype=G.dtype).reshape(original_shape)

@torch.no_grad()
def _project_weight_vector_batch(self, U: np.ndarray, G: np.ndarray) -> np.ndarray:

n, m = U.shape

Q_np = G
p_np = np.zeros(m, dtype=np.float64)
C_np = -np.eye(m, dtype=np.float64)
lb_np = np.full(m, -1e20, dtype=np.float64)
ub_np = U

batch_qps = proxqp.dense.BatchQP()
default_rho = 5.0e-5

for i in range(n):
qp = batch_qps.init_qp_in_place(m, 0, m)
qp.settings.primal_infeasibility_solving = False
qp.settings.max_iter = 1000
qp.settings.max_iter_in = 100
qp.settings.default_rho = default_rho
qp.settings.refactor_rho_threshold = default_rho
qp.settings.eps_abs = 1e-6

u = -ub_np[i]

qp.init(
H=Q_np,
g=p_np,
A=None,
b=None,
C=C_np,
l=lb_np,
u=u,
rho=default_rho,
)

# Initial guess
qp.results.x = u.copy()
qp.results.z = np.maximum(0.0, Q_np @ u)

num_threads = max(1, (os.cpu_count() or 2) // 2)
proxqp.dense.solve_in_parallel(num_threads=num_threads, qps=batch_qps)

zhats_np = np.empty((n, m), dtype=np.float64)
for i in range(n):
zhats_np[i] = batch_qps.get(i).results.x

return zhats_np


def _to_array(tensor: Tensor) -> np.ndarray:
"""Transforms a tensor into a numpy array with float64 dtype."""

Expand Down
2 changes: 2 additions & 0 deletions src/torchjd/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torchjd._linalg import (
DualConeProjector,
Matrix,
ProxsuiteProjector,
PSDMatrix,
QuadprogProjector,
)
Expand All @@ -15,4 +16,5 @@
"Matrix",
"PSDMatrix",
"QuadprogProjector",
"ProxsuiteProjector",
]
28 changes: 22 additions & 6 deletions tests/unit/aggregation/test_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch import Tensor
from utils.tensors import ones_

from torchjd._linalg import QuadprogProjector
from torchjd._linalg import ProxsuiteProjector, QuadprogProjector
from torchjd.aggregation import ConstantWeighting, DualProj

from ._asserts import (
Expand All @@ -15,10 +15,26 @@
)
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices

scaled_pairs = [(DualProj(), matrix) for matrix in scaled_matrices]
typical_pairs = [(DualProj(), matrix) for matrix in typical_matrices]
non_strong_pairs = [(DualProj(), matrix) for matrix in non_strong_matrices]
requires_grad_pairs = [(DualProj(), ones_(3, 5, requires_grad=True))]
projectors = [QuadprogProjector(), ProxsuiteProjector()]

scaled_pairs = [
(DualProj(projector=projector), matrix)
for matrix in scaled_matrices
for projector in projectors
]
typical_pairs = [
(DualProj(projector=projector), matrix)
for matrix in typical_matrices
for projector in projectors
]
non_strong_pairs = [
(DualProj(projector=projector), matrix)
for matrix in non_strong_matrices
for projector in projectors
]
requires_grad_pairs = [
(DualProj(projector=projector), ones_(3, 5, requires_grad=True)) for projector in projectors
]


@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
Expand All @@ -33,7 +49,7 @@ def test_non_conflicting(aggregator: DualProj, matrix: Tensor) -> None:

@mark.parametrize(["aggregator", "matrix"], typical_pairs)
def test_permutation_invariant(aggregator: DualProj, matrix: Tensor) -> None:
assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=2e-07, rtol=2e-07)
assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=6e-05, rtol=2e-07)


@mark.parametrize(["aggregator", "matrix"], non_strong_pairs)
Expand Down
24 changes: 18 additions & 6 deletions tests/unit/aggregation/test_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch import Tensor
from utils.tensors import ones_

from torchjd._linalg import QuadprogProjector
from torchjd._linalg import ProxsuiteProjector, QuadprogProjector
from torchjd.aggregation import ConstantWeighting, UPGrad

from ._asserts import (
Expand All @@ -16,10 +16,22 @@
)
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices

scaled_pairs = [(UPGrad(), matrix) for matrix in scaled_matrices]
typical_pairs = [(UPGrad(), matrix) for matrix in typical_matrices]
non_strong_pairs = [(UPGrad(), matrix) for matrix in non_strong_matrices]
requires_grad_pairs = [(UPGrad(), ones_(3, 5, requires_grad=True))]
projectors = [QuadprogProjector(), ProxsuiteProjector()]

scaled_pairs = [
(UPGrad(projector=projector), matrix) for matrix in scaled_matrices for projector in projectors
]
typical_pairs = [
(UPGrad(projector=projector), matrix) for matrix in typical_matrices for projector in projectors
]
non_strong_pairs = [
(UPGrad(projector=projector), matrix)
for matrix in non_strong_matrices
for projector in projectors
]
requires_grad_pairs = [
(UPGrad(projector=projector), ones_(3, 5, requires_grad=True)) for projector in projectors
]


@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
Expand All @@ -34,7 +46,7 @@ def test_non_conflicting(aggregator: UPGrad, matrix: Tensor) -> None:

@mark.parametrize(["aggregator", "matrix"], typical_pairs)
def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor) -> None:
assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=5e-07, rtol=5e-07)
assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=7e-05, rtol=5e-07)


@mark.parametrize(["aggregator", "matrix"], typical_pairs)
Expand Down
24 changes: 22 additions & 2 deletions tests/unit/linalg/test_dual_cone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
from torch.testing import assert_close
from utils.tensors import rand_, randn_

from torchjd._linalg import DualConeProjector, PSDMatrix, QuadprogProjector, compute_gramian
from torchjd._linalg import (
DualConeProjector,
ProxsuiteProjector,
PSDMatrix,
QuadprogProjector,
compute_gramian,
)

projectors = [QuadprogProjector(reg_eps=0.0, norm_eps=0.0), ProxsuiteProjector()]

@mark.parametrize("projector", [QuadprogProjector(reg_eps=0.0, norm_eps=0.0)])

@mark.parametrize("projector", projectors)
@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
def test_solution_weights(projector: DualConeProjector, shape: tuple[int, int]) -> None:
r"""
Expand Down Expand Up @@ -130,6 +138,18 @@ def test_reg_eps_setter_rejects_negative() -> None:
projector.reg_eps = -1e-9


def test_quadprog_repr() -> None:
A = QuadprogProjector(norm_eps=0.001, reg_eps=0.01)
assert repr(A) == "QuadprogProjector(norm_eps=0.001, reg_eps=0.01)"
assert str(A) == "QuadprogProjector(norm_eps=0.001, reg_eps=0.01)"


def test_proxsuite_repr() -> None:
A = ProxsuiteProjector()
assert repr(A) == "ProxsuiteProjector()"
assert str(A) == "ProxsuiteProjector()"


def test_qp_solver_based_failure() -> None:
"""
Tests that `QPSolverBased._project_weight_vector` raises an error when the input G has too large
Expand Down
Loading