diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 9c3896f0f..7f4bad798 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -7,15 +7,12 @@ import copy import math -import os -from contextlib import suppress from typing import Any from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies from memos.api.handlers.formatters_handler import rerank_knowledge_mem from memos.api.product_models import APISearchRequest, SearchResponse -from memos.dream.contextualization import CONTEXT_MEMORY_TYPE from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( cosine_similarity_matrix, @@ -23,25 +20,12 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView -from memos.plugins.hooks import hookable +from memos.plugins.hook_defs import H +from memos.plugins.hooks import hookable, trigger_hook logger = get_logger(__name__) -_ENV_CONTEXT_RECALL = "MEMOS_DREAM_CONTEXT_RECALL" -_ENV_CONTEXT_RECALL_TOP_K = "MEMOS_DREAM_CONTEXT_RECALL_TOP_K" -_DEFAULT_CONTEXT_RECALL_TOP_K = 2 - - -def _env_enabled(name: str, default: str = "off") -> bool: - return os.getenv(name, default).strip().lower() not in {"0", "false", "no", "off"} - - -def _env_int(name: str, default: int) -> int: - with suppress(TypeError, ValueError): - return int(os.getenv(name, str(default))) - return default - class SearchHandler(BaseHandler): """ @@ -88,7 +72,14 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse # Search and deduplicate cube_view = self._build_cube_view(search_req_local) results = cube_view.search_memories(search_req_local) - self._merge_context_recall(results=results, search_req=search_req_local) + hooked_results = trigger_hook( + H.SEARCH_MEMORY_RESULTS, + handler=self, + search_req=search_req_local, + results=results, + ) + if hooked_results is not None: + results = hooked_results if not search_req_local.relativity: search_req_local.relativity = 0 self.logger.info(f"[SearchHandler] Relativity filter: {search_req_local.relativity}") @@ -120,105 +111,6 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse data=results, ) - def _merge_context_recall( - self, *, results: dict[str, Any], search_req: APISearchRequest - ) -> None: - if not _env_enabled(_ENV_CONTEXT_RECALL, "off"): - return - - top_k = max(0, _env_int(_ENV_CONTEXT_RECALL_TOP_K, _DEFAULT_CONTEXT_RECALL_TOP_K)) - if top_k <= 0: - return - - context_buckets = self._recall_context_buckets(search_req=search_req, top_k=top_k) - if not context_buckets: - return - - results.setdefault("text_mem", []).extend(context_buckets) - - def _recall_context_buckets( - self, *, search_req: APISearchRequest, top_k: int - ) -> list[dict[str, Any]]: - graph_db = self.graph_db or getattr(self.searcher, "graph_store", None) - embedder = self.embedder or getattr(self.searcher, "embedder", None) - if graph_db is None or embedder is None: - self.logger.info( - "[SearchHandler] Context recall skipped: graph_db or embedder unavailable." - ) - return [] - - try: - query_embedding = embedder.embed([search_req.query])[0] - except Exception: - self.logger.warning("[SearchHandler] Context recall embedding failed.", exc_info=True) - return [] - - buckets: list[dict[str, Any]] = [] - for cube_id in self._resolve_cube_ids(search_req): - try: - hits = graph_db.search_by_embedding( - query_embedding, - top_k=top_k, - scope=CONTEXT_MEMORY_TYPE, - status="activated", - user_name=cube_id, - return_fields=[ - "memory", - "key", - "created_at", - "updated_at", - "source", - "internal_info", - ], - ) - except Exception: - self.logger.warning( - "[SearchHandler] Context recall search failed for cube=%s.", - cube_id, - exc_info=True, - ) - continue - - memories = [self._format_context_hit(hit) for hit in hits or [] if hit.get("memory")] - if not memories: - continue - buckets.append( - { - "cube_id": cube_id, - "memories": memories, - "total_nodes": len(memories), - } - ) - return buckets - - @staticmethod - def _format_context_hit(hit: dict[str, Any]) -> dict[str, Any]: - context_id = str(hit.get("id", "")) - score = float(hit.get("score", 0.0) or 0.0) - metadata = { - "id": context_id, - "memory": hit.get("memory", ""), - "memory_type": CONTEXT_MEMORY_TYPE, - "source": hit.get("source") or "dream", - "key": hit.get("key", ""), - "relativity": score, - "score": score, - "embedding": [], - "sources": [], - "usage": [], - "ref_id": f"[{context_id.split('-')[0]}]" if context_id else "[context]", - } - for field in ("created_at", "updated_at", "internal_info"): - if hit.get(field) is not None: - metadata[field] = hit[field] - - return { - "id": context_id, - "memory": hit.get("memory", ""), - "metadata": metadata, - "ref_id": metadata["ref_id"], - } - @staticmethod def _apply_relativity_threshold(results: dict[str, Any], relativity: float) -> dict[str, Any]: if relativity <= 0: diff --git a/src/memos/dream/plugin.py b/src/memos/dream/plugin.py index 4b787f1d5..fbe098df7 100644 --- a/src/memos/dream/plugin.py +++ b/src/memos/dream/plugin.py @@ -22,6 +22,7 @@ ) from memos.dream.routers.diary_router import create_diary_router from memos.dream.routers.trigger_router import create_trigger_router +from memos.dream.search import DreamContextSearchExtension from memos.dream.signal_store import DreamSignalStore from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import MEM_DREAM_TASK_LABEL @@ -50,6 +51,7 @@ def on_load(self) -> None: self.context: dict[str, Any] = {"shared": {}, "configs": {}} self.signal_store = DreamSignalStore() self.heuristic_enricher = DreamHeuristicEnricher() + self.search_extension = DreamContextSearchExtension() self.pipeline = AbstractDreamPipeline( context_strategy=DreamContextualizer(), motive_strategy=MotiveFormation(), @@ -62,6 +64,7 @@ def on_load(self) -> None: # Hook registration happens at load time because scheduler-triggered Dream # execution does not depend on FastAPI route binding. self.register_hook(H.DREAM_EXECUTE, partial(on_dream_execute, self)) + self.register_hook(H.SEARCH_MEMORY_RESULTS, self.search_extension.merge_context_recall) self.register_hook(H.ADD_AFTER, partial(on_add_signal, self)) self.register_hook( H.MEMORY_ITEMS_AFTER_FINE_EXTRACT, diff --git a/src/memos/dream/search.py b/src/memos/dream/search.py new file mode 100644 index 000000000..7e33b0381 --- /dev/null +++ b/src/memos/dream/search.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import logging + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from memos.dream.contextualization import CONTEXT_MEMORY_TYPE + + +if TYPE_CHECKING: + from memos.api.product_models import APISearchRequest + + +logger = logging.getLogger(__name__) + +_DEFAULT_CONTEXT_RECALL_TOP_K = 2 +_CONTEXT_RETURN_FIELDS = [ + "memory", + "key", + "created_at", + "updated_at", + "source", + "internal_info", +] + + +@dataclass +class DreamContextSearchExtension: + """Dream-owned search extension for recalling Context nodes. + + The core SearchHandler only exposes a generic plugin hook. This extension + owns Dream-specific retrieval details such as the Context memory type, + graph scope, metadata formatting, and fallback behavior. + """ + + top_k: int = _DEFAULT_CONTEXT_RECALL_TOP_K + + def merge_context_recall( + self, + *, + handler, + search_req: APISearchRequest, + results: dict[str, Any], + ) -> dict[str, Any]: + top_k = max(0, int(self.top_k or 0)) + if top_k <= 0: + return results + + context_buckets = self._recall_context_buckets( + handler=handler, + search_req=search_req, + top_k=top_k, + ) + if context_buckets: + results.setdefault("text_mem", []).extend(context_buckets) + return results + + def _recall_context_buckets( + self, *, handler, search_req: APISearchRequest, top_k: int + ) -> list[dict[str, Any]]: + graph_db = getattr(handler, "graph_db", None) or getattr( + handler.searcher, "graph_store", None + ) + embedder = getattr(handler, "embedder", None) or getattr(handler.searcher, "embedder", None) + if graph_db is None or embedder is None: + logger.info("[Dream Search] Context recall skipped: graph_db or embedder unavailable.") + return [] + + try: + query_embedding = embedder.embed([search_req.query])[0] + except Exception: + logger.warning("[Dream Search] Context recall embedding failed.", exc_info=True) + return [] + + buckets: list[dict[str, Any]] = [] + for cube_id in _resolve_cube_ids(search_req): + try: + hits = graph_db.search_by_embedding( + query_embedding, + top_k=top_k, + scope=CONTEXT_MEMORY_TYPE, + status="activated", + user_name=cube_id, + return_fields=_CONTEXT_RETURN_FIELDS, + ) + except Exception: + logger.warning( + "[Dream Search] Context recall search failed for cube=%s.", + cube_id, + exc_info=True, + ) + continue + + memories = [_format_context_hit(hit) for hit in hits or [] if hit.get("memory")] + if not memories: + continue + buckets.append( + { + "cube_id": cube_id, + "memories": memories, + "total_nodes": len(memories), + } + ) + return buckets + + +def _resolve_cube_ids(search_req: APISearchRequest) -> list[str]: + if search_req.readable_cube_ids: + return list(dict.fromkeys(search_req.readable_cube_ids)) + return [search_req.user_id] + + +def _format_context_hit(hit: dict[str, Any]) -> dict[str, Any]: + context_id = str(hit.get("id", "")) + score = float(hit.get("score", 0.0) or 0.0) + metadata = { + "id": context_id, + "memory": hit.get("memory", ""), + "memory_type": CONTEXT_MEMORY_TYPE, + "source": hit.get("source") or "dream", + "key": hit.get("key", ""), + "relativity": score, + "score": score, + "embedding": [], + "sources": [], + "usage": [], + "ref_id": f"[{context_id.split('-')[0]}]" if context_id else "[context]", + } + for field in ("created_at", "updated_at", "internal_info"): + if hit.get(field) is not None: + metadata[field] = hit[field] + + return { + "id": context_id, + "memory": hit.get("memory", ""), + "metadata": metadata, + "ref_id": metadata["ref_id"], + } diff --git a/src/memos/plugins/hook_defs.py b/src/memos/plugins/hook_defs.py index 5ec73cc86..3650e60c0 100644 --- a/src/memos/plugins/hook_defs.py +++ b/src/memos/plugins/hook_defs.py @@ -72,6 +72,9 @@ class H: SEARCH_BEFORE = "search.before" SEARCH_AFTER = "search.after" + # Search extension point before core threshold/dedup/rerank processing. + SEARCH_MEMORY_RESULTS = "search.memory_results" + # Custom Hook (manually triggered via trigger_hook) ADD_MEMORIES_POST_PROCESS = "add.memories.post_process" @@ -106,6 +109,16 @@ class H: pipe_key="prompt", ) +define_hook( + H.SEARCH_MEMORY_RESULTS, + description=( + "Allow plugins to merge additional search result buckets before core " + "threshold, deduplication, and reranking." + ), + params=["handler", "search_req", "results"], + pipe_key="results", +) + define_hook( H.MEMORY_ITEMS_AFTER_FINE_EXTRACT, description="Post-process memory items after mem_reader fine extraction completes", diff --git a/src/memos/plugins/manager.py b/src/memos/plugins/manager.py index fb473e0d9..a4478a22d 100644 --- a/src/memos/plugins/manager.py +++ b/src/memos/plugins/manager.py @@ -4,6 +4,7 @@ import importlib.metadata import logging +import os from typing import TYPE_CHECKING @@ -29,6 +30,17 @@ def __init__(self): def plugins(self) -> dict[str, MemOSPlugin]: return dict(self._plugins) + @staticmethod + def _parse_plugin_names(value: str | None) -> set[str]: + if not value: + return set() + return {item.strip() for item in value.split(",") if item.strip()} + + @classmethod + def _is_plugin_enabled(cls, plugin: MemOSPlugin) -> bool: + disabled = cls._parse_plugin_names(os.getenv("MEMOS_DISABLED_PLUGINS")) + return plugin.name not in disabled + @staticmethod def _select_plugin_winners( candidates: list[tuple[str, MemOSPlugin]], @@ -107,6 +119,13 @@ def discover(self) -> None: winners = self._select_plugin_winners(candidates) for plugin_name, plugin in winners.items(): + if not self._is_plugin_enabled(plugin): + logger.info( + "Plugin discovered but disabled: %s v%s (MEMOS_DISABLED_PLUGINS)", + plugin.name, + plugin.version, + ) + continue plugin.on_load() self._plugins[plugin_name] = plugin logger.info( diff --git a/tests/dream/test_context_recall.py b/tests/dream/test_context_recall.py index ca8e8c35c..1aafec833 100644 --- a/tests/dream/test_context_recall.py +++ b/tests/dream/test_context_recall.py @@ -3,6 +3,9 @@ from memos.api.handlers.base_handler import HandlerDependencies from memos.api.handlers.search_handler import SearchHandler from memos.api.product_models import APISearchRequest +from memos.dream.search import DreamContextSearchExtension +from memos.plugins.hook_defs import H +from memos.plugins.hooks import _hooks, register_hook class FakeEmbedder: @@ -76,8 +79,11 @@ def _search_req(): ) -def test_context_recall_disabled_by_default(monkeypatch): - monkeypatch.delenv("MEMOS_DREAM_CONTEXT_RECALL", raising=False) +def setup_function(): + _hooks.clear() + + +def test_context_recall_disabled_without_dream_search_hook(): graph = FakeGraphDB( hits=[ { @@ -95,9 +101,11 @@ def test_context_recall_disabled_by_default(monkeypatch): assert response.data["text_mem"] == [] -def test_context_recall_searches_context_scope_and_returns_summary(monkeypatch): - monkeypatch.setenv("MEMOS_DREAM_CONTEXT_RECALL", "on") - monkeypatch.setenv("MEMOS_DREAM_CONTEXT_RECALL_TOP_K", "1") +def test_context_recall_searches_context_scope_and_returns_summary(): + register_hook( + H.SEARCH_MEMORY_RESULTS, + DreamContextSearchExtension(top_k=1).merge_context_recall, + ) graph = FakeGraphDB( hits=[ { @@ -140,8 +148,11 @@ def test_context_recall_searches_context_scope_and_returns_summary(monkeypatch): assert memories[0]["metadata"]["internal_info"] == {"dream": {"memory_ids": ["m1", "m2"]}} -def test_context_recall_gracefully_skips_without_graph_db(monkeypatch): - monkeypatch.setenv("MEMOS_DREAM_CONTEXT_RECALL", "on") +def test_context_recall_gracefully_skips_without_graph_db(): + register_hook( + H.SEARCH_MEMORY_RESULTS, + DreamContextSearchExtension(top_k=1).merge_context_recall, + ) handler = _handler(graph_db=None) response = handler.handle_search_memories(_search_req()) @@ -149,8 +160,11 @@ def test_context_recall_gracefully_skips_without_graph_db(monkeypatch): assert response.data["text_mem"] == [] -def test_context_recall_gracefully_skips_on_embedding_failure(monkeypatch): - monkeypatch.setenv("MEMOS_DREAM_CONTEXT_RECALL", "on") +def test_context_recall_gracefully_skips_on_embedding_failure(): + register_hook( + H.SEARCH_MEMORY_RESULTS, + DreamContextSearchExtension(top_k=1).merge_context_recall, + ) graph = FakeGraphDB() handler = _handler(graph_db=graph, embedder=FailingEmbedder())