Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 10 additions & 118 deletions src/memos/api/handlers/search_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,25 @@

import copy
import math
import os

from contextlib import suppress
from typing import Any

from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
from memos.api.handlers.formatters_handler import rerank_knowledge_mem
from memos.api.product_models import APISearchRequest, SearchResponse
from memos.dream.contextualization import CONTEXT_MEMORY_TYPE
from memos.log import get_logger
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
cosine_similarity_matrix,
)
from memos.multi_mem_cube.composite_cube import CompositeCubeView
from memos.multi_mem_cube.single_cube import SingleCubeView
from memos.multi_mem_cube.views import MemCubeView
from memos.plugins.hooks import hookable
from memos.plugins.hook_defs import H
from memos.plugins.hooks import hookable, trigger_hook


logger = get_logger(__name__)

_ENV_CONTEXT_RECALL = "MEMOS_DREAM_CONTEXT_RECALL"
_ENV_CONTEXT_RECALL_TOP_K = "MEMOS_DREAM_CONTEXT_RECALL_TOP_K"
_DEFAULT_CONTEXT_RECALL_TOP_K = 2


def _env_enabled(name: str, default: str = "off") -> bool:
return os.getenv(name, default).strip().lower() not in {"0", "false", "no", "off"}


def _env_int(name: str, default: int) -> int:
with suppress(TypeError, ValueError):
return int(os.getenv(name, str(default)))
return default


