diff --git a/async_lru/__init__.py b/async_lru/__init__.py index 447e9cdb..22313bdd 100644 --- a/async_lru/__init__.py +++ b/async_lru/__init__.py @@ -63,6 +63,38 @@ def cancel(self) -> None: self.later_call = None +@final +class _DoneCallback: + def __init__(self, wrapper: "_LRUCacheWrapper[_R]", fut: "asyncio.Future[_R]", key: Hashable) -> None: + self.wrapper = wrapper + self.fut = fut + self.key = key + + def __call__(self, task: "asyncio.Task[_R]") -> None: + wrapper = self.wrapper + wrapper.__tasks.discard(task) + + if task.cancelled(): + fut.cancel() + wrapper.__cache.pop(key, None) + return + + exc = task.exception() + if exc is not None: + fut.set_exception(exc) + wrapper.__cache.pop(key, None) + return + + cache_item = wrapper.__cache.get(key) + if wrapper.__ttl is not None and cache_item is not None: + loop = asyncio.get_running_loop() + cache_item.later_call = loop.call_later( + wrapper.__ttl, wrapper.__cache.pop, key, None + ) + + fut.set_result(task.result()) + + @final class _LRUCacheWrapper(Generic[_R]): def __init__( @@ -167,31 +199,6 @@ def _cache_hit(self, key: Hashable) -> None: def _cache_miss(self, key: Hashable) -> None: self.__misses += 1 - def _task_done_callback( - self, fut: "asyncio.Future[_R]", key: Hashable, task: "asyncio.Task[_R]" - ) -> None: - self.__tasks.discard(task) - - if task.cancelled(): - fut.cancel() - self.__cache.pop(key, None) - return - - exc = task.exception() - if exc is not None: - fut.set_exception(exc) - self.__cache.pop(key, None) - return - - cache_item = self.__cache.get(key) - if self.__ttl is not None and cache_item is not None: - loop = asyncio.get_running_loop() - cache_item.later_call = loop.call_later( - self.__ttl, self.__cache.pop, key, None - ) - - fut.set_result(task.result()) - async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: if self.__closed: raise RuntimeError(f"alru_cache is closed for {self}") @@ -213,7 +220,7 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: coro = self.__wrapped__(*fn_args, **fn_kwargs) task: asyncio.Task[_R] = loop.create_task(coro) self.__tasks.add(task) - task.add_done_callback(partial(self._task_done_callback, fut, key)) + task.add_done_callback(_DoneCallback(self, fut, key)) self.__cache[key] = _CacheItem(fut, None) diff --git a/benchmark.py b/benchmark.py index b01e0e8a..830021d2 100644 --- a/benchmark.py +++ b/benchmark.py @@ -4,7 +4,7 @@ import pytest -from async_lru import _LRUCacheWrapper, alru_cache +from async_lru import _DoneCallback, _LRUCacheWrapper, alru_cache try: @@ -306,10 +306,9 @@ async def dummy_coro(): iterations = range(1000) create_future = loop.create_future - callback_fn = func._task_done_callback @benchmark def run() -> None: for i in iterations: - callback = partial(callback_fn, create_future(), i) + callback = _DoneCallback(func, create_future(), i) callback(task) diff --git a/tests/test_internals.py b/tests/test_internals.py index e5a055c4..5232f98f 100644 --- a/tests/test_internals.py +++ b/tests/test_internals.py @@ -1,5 +1,4 @@ import asyncio -from functools import partial from unittest import mock import pytest @@ -15,7 +14,7 @@ async def test_done_callback_cancelled() -> None: key = 1 - task.add_done_callback(partial(wrapped._task_done_callback, fut, key)) + task.add_done_callback(wrapped._get_done_callback(fut, key)) wrapped._LRUCacheWrapper__tasks.add(task) # type: ignore[attr-defined] task.cancel() @@ -33,7 +32,7 @@ async def test_done_callback_exception() -> None: key = 1 - task.add_done_callback(partial(wrapped._task_done_callback, fut, key)) + task.add_done_callback(wrapped._get_done_callback(fut, key)) wrapped._LRUCacheWrapper__tasks.add(task) # type: ignore[attr-defined] exc = ZeroDivisionError() @@ -59,7 +58,7 @@ async def test_done_callback() -> None: key = 1 fut = loop.create_future() - task.add_done_callback(partial(wrapped._task_done_callback, fut, key)) + task.add_done_callback(wrapped._get_done_callback(fut, key)) wrapped._LRUCacheWrapper__tasks.add(task) # type: ignore[attr-defined] task.set_result(1)