Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ size). In addition to $\mathcal A_{\text{UPGrad}}$, TorchJD supports
<!-- start installation -->
TorchJD can be installed directly with pip:
```bash
pip install torchjd
pip install "torchjd[quadprog_projector]"
```
<!-- end installation -->
Some aggregators may have additional dependencies. Please refer to the
This includes the dependencies required by UPGrad and DualProj. Some other aggregators may have
additional dependencies. Please refer to the
[installation documentation](https://torchjd.org/stable/installation) for them.

## Usage
Expand Down
16 changes: 8 additions & 8 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

Note that `torchjd` requires Python 3.10, 3.11, 3.12, 3.13 or 3.14 and `torch>=2.0`.

Some aggregators (CAGrad and Nash-MTL) have additional dependencies that are not included by default
when installing `torchjd`. To install them, you can use:
```
pip install "torchjd[cagrad]"
```
```
pip install "torchjd[nash_mtl]"
```
Some aggregators have additional dependencies that are not included by default when installing
`torchjd`. The following table lists the optional dependency groups and the aggregators they enable:

| Group | Classes | Dependencies | Install command |
|-------|---------|--------------|-----------------|
| `quadprog_projector` | QuadprogProjector (used in UPGrad and DualProj) | `numpy` (BSD-3-Clause), `quadprog` (GPL-2.0+), `qpsolvers` (LGPL-3.0) | `pip install "torchjd[quadprog_projector]"` |
| `cagrad` | CAGrad | `numpy` (BSD-3-Clause), `cvxpy` (Apache-2.0) | `pip install "torchjd[cagrad]"` |
| `nash_mtl` | NashMTL | `numpy` (BSD-3-Clause), `cvxpy` (Apache-2.0), `ecos` (GPL-3.0) | `pip install "torchjd[nash_mtl]"` |
Comment thread
ValerianRey marked this conversation as resolved.
Outdated

To install `torchjd` with all of its optional dependencies, you can also use:
```
Expand Down
16 changes: 12 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ authors = [
requires-python = ">=3.10"
dependencies = [
"torch>=2.3.0", # Problems before 2.4.0, especially with autogram.
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"qpsolvers>=1.0.1", # Does not work before 1.0.1
]
classifiers = [
"Development Status :: 4 - Beta",
Expand Down Expand Up @@ -93,11 +90,12 @@ test = [
]

plot = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"plotly[kaleido]>=5.19.0", # Recent version to avoid problems, could be relaxed
"dash>=2.16.0", # Recent version to avoid problems, could be relaxed
"matplotlib>=3.10.0", # Recent version to avoid problems, could be relaxed
]
# Dependency group allowing to easily resolve version of the core dependencies to the lower bound.
# Dependency group allowing to easily resolve version of the optional dependencies to the lower bound.
Comment thread
ValerianRey marked this conversation as resolved.
Outdated
lower_bounds = [
"torch==2.3.0",
"numpy==1.21.2",
Expand All @@ -106,14 +104,24 @@ lower_bounds = [
]

[project.optional-dependencies]
quadprog_projector = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
"qpsolvers>=1.0.1", # Does not work before 1.0.1
]
nash_mtl = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"cvxpy>=1.3.0", # Could be relaxed
"ecos>=2.0.14", # Does not work before 2.0.14
]
cagrad = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
]
full = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
"qpsolvers>=1.0.1", # Does not work before 1.0.1
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
"ecos>=2.0.14", # Does not work before 2.0.14
]
Expand Down
15 changes: 12 additions & 3 deletions src/torchjd/_linalg/_dual_cone.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import contextlib
from abc import ABC, abstractmethod

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor

from torchjd._mixins import _WithOptionalDeps

with contextlib.suppress(ImportError):
import numpy as np
from qpsolvers import solve_qp

from ._gramian import normalize, regularize
from ._matrix import PSDMatrix

Expand Down Expand Up @@ -49,7 +54,7 @@ def projector_or_default(projector: DualConeProjector | None) -> DualConeProject
return projector


class QuadprogProjector(DualConeProjector):
class QuadprogProjector(_WithOptionalDeps, DualConeProjector):
r"""
Solves the quadratic program defined in :meth:`DualConeProjector.__call__` using the
`quadprog <https://github.com/quadprog/quadprog>`_ QP solver.
Expand All @@ -61,12 +66,16 @@ class QuadprogProjector(DualConeProjector):
ensures that it is positive definite.
"""

_REQUIRED_DEPS = ["numpy", "qpsolvers", "quadprog"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[quadprog_projector]"'

def __init__(
self,
*,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
) -> None:
super().__init__()
self._norm_eps = norm_eps
self._reg_eps = reg_eps

Expand Down
28 changes: 28 additions & 0 deletions src/torchjd/_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from importlib.util import find_spec
from typing import Any


class _WithOptionalDeps:
"""
Mixin that raises :class:`ImportError` at instantiation time if required optional dependencies
are not installed.

Subclasses must define :attr:`_REQUIRED_DEPS` (list of package names to check via
:func:`importlib.util.find_spec`) and :attr:`_INSTALL_HINT` (appended to the error message).

.. warning::
This mixin must appear **first** in the inheritance list so that its :meth:`__init__`
runs before any base class that uses the optional dependencies.
"""

_REQUIRED_DEPS: list[str]
_INSTALL_HINT: str

def __init__(self, *args: Any, **kwargs: Any) -> None:
missing = [name for name in self._REQUIRED_DEPS if find_spec(name) is None]
if len(missing) != 0:
raise ImportError(

Check warning on line 24 in src/torchjd/_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/_mixins.py#L24

Added line #L24 was not covered by tests
f"{self.__class__.__name__} requires {missing} to be installed. "
f"{self._INSTALL_HINT}"
)
super().__init__(*args, **kwargs)
5 changes: 3 additions & 2 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from torch import Tensor

from torchjd._linalg import normalize
from torchjd._mixins import _WithOptionalDeps
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mixins import _NonDifferentiable, _WithOptionalDeps
from ._mixins import _NonDifferentiable
from ._weighting_bases import _GramianWeighting

with contextlib.suppress(ImportError):
Expand All @@ -18,7 +19,7 @@

# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
class CAGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting):
_REQUIRED_DEPS = ["cvxpy", "clarabel"]
_REQUIRED_DEPS = ["numpy", "cvxpy", "clarabel"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[cagrad]"'
"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
Expand Down
27 changes: 0 additions & 27 deletions src/torchjd/aggregation/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,10 @@
from abc import ABC, abstractmethod
from importlib.util import find_spec
from typing import Any

import torch
from torch import nn


class _WithOptionalDeps:
"""
Mixin that raises :class:`ImportError` at instantiation time if required optional dependencies
are not installed.

Subclasses must define :attr:`_REQUIRED_DEPS` (list of package names to check via
:func:`importlib.util.find_spec`) and :attr:`_INSTALL_HINT` (appended to the error message).

.. warning::
This mixin must appear **first** in the inheritance list so that its :meth:`__init__`
runs before any base class that uses the optional dependencies.
"""

_REQUIRED_DEPS: list[str]
_INSTALL_HINT: str

def __init__(self, *args: Any, **kwargs: Any) -> None:
missing = [name for name in self._REQUIRED_DEPS if find_spec(name) is None]
if len(missing) != 0:
raise ImportError(
f"{self.__class__.__name__} requires {missing} to be installed. "
f"{self._INSTALL_HINT}"
)
super().__init__(*args, **kwargs)


class Stateful(ABC):
"""Mixin adding a reset method."""

Expand Down
5 changes: 3 additions & 2 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import torch
from torch import Tensor

from torchjd.aggregation._mixins import Stateful, _NonDifferentiable, _WithOptionalDeps
from torchjd._mixins import _WithOptionalDeps
from torchjd.aggregation._mixins import Stateful, _NonDifferentiable

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import _MatrixWeighting
Expand All @@ -21,7 +22,7 @@

# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
class _NashMTLWeighting(_WithOptionalDeps, _NonDifferentiable, Stateful, _MatrixWeighting):
_REQUIRED_DEPS = ["cvxpy", "ecos"]
_REQUIRED_DEPS = ["numpy", "cvxpy", "ecos"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[nash_mtl]"'
"""
:class:`~torchjd.aggregation._mixins.Stateful`
Expand Down
Loading