class SearchHandler(BaseHandler):
"""
Expand Down Expand Up @@ -88,7 +72,14 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
# Search and deduplicate
cube_view = self._build_cube_view(search_req_local)
results = cube_view.search_memories(search_req_local)
self._merge_context_recall(results=results, search_req=search_req_local)
hooked_results = trigger_hook(
H.SEARCH_MEMORY_RESULTS,
handler=self,
search_req=search_req_local,
results=results,
)
if hooked_results is not None:
results = hooked_results
if not search_req_local.relativity:
search_req_local.relativity = 0
self.logger.info(f"[SearchHandler] Relativity filter: {search_req_local.relativity}")
Expand Down Expand Up @@ -120,105 +111,6 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
data=results,
)

def _merge_context_recall(
self, *, results: dict[str, Any], search_req: APISearchRequest
) -> None:
if not _env_enabled(_ENV_CONTEXT_RECALL, "off"):
return

top_k = max(0, _env_int(_ENV_CONTEXT_RECALL_TOP_K, _DEFAULT_CONTEXT_RECALL_TOP_K))
if top_k <= 0:
return

context_buckets = self._recall_context_buckets(search_req=search_req, top_k=top_k)
if not context_buckets:
return

results.setdefault("text_mem", []).extend(context_buckets)

def _recall_context_buckets(
self, *, search_req: APISearchRequest, top_k: int
) -> list[dict[str, Any]]:
graph_db = self.graph_db or getattr(self.searcher, "graph_store", None)
embedder = self.embedder or getattr(self.searcher, "embedder", None)
if graph_db is None or embedder is None:
self.logger.info(
"[SearchHandler] Context recall skipped: graph_db or embedder unavailable."
)
return []

try:
query_embedding = embedder.embed([search_req.query])[0]
except Exception:
self.logger.warning("[SearchHandler] Context recall embedding failed.", exc_info=True)
return []

buckets: list[dict[str, Any]] = []
for cube_id in self._resolve_cube_ids(search_req):
try:
hits = graph_db.search_by_embedding(
query_embedding,
top_k=top_k,
scope=CONTEXT_MEMORY_TYPE,
status="activated",
user_name=cube_id,
return_fields=[
"memory",
"key",
"created_at",
"updated_at",
"source",
"internal_info",
],
)
except Exception:
self.logger.warning(
"[SearchHandler] Context recall search failed for cube=%s.",
cube_id,
exc_info=True,
)
continue

memories = [self._format_context_hit(hit) for hit in hits or [] if hit.get("memory")]
if not memories:
continue
buckets.append(
{
"cube_id": cube_id,
"memories": memories,
"total_nodes": len(memories),
}
)
return buckets

@staticmethod
def _format_context_hit(hit: dict[str, Any]) -> dict[str, Any]:
context_id = str(hit.get("id", ""))
score = float(hit.get("score", 0.0) or 0.0)
metadata = {
"id": context_id,
"memory": hit.get("memory", ""),
"memory_type": CONTEXT_MEMORY_TYPE,
"source": hit.get("source") or "dream",
"key": hit.get("key", ""),
"relativity": score,
"score": score,
"embedding": [],
"sources": [],
"usage": [],
"ref_id": f"[{context_id.split('-')[0]}]" if context_id else "[context]",
}
for field in ("created_at", "updated_at", "internal_info"):
if hit.get(field) is not None:
metadata[field] = hit[field]

return {
"id": context_id,
"memory": hit.get("memory", ""),
"metadata": metadata,
"ref_id": metadata["ref_id"],
}

@staticmethod
def _apply_relativity_threshold(results: dict[str, Any], relativity: float) -> dict[str, Any]:
if relativity <= 0:
Expand Down
3 changes: 3 additions & 0 deletions src/memos/dream/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from memos.dream.routers.diary_router import create_diary_router
from memos.dream.routers.trigger_router import create_trigger_router
from memos.dream.search import DreamContextSearchExtension
from memos.dream.signal_store import DreamSignalStore
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.schemas.task_schemas import MEM_DREAM_TASK_LABEL
Expand Down Expand Up @@ -50,6 +51,7 @@ def on_load(self) -> None:
self.context: dict[str, Any] = {"shared": {}, "configs": {}}
self.signal_store = DreamSignalStore()
self.heuristic_enricher = DreamHeuristicEnricher()
self.search_extension = DreamContextSearchExtension()
self.pipeline = AbstractDreamPipeline(
context_strategy=DreamContextualizer(),
motive_strategy=MotiveFormation(),
Expand All @@ -62,6 +64,7 @@ def on_load(self) -> None:
# Hook registration happens at load time because scheduler-triggered Dream
# execution does not depend on FastAPI route binding.
self.register_hook(H.DREAM_EXECUTE, partial(on_dream_execute, self))
self.register_hook(H.SEARCH_MEMORY_RESULTS, self.search_extension.merge_context_recall)
self.register_hook(H.ADD_AFTER, partial(on_add_signal, self))
self.register_hook(
H.MEMORY_ITEMS_AFTER_FINE_EXTRACT,
Expand Down
139 changes: 139 additions & 0 deletions src/memos/dream/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

import logging

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from memos.dream.contextualization import CONTEXT_MEMORY_TYPE


if TYPE_CHECKING:
from memos.api.product_models import APISearchRequest


logger = logging.getLogger(__name__)

_DEFAULT_CONTEXT_RECALL_TOP_K = 2
_CONTEXT_RETURN_FIELDS = [
"memory",
"key",
"created_at",
"updated_at",
"source",
"internal_info",
]


@dataclass
class DreamContextSearchExtension:
"""Dream-owned search extension for recalling Context nodes.

