diff --git a/expression/collections/map.py b/expression/collections/map.py index 2383c9c..6c7171a 100644 --- a/expression/collections/map.py +++ b/expression/collections/map.py @@ -19,15 +19,16 @@ from collections.abc import Callable, ItemsView, Iterable, Iterator, Mapping from typing import Any, TypeVar, cast -from expression.core import Option, PipeMixin, SupportsLessThan, curry_flip, pipe +from expression.core import Option, PipeMixin, curry_flip, pipe +from expression.core.typing import SupportsLessThanAndHash from . import maptree, seq from .block import Block from .maptree import MapTree -_Key = TypeVar("_Key", bound=SupportsLessThan) -_Key_ = TypeVar("_Key_", bound=SupportsLessThan) +_Key = TypeVar("_Key", bound=SupportsLessThanAndHash) +_Key_ = TypeVar("_Key_", bound=SupportsLessThanAndHash) _Value = TypeVar("_Value") _Result = TypeVar("_Result") diff --git a/expression/collections/maptree.py b/expression/collections/maptree.py index efde6e1..2507171 100644 --- a/expression/collections/maptree.py +++ b/expression/collections/maptree.py @@ -23,17 +23,18 @@ """ import builtins -from collections.abc import Callable, Iterable, Iterator +from collections.abc import Callable, Hashable, Iterable, Iterator from dataclasses import dataclass from typing import Any, Generic, TypeVar -from expression.core import Nothing, Option, Some, SupportsLessThan, failwith, pipe +from expression.core import Nothing, Option, Some, failwith, pipe +from expression.core.typing import SupportsLessThanAndHash from . import block, seq from .block import Block -Key = TypeVar("Key", bound=SupportsLessThan) +Key = TypeVar("Key", bound=SupportsLessThanAndHash) Value = TypeVar("Value") Result = TypeVar("Result") @@ -55,7 +56,7 @@ class MapTreeNode(MapTreeLeaf[Key, Value]): height: int -empty: MapTree[Any, Any] = Nothing +empty: MapTree[Hashable, Any] = Nothing def is_empty(m: MapTree[Any, Any]): @@ -72,7 +73,7 @@ def size_aux(acc: int, m: MapTree[Key, Value]) -> int: return acc -def size(x: MapTree[Any, Any]): +def size(x: MapTree[Hashable, Any]): return size_aux(0, x) diff --git a/expression/core/typing.py b/expression/core/typing.py index 71c30ec..0d9f607 100644 --- a/expression/core/typing.py +++ b/expression/core/typing.py @@ -17,6 +17,11 @@ def __lt__(self, __other: Any) -> bool: raise NotImplementedError +class SupportsLessThanAndHash(SupportsLessThan, Protocol): + def __hash__(self) -> int: + return super().__hash__() + + class SupportsSum(Protocol): @abstractmethod def __radd__(self, __other: Any) -> Any: