Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ changelog does not include internal changes that do not affect the user.

## [Unreleased]

### Changed

- **BREAKING**: Removed `numpy`, `quadprog` and `qpsolvers` from the main dependencies of `torchjd`,
(which now only has `torch` as its main dependency). This makes the base version of `torchjd`
(installed with `pip install torchjd`) much lighter, but it means that users of `UPGrad` and
`DualProj` now have to install the new optional dependency group `quadprog_projector` explicitly
(with e.g. `pip install "torchjd[quadprog_projector]"`).

## [0.11.0] - 2026-05-18

### Changed
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ size). In addition to $\mathcal A_{\text{UPGrad}}$, TorchJD supports
<!-- start installation -->
TorchJD can be installed directly with pip:
```bash
pip install torchjd
pip install "torchjd[quadprog_projector]"
```
<!-- end installation -->
Some aggregators may have additional dependencies. Please refer to the
This includes the dependencies required by UPGrad and DualProj. Some other aggregators may have
additional dependencies. Please refer to the
[installation documentation](https://torchjd.org/stable/installation) for them.

## Usage
Expand Down
16 changes: 8 additions & 8 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

Note that `torchjd` requires Python 3.10, 3.11, 3.12, 3.13 or 3.14 and `torch>=2.0`.

Some aggregators (CAGrad and Nash-MTL) have additional dependencies that are not included by default
when installing `torchjd`. To install them, you can use:
```
pip install "torchjd[cagrad]"
```
```
pip install "torchjd[nash_mtl]"
```
Some aggregators have additional dependencies that are not included by default when installing
`torchjd`. The following table lists the optional dependency groups and the aggregators they enable:

Group | Classes | Dependencies | Install command |
|-----|---------|--------------|-----------------|
| `quadprog_projector` | {class}`~torchjd.linalg.QuadprogProjector` (used in {class}`~torchjd.aggregation.UPGrad` and {class}`~torchjd.aggregation.DualProj`) | [numpy](https://github.com/numpy/numpy), [quadprog](https://github.com/quadprog/quadprog), [qpsolvers](https://github.com/qpsolvers/qpsolvers) | `pip install "torchjd[quadprog_projector]"` |
| `cagrad` | {class}`~torchjd.aggregation.CAGrad` | [numpy](https://github.com/numpy/numpy), [cvxpy](https://github.com/cvxpy/cvxpy/) | `pip install "torchjd[cagrad]"` |
| `nash_mtl` | {class}`~torchjd.aggregation.NashMTL` | [numpy](https://github.com/numpy/numpy), [cvxpy](https://github.com/cvxpy/cvxpy/), [ecos](https://github.com/embotech/ecos) | `pip install "torchjd[nash_mtl]"` |

To install `torchjd` with all of its optional dependencies, you can also use:
```
Expand Down
17 changes: 13 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ authors = [
requires-python = ">=3.10"
dependencies = [
"torch>=2.3.0", # Problems before 2.4.0, especially with autogram.
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"qpsolvers>=1.0.1", # Does not work before 1.0.1
]
classifiers = [
"Development Status :: 4 - Beta",
Expand Down Expand Up @@ -93,11 +90,13 @@ test = [
]

plot = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"plotly[kaleido]>=5.19.0", # Recent version to avoid problems, could be relaxed
"dash>=2.16.0", # Recent version to avoid problems, could be relaxed
"matplotlib>=3.10.0", # Recent version to avoid problems, could be relaxed
]
# Dependency group allowing to easily resolve version of the core dependencies to the lower bound.
# Dependency group allowing to easily resolve version of the recommended dependencies to the lower
# bound.
lower_bounds = [
"torch==2.3.0",
"numpy==1.21.2",
Expand All @@ -106,14 +105,24 @@ lower_bounds = [
]

[project.optional-dependencies]
quadprog_projector = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
"qpsolvers>=1.0.1", # Does not work before 1.0.1
]
nash_mtl = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"cvxpy>=1.3.0", # Could be relaxed
"ecos>=2.0.14", # Does not work before 2.0.14
]
cagrad = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
]
full = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
"qpsolvers>=1.0.1", # Does not work before 1.0.1
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
"ecos>=2.0.14", # Does not work before 2.0.14
]
Expand Down
15 changes: 12 additions & 3 deletions src/torchjd/_linalg/_dual_cone.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import contextlib
from abc import ABC, abstractmethod

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor

from torchjd._mixins import _WithOptionalDeps

with contextlib.suppress(ImportError):
import numpy as np
from qpsolvers import solve_qp

from ._gramian import normalize, regularize
from ._matrix import PSDMatrix

Expand Down Expand Up @@ -49,7 +54,7 @@ def projector_or_default(projector: DualConeProjector | None) -> DualConeProject
return projector


class QuadprogProjector(DualConeProjector):
class QuadprogProjector(_WithOptionalDeps, DualConeProjector):
r"""
Solves the quadratic program defined in :meth:`DualConeProjector.__call__` using the
`quadprog <https://github.com/quadprog/quadprog>`_ QP solver.
Expand All @@ -61,12 +66,16 @@ class QuadprogProjector(DualConeProjector):
ensures that it is positive definite.
"""

_REQUIRED_DEPS = ["numpy", "qpsolvers", "quadprog"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[quadprog_projector]"'

def __init__(
self,
*,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
) -> None:
super().__init__()
self._norm_eps = norm_eps
self._reg_eps = reg_eps

Expand Down
28 changes: 28 additions & 0 deletions src/torchjd/_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from importlib.util import find_spec
from typing import Any


class _WithOptionalDeps:
"""
Mixin that raises :class:`ImportError` at instantiation time if required optional dependencies
are not installed.

Subclasses must define :attr:`_REQUIRED_DEPS` (list of package names to check via
:func:`importlib.util.find_spec`) and :attr:`_INSTALL_HINT` (appended to the error message).

.. warning::
This mixin must appear **first** in the inheritance list so that its :meth:`__init__`
runs before any base class that uses the optional dependencies.
"""

_REQUIRED_DEPS: list[str]
_INSTALL_HINT: str

def __init__(self, *args: Any, **kwargs: Any) -> None:
missing = [name for name in self._REQUIRED_DEPS if find_spec(name) is None]
if len(missing) != 0:
raise ImportError(
f"{self.__class__.__name__} requires {missing} to be installed. "
f"{self._INSTALL_HINT}"
)
super().__init__(*args, **kwargs)
7 changes: 4 additions & 3 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import contextlib
from typing import cast

import numpy as np
import torch
from torch import Tensor

from torchjd._linalg import normalize
from torchjd._mixins import _WithOptionalDeps
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mixins import _NonDifferentiable, _WithOptionalDeps
from ._mixins import _NonDifferentiable
from ._weighting_bases import _GramianWeighting

with contextlib.suppress(ImportError):
import cvxpy as cp
import numpy as np


# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
class CAGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting):
_REQUIRED_DEPS = ["cvxpy", "clarabel"]
_REQUIRED_DEPS = ["numpy", "cvxpy", "clarabel"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[cagrad]"'
"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
Expand Down
27 changes: 0 additions & 27 deletions src/torchjd/aggregation/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,10 @@
from abc import ABC, abstractmethod
from importlib.util import find_spec
from typing import Any

import torch
from torch import nn


class _WithOptionalDeps:
"""
Mixin that raises :class:`ImportError` at instantiation time if required optional dependencies
are not installed.

Subclasses must define :attr:`_REQUIRED_DEPS` (list of package names to check via
:func:`importlib.util.find_spec`) and :attr:`_INSTALL_HINT` (appended to the error message).

.. warning::
This mixin must appear **first** in the inheritance list so that its :meth:`__init__`
runs before any base class that uses the optional dependencies.
"""

_REQUIRED_DEPS: list[str]
_INSTALL_HINT: str

def __init__(self, *args: Any, **kwargs: Any) -> None:
missing = [name for name in self._REQUIRED_DEPS if find_spec(name) is None]
if len(missing) != 0:
raise ImportError(
f"{self.__class__.__name__} requires {missing} to be installed. "
f"{self._INSTALL_HINT}"
)
super().__init__(*args, **kwargs)


class Stateful(ABC):
"""Mixin adding a reset method."""

Expand Down
7 changes: 4 additions & 3 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,24 @@

import contextlib

import numpy as np
import torch
from torch import Tensor

from torchjd.aggregation._mixins import Stateful, _NonDifferentiable, _WithOptionalDeps
from torchjd._mixins import _WithOptionalDeps
from torchjd.aggregation._mixins import Stateful, _NonDifferentiable

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import _MatrixWeighting

with contextlib.suppress(ImportError):
import cvxpy as cp
import numpy as np
from cvxpy import Expression, SolverError


# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
class _NashMTLWeighting(_WithOptionalDeps, _NonDifferentiable, Stateful, _MatrixWeighting):
_REQUIRED_DEPS = ["cvxpy", "ecos"]
_REQUIRED_DEPS = ["numpy", "cvxpy", "ecos"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[nash_mtl]"'
"""
:class:`~torchjd.aggregation._mixins.Stateful`
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/aggregation/test_cagrad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
from utils.optional_deps import skip_if_deps_not_installed

pytest.importorskip("cvxpy")
pytest.importorskip("clarabel")
from torchjd.aggregation import CAGradWeighting

skip_if_deps_not_installed(CAGradWeighting)

from contextlib import nullcontext as does_not_raise

Expand All @@ -11,7 +12,6 @@
from utils.tensors import ones_

from torchjd.aggregation import CAGrad
from torchjd.aggregation._cagrad import CAGradWeighting

from ._asserts import assert_expected_structure, assert_non_conflicting, assert_non_differentiable
from ._inputs import scaled_matrices, typical_matrices
Expand Down
31 changes: 15 additions & 16 deletions tests/unit/aggregation/test_cr_mogm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pytest import mark, raises
from torch import Tensor
from torch.testing import assert_close
from utils.optional_deps import base_weighting
from utils.tensors import randn_, tensor_

from torchjd.aggregation import GradVacWeighting, MeanWeighting, UPGradWeighting
from torchjd.aggregation import GradVacWeighting, MeanWeighting
from torchjd.aggregation._aggregator_bases import (
GramianWeightedAggregator,
WeightedAggregator,
Expand All @@ -15,12 +16,10 @@

# UPGradWeighting uses a QP solver that can fail on the extreme scales (0.0, 1e15) found in
# scaled_matrices, so the gramian-path structural test only uses typical_matrices.
matrix_pairs = [
(WeightedAggregator(CRMOGMWeighting(MeanWeighting())), m)
for m in typical_matrices + scaled_matrices
]
matrix_pairs = [(WeightedAggregator(CRMOGMWeighting(MeanWeighting())), m) for m in typical_matrices]
gramian_pairs = [
(GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())), m) for m in typical_matrices
(GramianWeightedAggregator(CRMOGMWeighting(base_weighting())), m)
for m in typical_matrices + scaled_matrices
]


Expand All @@ -40,14 +39,14 @@ def test_expected_structure_gramian_weighting(

def test_reset_restores_first_step_behavior() -> None:
"""
Use ``UPGradWeighting`` so the weights actually depend on the input — with
Use ``base_weighting`` so the weights actually depend on the input — with
``MeanWeighting`` the EMA would be a fixed point at the uniform weights and the test would
be trivial.
"""

J = randn_((3, 8))
G = J @ J.T
W = CRMOGMWeighting(UPGradWeighting(), alpha=0.5)
W = CRMOGMWeighting(base_weighting(), alpha=0.5)
first = W(G)
W(G)
W.reset()
Expand Down Expand Up @@ -105,8 +104,8 @@ def test_alpha_zero_reduces_to_bare_weighting() -> None:

J = randn_((3, 8))
G = J @ J.T
bare = UPGradWeighting()
smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=0.0)
bare = base_weighting()
smoothed = CRMOGMWeighting(base_weighting(), alpha=0.0)

expected = bare(G)
assert_close(smoothed(G), expected)
Expand All @@ -122,7 +121,7 @@ def test_alpha_one_freezes_weights() -> None:

J = randn_((3, 8))
G = J @ J.T
W = CRMOGMWeighting(UPGradWeighting(), alpha=1.0)
W = CRMOGMWeighting(base_weighting(), alpha=1.0)
first = W(G)

assert_close(W(G), first)
Expand All @@ -138,8 +137,8 @@ def test_ema_is_applied() -> None:
G1 = J1 @ J1.T
G2 = J2 @ J2.T

bare = UPGradWeighting()
smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=alpha)
bare = base_weighting()
smoothed = CRMOGMWeighting(base_weighting(), alpha=alpha)

lambda_hat_1 = bare(G1)
lambda_hat_2 = bare(G2)
Expand All @@ -160,8 +159,8 @@ def test_initial_weights_used_as_lambda_0() -> None:
G = J @ J.T
initial = tensor_([0.5, 0.3, 0.2])

bare = UPGradWeighting()
W = CRMOGMWeighting(UPGradWeighting(), alpha=alpha, initial_weights=initial)
bare = base_weighting()
W = CRMOGMWeighting(base_weighting(), alpha=alpha, initial_weights=initial)

lambda_hat_1 = bare(G)
expected_1 = alpha * initial + (1.0 - alpha) * lambda_hat_1
Expand All @@ -177,7 +176,7 @@ def test_reset_restores_initial_weights() -> None:
G = J @ J.T
initial = tensor_([0.5, 0.3, 0.2])

W = CRMOGMWeighting(UPGradWeighting(), alpha=alpha, initial_weights=initial)
W = CRMOGMWeighting(base_weighting(), alpha=alpha, initial_weights=initial)
first = W(G)
W(G)
W.reset()
Expand Down
Loading
Loading