The core SearchHandler only exposes a generic plugin hook. This extension
owns Dream-specific retrieval details such as the Context memory type,
graph scope, metadata formatting, and fallback behavior.
"""

top_k: int = _DEFAULT_CONTEXT_RECALL_TOP_K

def merge_context_recall(
self,
*,
handler,
search_req: APISearchRequest,
results: dict[str, Any],
) -> dict[str, Any]:
top_k = max(0, int(self.top_k or 0))
if top_k <= 0:
return results

context_buckets = self._recall_context_buckets(
handler=handler,
search_req=search_req,
top_k=top_k,
)
if context_buckets:
results.setdefault("text_mem", []).extend(context_buckets)
return results

def _recall_context_buckets(
self, *, handler, search_req: APISearchRequest, top_k: int
) -> list[dict[str, Any]]:
graph_db = getattr(handler, "graph_db", None) or getattr(
handler.searcher, "graph_store", None
)
embedder = getattr(handler, "embedder", None) or getattr(handler.searcher, "embedder", None)
if graph_db is None or embedder is None:
logger.info("[Dream Search] Context recall skipped: graph_db or embedder unavailable.")
return []

try:
query_embedding = embedder.embed([search_req.query])[0]
except Exception:
logger.warning("[Dream Search] Context recall embedding failed.", exc_info=True)
return []

buckets: list[dict[str, Any]] = []
for cube_id in _resolve_cube_ids(search_req):
try:
hits = graph_db.search_by_embedding(
query_embedding,
top_k=top_k,
scope=CONTEXT_MEMORY_TYPE,
status="activated",
user_name=cube_id,
return_fields=_CONTEXT_RETURN_FIELDS,
)
except Exception:
logger.warning(
"[Dream Search] Context recall search failed for cube=%s.",
cube_id,
exc_info=True,
)
continue

memories = [_format_context_hit(hit) for hit in hits or [] if hit.get("memory")]
if not memories:
continue
buckets.append(
{
"cube_id": cube_id,
"memories": memories,
"total_nodes": len(memories),
}
)
return buckets


def _resolve_cube_ids(search_req: APISearchRequest) -> list[str]:
if search_req.readable_cube_ids:
return list(dict.fromkeys(search_req.readable_cube_ids))
return [search_req.user_id]


def _format_context_hit(hit: dict[str, Any]) -> dict[str, Any]:
context_id = str(hit.get("id", ""))
score = float(hit.get("score", 0.0) or 0.0)
metadata = {
"id": context_id,
"memory": hit.get("memory", ""),
"memory_type": CONTEXT_MEMORY_TYPE,
"source": hit.get("source") or "dream",
"key": hit.get("key", ""),
"relativity": score,
"score": score,
"embedding": [],
"sources": [],
"usage": [],
"ref_id": f"[{context_id.split('-')[0]}]" if context_id else "[context]",
}
for field in ("created_at", "updated_at", "internal_info"):
if hit.get(field) is not None:
metadata[field] = hit[field]

return {
"id": context_id,
"memory": hit.get("memory", ""),
"metadata": metadata,
"ref_id": metadata["ref_id"],
}
13 changes: 13 additions & 0 deletions src/memos/plugins/hook_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class H:
SEARCH_BEFORE = "search.before"
SEARCH_AFTER = "search.after"

# Search extension point before core threshold/dedup/rerank processing.
SEARCH_MEMORY_RESULTS = "search.memory_results"

# Custom Hook (manually triggered via trigger_hook)
ADD_MEMORIES_POST_PROCESS = "add.memories.post_process"

Expand Down Expand Up @@ -106,6 +109,16 @@ class H:
pipe_key="prompt",
)

define_hook(
H.SEARCH_MEMORY_RESULTS,
description=(
"Allow plugins to merge additional search result buckets before core "
"threshold, deduplication, and reranking."
),
params=["handler", "search_req", "results"],
pipe_key="results",
)

define_hook(
H.MEMORY_ITEMS_AFTER_FINE_EXTRACT,
description="Post-process memory items after mem_reader fine extraction completes",
Expand Down
19 changes: 19 additions & 0 deletions src/memos/plugins/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import importlib.metadata
import logging
import os

from typing import TYPE_CHECKING

Expand All @@ -29,6 +30,17 @@ def __init__(self):
def plugins(self) -> dict[str, MemOSPlugin]:
return dict(self._plugins)

@staticmethod
def _parse_plugin_names(value: str | None) -> set[str]:
if not value:
return set()
return {item.strip() for item in value.split(",") if item.strip()}

@classmethod
def _is_plugin_enabled(cls, plugin: MemOSPlugin) -> bool:
disabled = cls._parse_plugin_names(os.getenv("MEMOS_DISABLED_PLUGINS"))
return plugin.name not in disabled

@staticmethod
def _select_plugin_winners(
candidates: list[tuple[str, MemOSPlugin]],
Expand Down Expand Up @@ -107,6 +119,13 @@ def discover(self) -> None:

winners = self._select_plugin_winners(candidates)
for plugin_name, plugin in winners.items():
if not self._is_plugin_enabled(plugin):
logger.info(
"Plugin discovered but disabled: %s v%s (MEMOS_DISABLED_PLUGINS)",
plugin.name,
plugin.version,
)
continue
plugin.on_load()
self._plugins[plugin_name] = plugin
logger.info(
Expand Down
Loading
Loading